Yao Lirong's Blog

Variational Inference

2024/09/09
loading

Probabilistic Latent Variable Models

The two general forms of probabilistic models are:

  • p(x): a typical probabilistic distribution. In this model, we call x the query.
  • p(y ∣ x): a conditional probabilistic distribution. In this model, we cal x the evidence and y the query.

Latent variable models are models that have variables other than the query and the evidence.

  • p(x) = ∑zp(x ∣ z) p(z)

    A classic latent variable model of p(x) is the mixture model, where p(x) is actually a mixture of several different probabilistic model. For example, in the following graph, z is a discrete variable representing which class a datapoint belongs to and is represented by different colors here. p(x ∣ z) is each class’s probability distribution, where in this case can each be modeled by a Gaussian. And p(x)​ when we observe it, is just a bunch of uncolored datapoints and is hard to fit a distribution on it. However, we can see it’s roughly spread in 3 clusters so we introduce the latent variable representing class and we can now well fit a Gaussian mixture model on it (a mixture of 3 Gaussians)

    Gaussian Mixture Model
  • p(y ∣ x) = ∑zp(y ∣ x, z) p(z) or p(y ∣ x) = ∑zp(y ∣ z) p(z ∣ x): the conditional probability is a bit more free. You can decompose and model it using z​ as you like.

    An example of latent conditional model is the mixture density network, which we use in RL’s imitation learning to deal with multi-modal situations each requiring a different distribution.

Latent Variable Models in General

When we use latent variable models, it means we want to decompose a complicated distribution into several simple / easy distributions. By complicated, we mean it’s not possible to write it in a well-defined distribution. By simple / easy, we mean we can write it as a well-defined parametrized distribution, where the parameters can be complex, but the distribution itself is easy to write (a Gaussian of just mean and sigma, or as a Bernoulli with just one variable, etc.) p(x) = ∫p(x ∣ z)p(z)dz

  • p(z) is an “easy” prior we choose. For example a Gaussian, a categorical distribution, etc.
  • p(x ∣ z) should also be an easy distribution, like a Gaussian: $ p(x z) = ({nn}(z), {nn}(z))$ even though the mapping from z to the actual parameters of Gaussian can be complex, where in this case we have to model the mapping through a neural network and this mapping is the learnable part.
  • p(x) is complicated, not possible to write out as any well-defined distribution. Therefore, we decompose it into the two parts above that are easy to parametrize as a probability distribution and learn the parameters inside the distribution.

Generative models are not equal to latent variable models. We usually model generative models as latent variable ones because generative models are usually complex probability distributions and we can make it easier by introducing one or more latent variable.

How to Train a Latent Variable Model

Given dataset 𝒟 = {x1, x2, …, xN}, to fit a typical probabilistic model pθ(x), we use the maximum likelihood estimation: $$ \theta \leftarrow \underset{\theta}{arg\!\max} \frac 1 N \sum_i \log p_\theta(x_i) $$ In the latent variable model set up, we can substitute the definition in and an MLE would look like $$ \theta \leftarrow \underset{\theta}{arg\!\max} \frac 1 N \sum_i \log \left( \int p_\theta(x_i \mid z) p(z) dz \right) $$ pθ(x ∣ z) and p(z) are distributions of our choices, but this integral is still intractable when z is continuous. So now it’s time to do some math tricks.

Variational Inference

Variational Approximation

First, we construct an easy / simple probability distribution qi(z) to approximate p(z|xi) - the posterior distribution specific to datapoint xi. By easy we again mean it is easy to parametrize (a Gaussian, a Bernoulli, etc.)

We will show that by introducing this qi(z), we can actually construct a lower bound of log p(xi). What’s good with this lower bound? Later on, we will also prove this bound is sufficiently tight, so as we push up the value of this lower bound, we push up the value of p(xi) which is exactly what we want.

$$ \begin{align} \log p(x_{i}) &= \log\int_{z}p(x_{i}|z)p(z)\\ &= \log\int_{z}p(x_{i}|z)p(z) \frac{q_i(z)}{q_i(z)} \\ &= \log \mathbb E_{z\sim q_{i}(z)} \left[\frac{p(x_{i}|z)p(z)} {q_{i}(z)}\right] \\ &\geq \mathbb E_{z\sim q_{i}(z)} \left[\log\frac{p(x_{i}|z)p(z)}{q_{i}(z)}\right] &\text{\# Jensen's Inequality} \\ &= \mathbb E_{z\sim q_{i}(z)} \left[\log p(x_{i}|z)+\log p(z) \right] - \mathbb E_{z\sim q_{i}(z)} \left[ \log {q_{i}(z)}\right]\\ &= \mathbb E_{z\sim q_{i}(z)} \left[\log p(x_{i}|z)+\log p(z) \right] + \mathcal H_{z\sim q_{i}(z)} (q_i) = \mathcal L_i(p, q_i) \end{align} $$ Recall p(x) is a difficult probability distribution, so we decompose it into two easy distributions p(x|z) and p(z), and use an easy distribution qi(z) to approximate the posterior p(z|xi). Now the good thing is: everything here is tractable: for the first term, we can fix a qi(z) of our choice (recall qi is a distribution we constructed), sample some z, and evaluate the expression. For the second term, we notice it is just the entropy of a distribution and for simple distributions (we constructed qi to be simple), it has a closed form (even if it doesn’t, you can simply sample and evaluate)

We call the final lower bound we derived here the variance lower bound or evidence lower bound (ELBO). $$ \begin{align} \log p(x_{i}) &\geq \mathcal L_i(p, q_i) \\ &= \mathbb E_{z\sim q_{i}(z)} \left[\log p(x_{i}|z)+\log p(z) \right] + \mathcal H_{z\sim q_{i}(z)} (q_i) \end{align} $$ ### Effect of Pushing Up ELBO (Intuitively)

Assume our p(⋅)​ is a fixed value, what does pushing up ELBO mean? Here, we give out an intuitive explanation. First, we look at the first term with the two log combined. $$ \begin{align} &\mathbb E_{z\sim q_{i}(z)} \left[\log p(x_{i}|z)+\log p(z) \right] \\ = &\mathbb E_{z\sim q_{i}(z)} \left[\log p(x_{i},z) \right] \end{align} $$ To maximize this value, we just have to find a distribution of z, inside which we have the largest value of p(xi, z). Therefore, we want z to distribute mostly under the peak of p(xi, z), Since qi(z) is the distribution we currently have for z, we want qi(z) to sit mostly under the peak of p(xi, z). In the following graph, the y-axis is p(xi, z), the distribution we try to maximize, and the x-axis is our latent variable z. There is also a hidden axis - the probability mass (distribution) of z. We project this hidden axis to the y-axis in this graph. To maximize this first term, we spread z’s mass as much under the peak of p(xi, z) as possible, which makes the green part of this graph.

maximize ELBO

Now we take the second term entropy into consideration. i(p, qi) = 𝔼z ∼ qi(z)[log p(xi, z)] + ℋz ∼ qi(z)(qi) From our entropy post, we know entropy measures the expected code length of communicating the event described by a random variable. So the more random this variable is, the more code words it’s required to communicate it. Therefore, the more spread out / uniform the distribution is, the higher the entropy. If we’re maxing the entropy, we don’t want the distribution to be skinny. See the following graph.

entropy-example

When we consider both entropy and the first term, we should achieve this probability distribution depicted in brown. If we don’t have the entropy, z will want to sit under the most likely point, but since we added entropy, z now tries to cover it. In conclusion, (equal sign “=” reads “in effect”) maximize evidence lower bound = cover most of the p(xi|z) distribution = maximize approximation between qi and p(xi|z).

maximize ELBO

Effect of Pushing Up ELBO (Analytically)

Can we measure how good our approximation is? That is, can we measure the distance between p(xi|z) and qi? In fact, we have a nice, analytical way to look at it using KL divergence. For two arbitrary distribution p, q of x, the KL divergence of q from p (the distance from q to p, note KL divergence is not symmetric) is

$$ \begin{align} D_{\mathrm{KL}}(q|p) &=E_{x\sim q(x)}\left[\log{\frac{q(x)}{p(x)}}\right]\\ &=E_{x \sim q(x)}[\log q(x)]-E_{x \sim q(x)}[\log p(x)]\\ &=-E_{x \sim q(x)}[\log p(x)]-H(q) \end{align} $$ Doesn’t this look similar to our evidence lower bound? Borrowing that explanation, minimizing KL divergence = cover most of the p(z) distribution = maximize approximation between q and p.

KL-divergence

Having understood the definition of KL divergence, let’s use it to measure the distance between qi(z) and p(z|xi) - the distribution we want qi to approximate: $$ \begin{align} D_{KL}(q_{i}(z)\|p(z \mid x_{i})) &= E_{z\sim q_{i}(z)}\left[\log{\frac{q_{i}(z)}{p(z|x_{i})}}\right]\\ &= E_{z\sim q_{i}(z)}\left[\log{\frac{q_{i}(z)p(x_{i})}{p(x_{i},z)}}\right]\\ &= -E_{z\sim q_{i}(z)}\left[\log p(x_{i}|z)+\log p(z)\right] + E_{z\sim q_{i}(z)}\log q_i(z) + E_{z\sim q_{i}(z)}\log p(x_{i})\\ &= -E_{z\sim q_{i}(z)}\left[\log p(x_{i}|z)+\log p(z)\right] + \mathcal H(q_i) + E_{z\sim q_{i}(z)}\log p(x_{i})\\ &= -\mathcal L(p, q_i) + \log p(x_i)\\ \log p(x_i) &= \mathcal L(p, q_i) + D_{KL}(q_{i}(x_{i})\|p(z \mid x_{i})) \end{align} $$ Therefore, having a good approximation of qi to p(xi|z) = driving KL divergence, which is always a non-negative number, to 0 = the evidence lower bound is a tight bound or even equal to log p(xi)​ - the ultimate thing we want to optimize.

Looking at our optimization objective here: ℒ(p, qi) = log p(xi) − DKL(qi(xi)∥p(z ∣ xi))

  • When we optimize w.r.t. q: note the first term log p(xi) is independent of q, so its value stays the same. We are in effect optimizing against the KL divergence only, making the distance between our approximation qi and p(z|xi) smaller. The best / extreme case is we have DKL = 0, so ℒ = log p(xi).

  • When we optimize w.r.t. p: Recall our ultimate goal is to make log p(xi) bigger, so we make a better model in theory. Only in theory because we don’t know whether the bound is tight or not.

The Learning Algorithm?

Therefore, when we optimize ℒ(p, qi)​ w.r.t. q​, we make the bound tighter (make ​ a better approximation of log p(xi)​ ); when we optimize ℒ(p, qi)​ w.r.t. p​, we make a better model in general.

By alternating these two steps, we have the actual learning algorithm. Let’s review: which parts are learnable in these two distributions?

  • In our latent variable model setup, we decompose the complicated distribution p(x) into two easy distributions p(x|z) and p(z), where the mapping from z to actual parameters of this p(x|z) distribution needs to be modeled by a complex network. Therefore, the only distribution in the p part with learnable parameters is p(x|z). We denote it with pθ(x|z).

  • In our ELBO setup, we also introduced a simple qi(z) for each datapoint xi to approximate the posterior p(z|xi). To optimize w.r.t. q, we optimize the parameters of each distribution. If qi(z) = 𝒩(μi, σi), we optimize each μi, σi. (we can optimize the entropy value for sure, but I’m not entirely sure how you would take gradient of the expectation term Ez ∼ qi(z)[log p(z)])

Therefore, we have our learning algorithm: $$ \begin{align} &\text{for each $x_i$ in $\{x_1, \dots, x_N\}$: }\\ &\hspace{4mm} \text{sample $z \sim q_i(z)$}\\ &\hspace{4mm} \text{optimize against $p$:}\\ &\hspace{4mm} \hspace{4mm} \nabla_\theta \mathcal L (p_\theta, q_i) = \nabla_\theta \log p_\theta(x_i|z) \\ &\hspace{4mm} \hspace{4mm} \theta \leftarrow \theta + \alpha \nabla_\theta \mathcal L (p, q_i) \\ &\hspace{4mm} \text{optimize against $q$:}\\ &\hspace{4mm} \hspace{4mm} \nabla_{\mu_i, \sigma_i} \mathcal L (p_\theta, q_i) = \nabla_{\mu_i, \sigma_i} \left[\mathbb E_{z\sim q_{i}(z)} \left[\log p(x_{i}|z)+\log p(z) \right] + \mathcal H_{z\sim q_{i}(z)} (q_i) \right] \\ &\hspace{4mm} \hspace{4mm} (\mu_i, \sigma_i) \leftarrow (\mu_i, \sigma_i) + \alpha \nabla_{\mu_i, \sigma_i} \mathcal L (p, q_i) \\ \end{align} $$

There’s a problem with optimizing qi though. Note we have a separate q for each data point i, which means if we have N data points, we will have to store N × (|μi| + |σi|) parameters assuming we chose qi to be Gaussian. In machine learning, the number of data points N is usually in millions, making this model unwieldily big. It’s true in inference time we do not use q at all (we’ll see why this is true in the last chapter about VAE), but in training time, we still need them so it’s necessary to keep all these parameters.

Therefore, instead of having a separate qi(⋅) to approximate each data point’s P(⋅|xi) specifically, we use a learnable model qϕ(⋅|xi) to approximate p(⋅|xi) This learnable network will take in a datapoint xi, predicts the corresponding μi, σi. We can then sample z​ from this predicted network.

Amortized

By adapting q to be a learnable network qϕ​ instead, model size does not depend on the number of datapoints anymore. Therefore, it is amortized.

The variational lower bound becomes: ℒ(pθ(xi|z), qϕ(z|xi)) = 𝔼z ∼ qϕ(z|xi)[log pθ(xi|z) + log p(z)] + ℋ(qϕ(z|xi)) The learning algorithm naturally becomes: $$ $$

Gradient Over Expectation (Policy Gradient)

The question now boils down to how do we calculate this gradient ϕℒ(pθ, qϕ).

The second term entropy is easy. We purposefully chose q to be a simple distribution, so there is usually a close form of its entropy and we just have to look it up.

The meat is in the first part. How do we take gradient w.r.t. parameter ϕ in the expectation term’s distribution qϕ ? Note the term inside expectation is independent of ϕ, so we can rewrite it as R(xi, z) = log pθ(xi|z) + log p(z) and call the whole thing J.
J(ϕ) = ∇ϕ𝔼z ∼ qϕ(z|xi)[R(xi, z)] We chose these namings purposefully because we encountered something similar back in the policy gradient part of reinforcement learning LINK???. Say we have a trajectory τ, sampled from the state transition function with learnable policy πθ, the final expected value we can get from starting state s0 can be written as the following, where R(τ) is a reward function returning the reward of this trajectory. J(θ) = Vπθ(s0) = 𝔼τ ∼ Ps0πθ[R(τ)] We can take the gradient of this value function V w.r.t our policy πθ, so this is called policy gradient. If you’re unfamiliar with RL setup, you just have to know we can derive the following gradient and we can approximate it by sampling M trajectories. $$ $$ Pugging in our $q$ and $\phi$, $$ $$

Reparametrization Trick

We have our full learning algorithm and it’s ready to go now. However, there is a tiny improvement we can do.

We defined our qϕ to be a normal distribution 𝒩(μϕ, σϕ) Observe that all normal distributions can be written as a function of the unit normal distribution. Therefore, a sample z is in effect: z ∼ 𝒩(μϕ, σϕ) ⇔ z = μϕ + ϵσϕ, ϵ ∼ 𝒩(0, 1) Let’s rewrite our expectation term to now sample an ϵ from the unit normal distribution instead. By decomposing z into these two parts, we separate the stochastic part and changed z from a sample of some stochastic distribution into a deterministic function z(ϕ, ϵ) parametrized by ϕ and random variable ϵ that is independent of ϕ. ϵ takes the stochastic part alone. Our learnable parameter ϕ now only parametrizes deterministic quantity. ϕJ(ϕ) = ∇ϕ𝔼ϵ ∼ 𝒩(0, 1)[R(xi, μϕ + ϵσϕ)] Aside from these theoretical benefits, mathematically, we do not have to take gradient w.r.t an expectation of parametrized distribution anymore. Instead, the gradient can go straight into the expectation term now like how we usually interchange gradient and expectation (think about discrete case, expectation is just a big sum so we can do it). ϕJ(ϕ) = 𝔼ϵ ∼ 𝒩(0, 1)[∇ϕR(xi, μϕ + ϵσϕ)] Further, to approximate this expectation, we just sample some ϵ from this normal distribution. $$ \nabla_\phi J(\phi) \approx \frac 1 M \sum_j^M \nabla_\phi R(x_i, \mu_\phi + \epsilon_j \sigma_\phi) $$

With reparametrization, we achieve a lower variance than policy gradient because we are using the derivative of R. (Unfortunately the lecturer didn’t provide a quantitative analysis on this and I don’t know how to prove it) On the other hand, previously, we only took derivative w.r.t. the probability distribution. Why didn’t we use derivative of R back in RL with policy gradient? It’s not we don’t want to but we can’t: we can’t use reparametrization in RL because in RL we usually cannot take derivative w.r.t. reward R.

Method Formula Approximation Benefit Deficit
Policy Gradient ϕ𝔼z ∼ qϕ(z ∣ xi)[R(xi, z)] $\frac 1 M \sum_j^M \nabla_\phi[\log q_\phi(z_j \mid x_i)] R(x_i,z_j)$ works with both discrete and continuous latent variable z High variance, requires multiple samples & small learning rates
Reparametrization 𝔼ϵ ∼ 𝒩(0, 1)[∇ϕR(xi, μϕ + ϵσϕ)] $\frac 1 M \sum_j^M \nabla_\phi R(x_i, \mu_\phi + \epsilon_j \sigma_\phi)$ low variance, simple to implement (we’ll see soon) only works with continuous variable z and have to model it with a Gaussian

In fact, you can forget about the policy gradient method and simply take it for granted that you cannot back propagate a sampled value ϕ𝔼z ∼ qϕ(z|xi), so you have to find some way to make our z​ deterministic, which is what we’re doing here with our reparametrization trick.

reparametrization-trick

Left is without the “reparameterization trick”, and right is with it. Red shows sampling operations that are non-differentiable. Blue shows loss layers. We forward the network by going up and back propagate it by going down. The forward behavior of these networks is identical, but back propagation can be applied only to the right network. Figure copied from Carl Doersch: Tutorial on Variational Autoencoders

Looking at Directly

$$ \begin{align} \mathcal L_i = \mathcal L \left( p_\theta(x_i | z), q_\phi(z | x_i) \right) &= \mathbb E_{z\sim q_\phi(z | x_i)} \left[\log p_\theta(x_{i}|z)+\log p(z) \right] + \mathcal H (q_\phi(z|x_i))\\ &= \mathbb E_{z\sim q_\phi(z | x_i)} \left[\log p_\theta(x_{i}|z) \right] + \mathbb E_{z\sim q_\phi(z | x_i)} \left[\log p(z) \right] + \mathcal H (q_\phi(z|x_i))\\ &= \mathbb E_{z\sim q_\phi(z | x_i)} \left[\log p_\theta(x_{i}|z)\right] - D_{KL}(q_\phi(z | x_i)\|p(z)) \\ &= \mathbb E_{\epsilon \sim \mathcal N(0,1)} \left[\log p_\theta(x_{i}| \mu_\phi + \epsilon \sigma_\phi)\right] - D_{KL}(q_\phi(z | x_i)\|p(z)) \\ &\approx \frac 1 M \sum_j^M \log p_\theta(x_{i}| \mu_\phi + \epsilon_j \sigma_\phi) - D_{KL}(q_\phi(z | x_i)\|p(z)) \\ \end{align} $$

For the first term, we can just evaluate it. For the second KL term, since we chose both distributions to be easy (in this case Gaussian), there often is a nice analytical form for it.

Therefore, we can go ahead to maximize the variational lower bound ​. We can also draw out the following computational graph for the log term and conclude we can back propagate this graph without any problem. On the other hand, if we didn’t do the reparametrization trick, we will get stuck at z: you cannot back propagate z - a sampled value instead of a variable. And we will have to seek help from policy gradient. With reparametrization, we decompose z into two variables μϕ, σϕ we can back propagate through and one stochastic value ϵ we do not care about.

computational-graph

Variational Autoencoder

Setup and Interpretation

What we have gone though constitutes the full pipeline of a variational autoencoder.

In a variation autoencoder, we have observed variable x and latent variable z

  • encoder qϕ(z|x) = 𝒩(μϕ(x), σϕ(x))
  • decoder pθ(x|z) = 𝒩(μθ(z), σθ(z))

In training, given an observed sample xi, we encode it to latent variable zi using qϕ, then tries to decode it back with decoder pθ. We maximize the variational lower bound during the process. For all N samples, the training objective looks like: (where the ϵ is a sampled value) $$ \max_{\phi,\theta} \frac 1 N \sum_i^N \log p_\theta\left(x_{i}| \mu_\phi(x_i) + \epsilon \sigma_\phi(x_i)\right) - D_{KL}(q_\phi(z | x_i)\|p(z)) \\ $$ In inference (generation), we sample a z from our prior p(z), then decode it using pθ: z ∼ p(z), x ∼ pθ(x|z)

Why does the variational autoencoder work? We talked about many benefits of maximizing this variational lower bound in previous chapter. Let’s look at it again in this decoder-encoder setup,. i = 𝔼z ∼ qϕ(z|xi)[log pθ(xi|z)] − DKL(qϕ(z|xi)∥p(z))

  • The first log pθ term maximizes the probability of our observed image x given a sample z, so the model makes decoder pθ to reconstruct image x​ as accurate as possible.
  • The second KL term restricts the encoding of an image to be close to the actual prior, which makes sure at inference / generate time, we can directly sample from the prior.

Comparison with Auto-Encoder

vae-and-ae

The VAE’s decoder is trained to convert random points in the embedding space (generated by perturbing the input encodings) to sensible outputs. By contrast, the decoder for the deterministic autoencoder only ever gets as inputs the exact encodings of the training set, so it does not know what to do with random inputs that are outside what it was trained on. So a standard autoencoder cannot create new samples.

The reason the VAE is better at sample is that it embeds images into Gaussians in latent space, whereas the AE embeds images into points, which are like delta functions. The advantage of using a latent distribution is that it encourages local smoothness, since a given image may map to multiple nearby places, depending on the stochastic sampling. By contrast, in an AE, the latent space is typically not smooth, so images from different classes often end up next to each other. Figure copied from Probabilistic Machine Learning: An Introduction - Figure 20.26

We can leverage the smoothness of the latent space to perform image interpolation in latent space.

Reference

Most content of this blog post comes from Berkeley CS 285 (Sergey Levine): Lecture 18, Variational Inference, which I think organized his lecture based on An Introduction to Variational Autoencoders (2.1-2.7, and 2.9.1), or more in-depth on the author’s PhD thesis Variational Inference and Deep Learning: A New Synthesis I found this wonderful tutorial in Probabilistic Machine Learning: Advanced Topics

Some graph come from Probabilistic Machine Learning: An Introduction itself and Carl Doersch: Tutorial on Variational Autoencoders, which is referenced in the previous book.

Note though the Probabilistic Machine Learning book itself is a horrible book with extremely confusing explanations.

CATALOG
  1. 1. Probabilistic Latent Variable Models
    1. 1.1. Latent Variable Models in General
    2. 1.2. How to Train a Latent Variable Model
  2. 2. Variational Inference
    1. 2.1. Variational Approximation
    2. 2.2. Effect of Pushing Up ELBO (Analytically)
    3. 2.3. The Learning Algorithm?
  3. 3. Amortized
    1. 3.1. Gradient Over Expectation (Policy Gradient)
    2. 3.2. Reparametrization Trick
    3. 3.3. Looking at ℒ Directly
  4. 4. Variational Autoencoder
    1. 4.1. Setup and Interpretation
    2. 4.2. Comparison with Auto-Encoder
  5. 5. Reference