Introduction to Approximate Inference

Advanced Statistical Inference

Simone Rossi

EURECOM

Inference

\[ \require{physics} \definecolor{input}{rgb}{0.42, 0.55, 0.74} \definecolor{params}{rgb}{0.51,0.70,0.40} \definecolor{output}{rgb}{0.843, 0.608, 0} \definecolor{vparams}{rgb}{0.58, 0, 0.83} \definecolor{noise}{rgb}{0.0, 0.48, 0.65} \definecolor{latent}{rgb}{0.8, 0.0, 0.8} \]

\[ \require{physics} \definecolor{input}{rgb}{0.42, 0.55, 0.74} \definecolor{params}{rgb}{0.51,0.70,0.40} \definecolor{output}{rgb}{0.843, 0.608, 0} \]

Bayesian inference

Bayesian inference allows to “transform” a prior distribution over the parameters into a posterior after observing the data

Prior distribution \(p({\textcolor{params}{\boldsymbol{w}}})\)

Figure 1

Posterior distribution \(p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}})\)

Figure 2

Bayes’ rule:

\[ p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}}) = \frac{p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{w}}}, {\textcolor{input}{\boldsymbol{X}}}) p({\textcolor{params}{\boldsymbol{w}}})}{p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{input}{\boldsymbol{X}}})} \]

  • Prior: \(p({\textcolor{params}{\boldsymbol{w}}})\)
    • Encodes our beliefs about the parameters before observing the data
  • Likelihod: \(p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{w}}}, {\textcolor{input}{\boldsymbol{X}}})\)
    • Encodes our model of the data
  • Posterior: \(p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}})\)
    • Encodes our beliefs about the parameters after observing the data (e.g. conditioned on the data)
  • Evidence: \(p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{input}{\boldsymbol{X}}})\)
    • Normalizing constant, ensures that \(\int p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}}) \dd {\textcolor{params}{\boldsymbol{w}}}= 1\)

Bayesian linear regression (review)

Modeling observation as noisy realization of a linear combination of the features As before, we assume a Gaussian likelihood

\[ p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{w}}}, {\textcolor{input}{\boldsymbol{X}}}) = {\mathcal{N}}({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{input}{\boldsymbol{X}}}{\textcolor{params}{\boldsymbol{w}}}, \sigma^2 {\boldsymbol{I}}) \]

For the prior, we use a Gaussian distribution over the model parameters

\[ p({\textcolor{params}{\boldsymbol{w}}}) = {\mathcal{N}}({\textcolor{params}{\boldsymbol{w}}}\mid {\boldsymbol{0}}, {\boldsymbol{S}}) \]

In practice, we often use a diagonal covariance matrix \({\boldsymbol{S}}= \sigma_{\textcolor{params}{\boldsymbol{w}}}^2 {\boldsymbol{I}}\)

When can we compute the posterior?

Definition

A prior is conjugate to a likelihood if the posterior is in the same family as the prior.

Only a few conjugate priors exist, but they are very useful.

Examples:

  • Gaussian likelihood and Gaussian prior \(\Rightarrow\) Gaussian posterior
  • Binomial likelihood and Beta prior \(\Rightarrow\) Beta posterior

Full table available on wikipedia

Why is this useful?

\[ p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}}) = \frac{p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{w}}}, {\textcolor{input}{\boldsymbol{X}}}) p({\textcolor{params}{\boldsymbol{w}}})}{p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{input}{\boldsymbol{X}}})} \]

  • Generally the posterior is intractable to compute
    • We don’t the form of the posterior \(p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}})\)
    • The evidence \(p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{input}{\boldsymbol{X}}})\) is an integral
      • without closed form solution
      • high-dimensional and computationally intractable to compute numerically
  • Analytical solution thanks to conjugacy:
    • We know the form of the posterior
    • We know the form of the normalization constant
    • We don’t need to compute the evidence, just some algebra to get the posterior

From the likelihood and prior, we can write the posterior as

\[ \begin{aligned} p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{w}}}, {\textcolor{input}{\boldsymbol{X}}})p({\textcolor{params}{\boldsymbol{w}}}) &\propto \exp\left(-\frac 1 {2\sigma^2} \norm{{\textcolor{output}{\boldsymbol{y}}}- {\textcolor{input}{\boldsymbol{X}}}{\textcolor{params}{\boldsymbol{w}}}}_2^2 - \frac 1 2 {\textcolor{params}{\boldsymbol{w}}}^\top {\boldsymbol{S}}^{-1} {\textcolor{params}{\boldsymbol{w}}}\right) \\ &\propto \exp\left(-\frac 1 2 \left( {\textcolor{params}{\boldsymbol{w}}}^\top \left(\frac 1 {\sigma^2} {\textcolor{input}{\boldsymbol{X}}}^\top {\textcolor{input}{\boldsymbol{X}}}+ {\boldsymbol{S}}^{-1}\right) {\textcolor{params}{\boldsymbol{w}}}- \frac 2 {\sigma^2}{\textcolor{params}{\boldsymbol{w}}}^\top {\textcolor{input}{\boldsymbol{X}}}^\top {\textcolor{output}{\boldsymbol{y}}}\right)\right) \end{aligned} \]

From conjugacy, we know that the posterior is Gaussian

\[ p({\textcolor{params}{\boldsymbol{w}}}\mid {\textcolor{output}{\boldsymbol{y}}}, {\textcolor{input}{\boldsymbol{X}}}) \propto \exp\left(-\frac 1 2 \left( {\textcolor{params}{\boldsymbol{w}}}^\top {\boldsymbol{\Sigma}}^{-1} {\textcolor{params}{\boldsymbol{w}}}- 2 {\textcolor{params}{\boldsymbol{w}}}^\top {\boldsymbol{\Sigma}}^{-1} {\boldsymbol{\mu}}\right)\right) \]

We can identify the posterior mean and covariance

Posterior covariance \[ {\boldsymbol{\Sigma}}= \left(\frac 1 {\sigma^2} {\textcolor{input}{\boldsymbol{X}}}^\top {\textcolor{input}{\boldsymbol{X}}}+ {\boldsymbol{S}}^{-1}\right)^{-1} \]

Posterior mean \[ {\boldsymbol{\mu}}= \frac 1 {\sigma^2} {\boldsymbol{\Sigma}}{\textcolor{input}{\boldsymbol{X}}}^\top {\textcolor{output}{\boldsymbol{y}}} \]

Exact inference is rare

  • Exact inference is possible when the posterior distribution can be computed analytically

Example: Linear regression with Gaussian likelihood and Gaussian prior

  • This is the case for simple models with conjugate priors, …

  • … but most of the time, the posterior distribution is intractable

Examples: Logistic regression (binary classification), neural networks, …

Introduction to Approximate Inference

  • In this lecture, we will introduce the concept of approximate inference in the context of Bayesian models.

  • Approximate inference methods provide a way to approximate the posterior distribution when it is intractable

  • For the next 2 weeks, we will be model-agostic and focus on the methods used to perform inference in complex and intractable models.

Why model-agnostic?

Solving a machine learning problem involves multiple steps:

  1. Modeling: Define a model that captures the underlying structure of the data

  2. Inference: Estimate the parameters of the model

  3. Prediction: Use the model to make predictions on new data

In these two weeks, we will focus on the inference step

Problem definition

In probabilistic models, all unknown quantities are treated as random variables

  • Observed quantities: \({\textcolor{output}{\boldsymbol{y}}},{\textcolor{input}{\boldsymbol{x}}}\) (data)

  • Unknown variables: \({\textcolor{params}{\boldsymbol{\theta}}}\) (parameters)

Given a likelihood \(p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{\theta}}})\) and a prior \(p({\textcolor{params}{\boldsymbol{\theta}}})\), the goal is to compute the posterior distribution \(p({\textcolor{params}{\boldsymbol{\theta}}}\mid {\textcolor{output}{\boldsymbol{y}}})\)

\[ p({\textcolor{params}{\boldsymbol{\theta}}}\mid {\textcolor{output}{\boldsymbol{y}}}) = \frac{p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{\theta}}}) p({\textcolor{params}{\boldsymbol{\theta}}})}{p({\textcolor{output}{\boldsymbol{y}}})} \]

  • Intractability: The denominator \(p({\textcolor{output}{\boldsymbol{y}}})\) involves an intractable integral

\[ p({\textcolor{output}{\boldsymbol{y}}}) = \int p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{\theta}}}) p({\textcolor{params}{\boldsymbol{\theta}}}) \dd {\textcolor{params}{\boldsymbol{\theta}}} \]

Approximate inference

Approximate inference methods provide a way to approximate the distributions when the exact computation is intractable

We will study two main classes of approximate inference methods

Sampling-based methods

  • Monte Carlo methods
  • Markov Chain Monte Carlo (MCMC)
  • Hamiltonian Monte Carlo (HMC)

Parametric methods

  • Variational inference
  • Laplace approximation

Grid approximation

Grid approximation: Evaluate the posterior on a grid of points

  1. Divide the parameter space into set of \(K\) regions \({\boldsymbol{r}}_1, ..., {\boldsymbol{r}}_K\) of volume \(\Delta\)
  2. Probability mass in each region is \(p({\textcolor{params}{\boldsymbol{\theta}}}\in {\boldsymbol{r}}_k \mid {\textcolor{output}{\boldsymbol{y}}}) \approx p_k\Delta\), where

\[ \begin{aligned} p_k &= \frac{\widetilde{p}_k}{\sum_{j=1}^K \widetilde{p}_j}\\ \widetilde{p}_k &= p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{\theta}}}_k) p({\textcolor{params}{\boldsymbol{\theta}}}_k) \end{aligned} \]

  1. Approximate the marginal as

\[ p({\textcolor{output}{\boldsymbol{y}}}) = \int p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{\theta}}}) p({\textcolor{params}{\boldsymbol{\theta}}}) \dd {\textcolor{params}{\boldsymbol{\theta}}}\approx \sum_{k=1}^K p({\textcolor{output}{\boldsymbol{y}}}\mid {\textcolor{params}{\boldsymbol{\theta}}}_k) p({\textcolor{params}{\boldsymbol{\theta}}}_k) \Delta \]

Example of grid approximation

Grid approximation for a Beta-Bernoulli model

Scaling grid approximation

Grid approximation is a simple method that can be used to approximate distributions in 1/2 dimensions.

⏩ Next: Monte Carlo Methods