These notes are based on CMU Prof. Bhiksha Raj’s lecture video on generative models and variational autoencoders and his lecture slides, specifically the section about expectation maximization and variational inference. Prof. Raj gives a great intuitive explanation for the principles behind EM (expectation maximization), so I highly recommend watching the lecture video. But the overall topic is complex and there are many moving pieces, so here I will try to condense down to the key principles behind how this works. I hope these notes will be useful for both myself and others in the future.

[2/2/2024: initial post, 1/5/2026: many edits]

Image space and latent space

Let \(x\) be a vector in a space of images, and \(z\) be a vector in a latent space of image classifications.

For example, if each image has \(n^2=n \cdot n\) pixels and each pixel has 3 numbers representing RGB values, then \(x\) could be a vector of \(3n^2\) elements, where each element is a number.

The conditional probability distribution \(P(x \mid z)\) is a probability distribution over \(x\) whose support consists of those values of \(x\) consistent with the given latent classification \(z\). The peak of this distribution is for \(x\) that strongly corresponds to \(z\); the tail of the distribution is for \(x\) that weakly corresponds to \(z\).

For example, suppose \(z_0\) represents the category “rose”. Then a \(x_1\) which is an image of a rose would strongly satisfy \(z_0\), while a \(x_2\) that is a begonia would weakly satisfy \(z_0\), and a \(x_3\) that is a cat would not at all satisfy \(z_0\). This means that \(P(x_1 \mid z_0) > P(x_2 \mid z_0) > P(x_3 \mid z_0)\).

\(P(x,z)\) is the joint probability density of \(x\) and \(z\). If we sample \((x,z) \sim P(x,z)\), then each \((x,z)\) sample would be an image together with its latent classification. The image and its classification would be randomly chosen, most frequently sampling the most probable image/classification pairs.

If we instead want to generate images with a specific classification (e.g., “rose”), we need to fix \(z\) and randomly sample \(x \sim P(x \mid z)\).

Confusion: do we need to know \(P(z)\)?

We know that the conditional probability is given by \(P(x \mid z) = P(x,z)/P(z)\). So it appears that in order to sample from the conditional probability, we not only need to know the joint density \(P(x,z)\), but also need to know \(P(z)\). But what if \(P(z)\) is difficult to evaluate (as it usually is)?

Answer: we need to realize that \(P(x \mid z)\) is a probability density over \(x\). And therefore it must normalize to \(1\) when we integrate over all values of \(x\). And so the \(P(z)\) in the denominator is only there to ensure the normalization. But if we use a sampling method that does not require normalization (e.g., MCMC), then we don’t care about this denominator \(P(z)\). And then a model for either \(P(x,z)\) or \(P(x \mid z)\) is sufficient. (But for a deeper dive into the subtleties here, see this post.)

So we can use a parametrized model \(P(x,z;\theta)\), designed so that any observed data distribution can be estimated with the appropriate settings for \(\theta\). To reiterate: this is our hypothesis for the probability density, fully specified once the parameters vector \(\theta\) is specified.

Also to clarify, the ground truth distribution is \(P(x,z)\) (which does not have the \(\theta\) parameters), but unfortunately we do not know the ground truth; and the model to be learned is \(P(x,z;\theta),\) which is our estimate/hypothesis for the ground truth. We apply maximum likelihood estimation (MLE) to learn the optimal \(\theta\), by maximizing the log-likelihoods evaluated over the observed data.

Not having all the data is ok

If we have observations of vectors \(x\) and \(z\), then we can optimize \(\theta\) to maximize the sum of the log-likelihoods, where the log-likelihood is the log of the parametrized model \(P(x,z;\theta)\) evaluated at the observed values for \(x\) and \(z\).

But since we only observe \(x\) and don’t have observations of \(z\), we need to use the log-likelihood of \(P(x;\theta)\) and feed the observed \(x\) values into that. We get \(P(x;\theta)\) by integrating over \(z\):

\[\begin{align} P(x;\theta) &= \int dz\, P(x,z;\theta) \end{align}\]

Learning the model

So we need to perform the following optimization:

\[\begin{align} \theta_{opt} & = \operatorname{argmax}_{\theta} \sum_i \log P(x_i;\theta) \\ & = \operatorname{argmax}_{\theta} \sum_i \log \int dz\, P(x_i,z;\theta), \end{align}\]

where the sum over \(i\) is the sum over the observed data \(x_i\).

The log of an integral may be difficult to maximize because taking a derivative would move a copy of the integral into the denominator, and this integral depends on the specific \(x_i\). So we would need to evaluate this integral for each \(x_i\), and the problem is that the integral could be over a high dimensional \(z\)-space. (BTW, it appears that for some problems this integral is actually tractable, so there is some subtlety here.) So in general, this may be difficult to maximize, so we go with the following inequality:

\[\begin{align} \sum_i \log \int dz\, P(x_i,z;\theta) &= \sum_i \log \int dz\, Q(z;x_i,\phi) \frac{P(x_i,z;\theta)}{Q(z;x_i,\phi)} \\ &\geq \sum_i \int dz\, Q(z;x_i,\phi) \log \frac{P(x_i,z;\theta)}{Q(z;x_i,\phi)} \tag{1}\label{eq:1} \\ &:= \sum_i J(x_i;\theta,\phi), \end{align}\]

where \(Q(z;x_i,\phi)\) is an arbitrary probability density over \(z\) (i.e., \(Q(z;x_i,\phi)\) is normalized) that is parametrized by \(\phi\) and we used Jensen’s Inequality. We would like to pick a \(\phi\) that results in a \(Q(z;x_i,\phi)\) that maximizes \(J\). (Note that we use a slightly different notation compared to the Raj lecture: \(z\) here instead of \(h\); \(x_i\) instead of \(o\).) \(J\) is called the ELBO, or evidence lower bound, because the summed ELBOs is a lower bound on the summed log-likelihoods (the “evidence”). We see that for a fixed \(\phi\), both the summed “evidence” and the summed ELBOs are each a function of \(\theta\).

We now have two knobs to tune: \(J\)’s parameters include \(\theta\), which are parameters (e.g., means and variances of various Gaussians) that specify the joint probability density model \(P(x,z;\theta)\); and \(\phi\), which are parameters that specify the \(Q(z;x,\phi)\) probability density over \(z\). We start with initial values for \(\theta\) and \(\phi\); that enables us to construct the ELBO \(J\); we use \(J\) and tweak the knobs for \(\phi\) and \(\theta\) to find an optimal \(\theta\) that results in the summed “evidence” being close to maximum; and that optimal \(\theta\) is then plugged into our parametrized model \(P(x,z;\theta)\) to give us the joint probability density of \(x\) and \(z\).

There are a lot of symbols flying around here, but if we zoom out, we see that the overall task is to adjust \(\theta\) and \(\phi\) to make the ELBO \(J\), or more precisely the summed ELBOs \(\sum_i J(x_i;\theta,\phi)\), as large as possible. And if we do that in the right way, as we’ll see in the next section, we would be able to ensure that \(P(x,z;\theta)\) would be an accurate maximum likelihood model of the ground truth \(P(x,z)\).

Who said “Greed is Good”?

When we have multiple adjustable parameters in an optimization problem, a common strategy is called “greedy”; here, that means we fix \(\theta\) and optimize \(\phi\), then fix \(\phi\) and optimize \(\theta\), and rinse and repeat until we are close enough. But this only enables us to find a value for \(\theta\) that maximizes the summed ELBOs; it’s not obvious why this value for \(\theta\) would also yield the maximimum value for our summed “evidence” (the log of the integral of the parametrized model \(P(x,z;\theta)\)), which is our ultimate goal in maximum likelihood estimation. We know that the summed ELBOs is a lower bound for the summed evidence, but it’s possible that the \(\theta\) that produces the largest summed ELBOs actually leads to a non-maximal summed evidence. This may happen if, e.g., we start at \(\theta_1\) where the old summed evidence is far above the old summed ELBOs, and we end at a new \(\theta_2\) where the new summed ELBOs has slightly increased, but the new summed evidence is only slightly above the new summed ELBOs, so that the new summed evidence actually is lower than the old summed evidence. Intuitively, what happened was \(\phi\) was chosen badly, which caused the summed ELBOs (as a function of \(\theta\)) to be very different from the summed evidence (as a function of \(\theta\)). What we need is to have the summed ELBOs and the summed evidence to be very similar functions of \(\theta\), so that they share the same optimal value for \(\theta\).

But it turns out that this is mostly not a problem, as for any given value of \(\theta\), there exists a function \(Q_{opt}(z;x)\) (i.e., the exact form of this function depends on \(\theta\)) that converts Eq. \(\eqref{eq:1}\) from an inequality to an equality (which means that the evidence and the ELBO are equal, and also that the summed ELBOs is maximized). Therefore, if \(Q(z;x,\phi)\)’s dependence on \(\phi\) is very “expressive” and it’s possible to adjust \(\phi\) to make \(Q(z;x,\phi)\) be very close to \(Q_{opt}(z;x)\), then that is the \(\phi\) that approximately maximizes the summed ELBOs. (In other words, if we tune \(\phi\) to optimize the summed ELBOs, then that would give us a \(Q(z;x,\phi)\) that is very close to \(Q_{opt}(z;x)\).) Call that summed ELBOs value \(S_1\), and note that \(S_1\) is very close to the value of the summed evidence at the given value of \(\theta\); this means that we do not have the issue, from the prior paragraph, of having the summed ELBOs far below the summed evidence. And then if we fix \(\phi\) and tune \(\theta\) to further maximize the summed ELBOs, the resultant summed ELBOs value will be greater than or equal to \(S_1\). By alternating between tuning \(\phi\) and \(\theta\), we monotonically increase the summed ELBOs. And critically, each time we tune \(\phi\) we also ensure that the summed ELBOs become an only slighly smaller function of \(\theta\) compared to the summed evidence.

How do we know that such a \(Q_{opt}(z;x)\) exist for each given value of \(\theta\)? Because we can plug in \(Q_{opt}(z;x) = P(z \mid x;\theta)\) and some simple math show that this leads to equality between the LHS and the RHS. Remember that if we fix the value of \(\theta\), then \(P(x,z;\theta)\) is just a joint probability density of the two vectors \(x\) and \(z\), which uniquely determines the conditional probability \(P(z \mid x;\theta)\). If we plug \(Q_{opt}(z,x) = P(z \mid x;\theta)\) into the RHS of Eq. \(\eqref{eq:1}\), we see that the log term becomes \(\log P(x_i;\theta)\), and since that is independent of \(z\) it can be factored out of the integral, which means the integral is \(1\) because it’s just the normalization of a probability density, which means that the RHS is equal to the LHS.

Just fill in the blanks

And we get a very nice intuitive interpretation if we look more closely at the ELBO.

Suppose we start with an initial value \(\theta_0\). We optimize \(\phi\) and arrive at \(\phi_0\), such that \(Q(z;x,\phi_0) \approx P(z \mid x;\theta_0)\). (Just to be clear: the reason we don’t just set \(Q(z;x,\phi)\) to be equal to \(P(z \mid x;\theta)\) is because we only know \(P(x,z;\theta)\).)

Next, we fix \(\phi_0\) (meaning we fix \(Q(z;x,\phi_0) \approx P(z \mid x;\theta_0)\), which, to be clear, is not a function of \(\theta\)) and optimize the ELBO over \(\theta\):

\[\begin{align} \mathrm{argmax}_{\theta} \sum_i \int dz\, Q(z;x_i,\phi_0) \log \frac{P(x_i,z;\theta)}{Q(z;x_i,\phi_0)} &\approx \mathrm{argmax}_{\theta} \sum_i \int dz\, P(z \mid x_i;\theta_0) \log \frac{P(x_i,z;\theta)}{P(z \mid x_i;\theta_0)} \\ &= \mathrm{argmax}_{\theta} \sum_i \int dz\, P(z \mid x_i;\theta_0) \log P(x_i,z;\theta) , \end{align}\]

where we dropped \(P(z \mid x_i;\theta_0)\) from the denominator of the log term because it is not dependent on \(\theta\). And now we see from the last line that we are really optimizing the log-likelihood, \(\log P(x,z;\theta)\), and “filling in the missing values of \(z\)” by weighting each possible \(z\) value with the conditional probability (assuming the model parametrized by \(\theta_0\)) over \(z\) given \(x_i\). In other words, we take \(\theta_0\) as our best current model, and use that to estimate the distribution over \(z\) in order to calculate the log likelihood, and use the resultant log likelihoods to let us further optimize \(\theta\). This is a type of “bootstrap” or “fill in the blank”.

Summary of using EM to learn a generative model

Whew, that was still pretty complicated! Here is a final summary:

  1. We focus on an example of a generative model: a joint probability density over \(x\), the space of images, and \(z\), the latent space of classifications or categories
  2. If we had the conditional probability \(P(x\mid z)\), then we could pick a \(z\) and generate samples of \(x\) that fit that category
  3. Unfortunately, we do not know this conditional probability. Instead, we create a parametrized model \(P(x,z;\theta)\), and use a training set of images \(x_i\) to learn the parameter \(\theta\) and the relationship between \(x\) and the latent \(z\). We do this via maximum likelihood estimation (MLE) by adjusting \(\theta\) to maximize the evidence (log-likelihoods).
  4. Unfortunately, MLE usually requires that we have observations for both \(x\) and \(z\), but we only have \(x\). So we have the problem of missing data. So we need to marginalize out \(z\) by integrating over it. Unfortunately, that results in an ugly log of an integral, which is difficult to handle.
  5. So instead we optimize the ELBO, which is a lower bound on the evidence, and we chose to use the ELBO because it enables nice math cleverness that makes things work out. And we have two sets of parameters to optimize over, \(\theta\) and \(\phi\). The name “variational inference” comes about because \(\phi\) is a set of knobs to adjust the parametrized function \(Q(\cdot)\), and this makes one think of the calculus of variations (from physics, Euler-Lagrange equations, etc., even though here we do not use the calculus of variations at all).
  6. And it turns out there is a clever “greedy” way to do the optimization, where we optimize alternately over the parameters \(\phi\) and \(\theta\).
  7. And it turns out that this is mathematically very similar to guessing the most likely values of \(z\) given the observed \(x\), and using that to fill in the missing data (see step 4 above). We do this iteratively. We have some estimate (by optimizing \(\phi\)) of the likely values of \(z\) given \(x\) (the “posterior”) and use that to fill in the missing \(z\)’s, then we use that to update (via MLE and optimizing \(\theta\)) our parametrized model for the joint density of \(x\) and \(z\), then we use that to update (by further optimizing \(\phi\)) our estimate of the likely values of \(z\) given \(x\), and so on.

EM is all about filling in the missing \(z\)’s in the latent space by “making it up” according to the current hypothesis for \(z\), running MLE, then rinse and repeat.


Share or comment on Mastodon

Share or comment on Bluesky