Introduction to Approximate Inference

Advanced Statistical Inference

Simone Rossi

EURECOM

Bayesian inference

\[ \require{physics} \definecolor{input}{rgb}{0.42, 0.55, 0.74} \definecolor{params}{rgb}{0.51,0.70,0.40} \definecolor{output}{rgb}{0.843, 0.608, 0} \definecolor{vparams}{rgb}{0.58, 0, 0.83} \definecolor{noise}{rgb}{0.0, 0.48, 0.65} \definecolor{latent}{rgb}{0.8, 0.0, 0.8} \definecolor{function}{rgb}{0.75, 0.75, 0.12} \]

Bayesian inference allows to “transform” a prior distribution over the parameters into a posterior after observing the data

Prior distribution \(p({\textcolor{params}{\boldsymbol{w}}})\)

Posterior distribution \(p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}})\)

Bayes’ rule:

\[ p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}}) = \frac{p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{w}}}, {\textcolor{input}{\boldsymbol{X}}}) p({\textcolor{params}{\boldsymbol{w}}})}{p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{input}{\boldsymbol{X}}})} \]

  • Prior: \(p({\textcolor{params}{\boldsymbol{w}}})\)
    • Encodes our beliefs about the parameters before observing the data
  • Likelihod: \(p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{w}}}, {\textcolor{input}{\boldsymbol{X}}})\)
    • Encodes our model of the data
  • Posterior: \(p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}})\)
    • Encodes our beliefs about the parameters after observing the data (e.g. conditioned on the data)
  • Evidence: \(p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{input}{\boldsymbol{X}}})\)
    • Normalizing constant, ensures that \(\int p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}}) \dd {\textcolor{params}{\boldsymbol{w}}}= 1\)

Bayesian linear regression (review)

Modeling observation as noisy realization of a linear combination of the features As before, we assume a Gaussian likelihood

\[ p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{w}}}, {\textcolor{input}{\boldsymbol{X}}}) = {\mathcal{N}}({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{input}{\boldsymbol{X}}}{\textcolor{params}{\boldsymbol{w}}}, \sigma^2 {\boldsymbol{I}}) \]

For the prior, we use a Gaussian distribution over the model parameters

\[ p({\textcolor{params}{\boldsymbol{w}}}) = {\mathcal{N}}({\textcolor{params}{\boldsymbol{w}}}\mid {\boldsymbol{0}}, {\boldsymbol{S}}) \]

In practice, we often use a diagonal covariance matrix \({\boldsymbol{S}}= \sigma_{\textcolor{params}{\boldsymbol{w}}}^2 {\boldsymbol{I}}\)

When can we compute the posterior?

Definition

A prior is conjugate to a likelihood if the posterior is in the same family as the prior.

Only a few conjugate priors exist, but they are very useful.

Examples:

  • Gaussian likelihood and Gaussian prior \(\Rightarrow\) Gaussian posterior
  • Binomial likelihood and Beta prior \(\Rightarrow\) Beta posterior

Full table available on wikipedia

Why is this useful?

\[ p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}}) = \frac{p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{w}}}, {\textcolor{input}{\boldsymbol{X}}}) p({\textcolor{params}{\boldsymbol{w}}})}{p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{input}{\boldsymbol{X}}})} \]

  • Generally the posterior is intractable to compute
    • We don’t the form of the posterior \(p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}})\)
    • The evidence \(p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{input}{\boldsymbol{X}}})\) is an integral
      • without closed form solution
      • high-dimensional and computationally intractable to compute numerically
  • Analytical solution thanks to conjugacy:
    • We know the form of the posterior
    • We know the form of the normalization constant
    • We don’t need to compute the evidence, just some algebra to get the posterior

From the likelihood and prior, we can write the posterior as

\[ \begin{aligned} p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{w}}}, {\textcolor{input}{\boldsymbol{X}}})p({\textcolor{params}{\boldsymbol{w}}}) &\propto \exp\left(-\frac 1 {2\sigma^2} \norm{{\textcolor{output}{\boldsymbol{y}}}- {\textcolor{input}{\boldsymbol{X}}}{\textcolor{params}{\boldsymbol{w}}}}_2^2 - \frac 1 2 {\textcolor{params}{\boldsymbol{w}}}^\top {\boldsymbol{S}}^{-1} {\textcolor{params}{\boldsymbol{w}}}\right) \\ &\propto \exp\left(-\frac 1 2 \left( {\textcolor{params}{\boldsymbol{w}}}^\top \left(\frac 1 {\sigma^2} {\textcolor{input}{\boldsymbol{X}}}^\top {\textcolor{input}{\boldsymbol{X}}}+ {\boldsymbol{S}}^{-1}\right) {\textcolor{params}{\boldsymbol{w}}}- \frac 2 {\sigma^2}{\textcolor{params}{\boldsymbol{w}}}^\top {\textcolor{input}{\boldsymbol{X}}}^\top {\textcolor{output}{\boldsymbol{y}}}\right)\right) \end{aligned} \]

From conjugacy, we know that the posterior is Gaussian

\[ p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}}) \propto \exp\left(-\frac 1 2 \left( {\textcolor{params}{\boldsymbol{w}}}^\top {\boldsymbol{\Sigma}}^{-1} {\textcolor{params}{\boldsymbol{w}}}- 2 {\textcolor{params}{\boldsymbol{w}}}^\top {\boldsymbol{\Sigma}}^{-1} {\boldsymbol{\mu}}\right)\right) \]

We can identify the posterior mean and covariance

Posterior covariance \[ {\boldsymbol{\Sigma}}= \left(\frac 1 {\sigma^2} {\textcolor{input}{\boldsymbol{X}}}^\top {\textcolor{input}{\boldsymbol{X}}}+ {\boldsymbol{S}}^{-1}\right)^{-1} \]

Posterior mean \[ {\boldsymbol{\mu}}= \frac 1 {\sigma^2} {\boldsymbol{\Sigma}}{\textcolor{input}{\boldsymbol{X}}}^\top {\textcolor{output}{\boldsymbol{y}}} \]

Exact inference is rare

  • Exact inference is possible when the posterior distribution can be computed analytically

Example: Linear regression with Gaussian likelihood and Gaussian prior

  • This is the case for simple models with conjugate priors, …

  • … but most of the time, the posterior distribution is intractable

Examples: Logistic regression (binary classification), neural networks, …

Introduction to Approximate Inference

  • In this lecture, we will introduce the concept of approximate inference in the context of Bayesian models.

  • Approximate inference methods provide a way to approximate the posterior distribution when it is intractable

  • For the next 2 weeks, we will be model-agostic and focus on the methods used to perform inference in complex and intractable models.

Why model-agnostic?

Solving a machine learning problem involves multiple steps:

  1. Modeling: Define a model that captures the underlying structure of the data

  2. Inference: Estimate the parameters of the model

  3. Prediction: Use the model to make predictions on new data

In these two weeks, we will focus on the inference step

Problem definition

In probabilistic models, all unknown quantities are treated as random variables

  • Observed quantities: \({\textcolor{output}{\boldsymbol{y}}}\in {\mathbb{R}}^N\) (vector of \(N\) observations);

  • Unknown variables: \({\textcolor{params}{\boldsymbol{\theta}}}\in {\mathbb{R}}^D\) (vector of \(D\) parameters)

Given a likelihood \(p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{\theta}}})\) and a prior \(p({\textcolor{params}{\boldsymbol{\theta}}})\), the goal is to compute the posterior distribution \(p({\textcolor{params}{\boldsymbol{\theta}}}\mid {\textcolor{output}{\boldsymbol{y}}})\)

\[ p({\textcolor{params}{\boldsymbol{\theta}}}\mid {\textcolor{output}{\boldsymbol{y}}}) = \frac{p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{\theta}}}) p({\textcolor{params}{\boldsymbol{\theta}}})}{p({\textcolor{output}{\boldsymbol{y}}})} \]

Note: We drop the conditioning on the data \({\textcolor{input}{\boldsymbol{X}}}\) for simplicity, but it is present in the likelihood as input of the model

Approximate inference

Approximate inference methods provide a way to approximate distributions when the exact computation is intractable

\[ p({\textcolor{params}{\boldsymbol{\theta}}}\mid {\textcolor{output}{\boldsymbol{y}}}) \approx q({\textcolor{params}{\boldsymbol{\theta}}}) \]

We will study two main classes of approximate inference methods

Sampling-based methods

  • Monte Carlo methods
  • Markov Chain Monte Carlo (MCMC)
  • Hamiltonian Monte Carlo (HMC)

Parametric methods

  • Variational inference
  • Laplace approximation

Grid approximation

Divide the parameter space into \(K\) regions \(\mathcal R_1,\ldots,\mathcal R_K\) of equal volume \(\Delta\). For each region, approximate the posterior mass by a Riemann approximation:

\[ p({\textcolor{params}{\boldsymbol{\theta}}}\in\mathcal R_k\mid {\textcolor{output}{\boldsymbol{y}}}) =\int_{\mathcal R_k} p({\textcolor{params}{\boldsymbol{\theta}}}\mid {\textcolor{output}{\boldsymbol{y}}}),\dd {\textcolor{params}{\boldsymbol{\theta}}} \approx p({\textcolor{params}{\boldsymbol{\theta}}}_k\mid {\textcolor{output}{\boldsymbol{y}}})\,\Delta . \]

Use Bayes’ rule at each grid point:

\[ p({\textcolor{params}{\boldsymbol{\theta}}}_k\mid {\textcolor{output}{\boldsymbol{y}}}) =\frac{p({\textcolor{output}{\boldsymbol{y}}}\mid{\textcolor{params}{\boldsymbol{\theta}}}_k)p({\textcolor{params}{\boldsymbol{\theta}}}_k)} {\int p({\textcolor{output}{\boldsymbol{y}}}\mid{\textcolor{params}{\boldsymbol{\theta}}})p({\textcolor{params}{\boldsymbol{\theta}}})\dd {\textcolor{params}{\boldsymbol{\theta}}}}. \]

Define unnormalized terms \(\widetilde p_k=p({\textcolor{output}{\boldsymbol{y}}}\mid{\textcolor{params}{\boldsymbol{\theta}}}_k)p({\textcolor{params}{\boldsymbol{\theta}}}_k)\). Normalization gives

\[ p({\textcolor{params}{\boldsymbol{\theta}}}_k\mid {\textcolor{output}{\boldsymbol{y}}})= \frac{\widetilde p_k}{\sum_{j=1}^K \widetilde p_j}, \qquad p({\textcolor{params}{\boldsymbol{\theta}}}\in\mathcal R_k\mid {\textcolor{output}{\boldsymbol{y}}})\approx \frac{\widetilde p_k}{\sum_{j=1}^K \widetilde p_j}\Delta . \]

The marginal likelihood follows from the same Riemann sum \(p({\textcolor{output}{\boldsymbol{y}}})\approx \sum_{k=1}^K \widetilde p_k\Delta\)