A step-by-step explanation of how EM (expectation-maximization) is used for training a generative AI model
These notes are based on CMU Prof. Bhiksha Raj’s lecture video on generative models and variational autoencoders and his lecture slides, specifically the section about expectation maximization and variational inference. Prof. Raj gives a great intuitive explanation for the principles behind EM (expectation maximization), so I highly recommend watching the lecture video. But the overall topic is complex and there are many moving pieces, so here I will try to condense down to the key principles behind how this works. I hope these notes will be useful for both myself and others in the future.
Image space and latent space
Let \(x\) be a vector in a space of images, and \(z\) be a vector in a latent space of image classifications.
For example, if each image has \(n^2=n \cdot n\) pixels and each pixel has 3 numbers representing RGB values, then \(x\) could be a vector of \(3n^2\) elements, where each element is a number.
The conditional distribution \(P(x \mid z)\) is a probability distribution in \(x\), with nonzero probability over all \(x\) that satisfy the constraint \(z\). The peak of this distribution is for \(x\) that strongly satisfies \(z\); the tail of the distribution is for \(x\) that weakly satisfies \(z\).
For example, suppose \(z_0\) represents the category “rose”. Then a \(x_1\) which is an image of a rose would strongly satisfy \(z_0\), while a \(x_2\) that is a begonia would weakly satisfy \(z_0\), and a \(x_3\) that is a cat would not at all satisfy \(z_0\). This means that \(P(x_1 \mid z_0) > P(x_2 \mid z_0) > P(x_3 \mid z_0)\).
\(P(x,z)\) is the joint probability density of \(x\) and \(z\). If we sample \((x,z) \sim P(x,z)\), then each \((x,z)\) sample would be an image together with its latent classification. The image and its classification would be randomly chosen, with the most probable image/classification pair most frequently sampled.
If we instead want to generate images with a specific classification (e.g., “rose”), we need to fix \(z\) and randomly sample \(x \sim P(x \mid z)\).
Confusion: do we need to know \(P(z)\)?
We know that the conditional probability is given by \(P(x \mid z) = P(x,z)/P(z)\). So it appears that in order to sample from the conditional probability, we not only need to know the joint density \(P(x,z)\), but also need to know \(P(z)\). But if we have already decided to generate rose images and have already chosen \(z=z_0\), what does \(P(z)\) mean? And what if \(P(z)\) is difficult to evaluate (as it usually is)?
A quick way to think about this is to realize that \(P(x \mid z)\) is a probability density of \(x\). And therefore it must normalize to \(1\) when we integrate over all values of \(x\). And so the \(P(z)\) in the denominator is only there to ensure the normalization. But if we use a sampling method that does not require normalization, (e.g., MCMC), then we don’t care about this denominator \(P(z)\). And then a model for either \(P(x,z)\) or \(P(x \mid z)\) is sufficient. (But for a deeper dive into the subtleties here, see this post.)
So we start with a parametrized model \(P(x,z;\theta)\), designed so that any observed data distribution can be modeled with the appropriate settings for \(\theta\). To reiterate: this is our hypothesis for the probability density, fully specified once the parameters vector \(\theta\) is specified.
Also to clarify, the ground truth distribution is \(P(x,z)\) (you can see that it does not have the \(\theta\) parameters), but unfortunately we do not know the ground truth; and the model to be learned is \(P(x,z;\theta),\) which is our estimate/hypothesis for the ground truth. We apply maximum likelihood estimation (MLE) to learn the optimal \(\theta\), by maximizing the log-likelihoods evaluated over the observed data.
Not having all the data is ok
If we have observations of vectors \(x\) and \(z\), then we can optimize \(\theta\) to maximize the sum of the log-likelihoods, where the log-likelihood is the log of the parametrized model \(P(x,z;\theta)\) evaluated at the observed values for \(x\) and \(z\).
But since we only observe \(x\) and don’t have observations of \(z\), we need to use the log-likelihood of \(P(x;\theta)\) and feed the observed \(x\) values into that. We get \(P(x;\theta)\) by integrating over \(z\):
\[\begin{align} P(x;\theta) &= \int dz\, P(x,z;\theta) \end{align}\]Learning the model
So we need to perform the following optimization:
\[\begin{align} \theta_{opt} & = \mathrm{argmax}_{\theta} \sum_i \log P(x_i;\theta) \\ & = \mathrm{argmax}_{\theta} \sum_i \log \int dz\, P(x_i,z;\theta), \end{align}\]where the sum over \(i\) is the sum over the observed data \(x_i\).
The log of an integral may be difficult to maximize because taking a derivative would move a copy of the integral into the denominator, and this integral depends on the specific \(x_i\). So we would need to evaluate this integral for each \(x_i\), and the problem is that the integral could be over a high dimensional \(z\)-space. (BTW, it appears that for some problems this integral is actually tractable, so there is some subtlety here.) But in general, this may be difficult to maximize, so we go with the following inequality:
\[\begin{align} \sum_i \log \int dz\, P(x_i,z;\theta) &= \sum_i \log \int dz\, Q(z;x_i,\phi) \frac{P(x_i,z;\theta)}{Q(z;x_i,\phi)} \\ &\geq \sum_i \int dz\, Q(z;x_i,\phi) \log \frac{P(x_i,z;\theta)}{Q(z;x_i,\phi)} \tag{1}\label{eq:1} \\ &:= \sum_i J(x_i;\theta,\phi), \end{align}\]where \(Q(z;x_i,\phi)\) is an arbitrary probability density of \(z\) (i.e., \(Q(z;x_i,\phi)\) is normalized) that is parametrized by \(\phi\) and we used Jensen’s Inequality. We would like to pick a \(\phi\) that results in a \(Q(z;x_i,\phi)\) that maximizes \(J\). (Note that we use a slightly different notation compared to the Raj lecture: \(z\) here instead of \(h\); \(x_i\) instead of \(o\).) \(J\) is called the ELBO, or evidence lower bound, because it is a lower bound on the log-likelihood (the “evidence”).
We now have two knobs to tune: \(J\)’s parameters include \(\theta\), which are parameters (e.g., means and variances of various Gaussians) that specify the joint probability density model \(P(x,z;\theta)\); and \(\phi\), which are parameters that specify the \(Q(z;x,\phi)\) probability density of \(z\). We start with a specific parametrized model \(P(x,z;\theta)\), where the optimal \(\theta\) is not yet known; that enables us to construct the ELBO \(J\); we use \(J\) to find an optimal \(\theta\) that results in the “evidence” being close to maximum; and that optimal \(\theta\) is then plugged into our parametrized model \(P(x,z;\theta)\) for the joint probability density of \(x\) and \(z\).
There are a lot of symbols flying around here, but if we zoom out, we see that the overall task is to adjust \(\theta\) and \(\phi\) to make the ELBO \(J\), or more precisely the summed ELBO \(\sum_i J(x_i;\theta,\phi)\), as large as possible. And if we do that in the right way, as we’ll see in the next section, we would be able to ensure that \(P(x,z;\theta)\) would be an accurate maximum likelihood model of the ground truth \(P(x,z)\).
Who said “Greed is Good”?
When we have multiple adjustable parameters in an optimization problem, a common strategy is called “greedy”; here, that means we fix \(\theta\) and optimize \(\phi\), then fix \(\phi\) and optimize \(\theta\), and rinse and repeat until we are close enough. But this only enables us to find a value for \(\theta\) that maximizes the ELBO \(J\); it’s not obvious why this value for \(\theta\) would also yield the maximimum value for our “evidence” (the log of the integral of the parametrized model \(P(x,z;\theta)\)), which is our ultimate goal in maximum likelihood estimation. We know that the ELBO is a lower bound for the evidence, but it’s possible that the \(\theta\) that produces the largest ELBO actually leads to a non-maximal evidence. This would happen if, e.g., we start at \(\theta_1\) where the old ELBO is far below the old evidence, and we end at a new \(\theta_2\) where the new ELBO has slightly increased, but the new evidence is only slightly above the new ELBO, so that the new evidence actually is lower than the old evidence.
But it turns out that this is mostly not a problem, as there exists a function \(Q_{opt}(z;x,\theta)\) that converts Eq. \(\eqref{eq:1}\) from an inequality to an equality (which means that the evidence and the ELBO are equal, and also that \(J\) is maximized). Therefore, if \(Q(z;x,\phi)\)’s dependence on \(\phi\) is very “expressive” and it’s possible to adjust \(\phi\) to make \(Q(z;x,\phi)\) be very close to \(Q_{opt}(z;x,\theta)\), then that is the \(\phi\) that approximately maximizes \(J\). Call that value \(J_1\), and note that \(J_1\) is very close to the value of the evidence (this means that we do not have the issue, from the prior paragraph, of having the ELBO far below the evidence). And then if we fix \(\phi\) and maximize \(J\) further over \(\theta\), the resultant \(J\) will be greater than or equal to \(J_1\), and more importantly the new evidence will be greater than or equal to the old evidence.
How do we know that such a \(Q_{opt}(z;x,\theta)\) exist? Because we can plug in \(Q_{opt}(z;x,\theta) = P(z \mid x;\theta)\) and show that this leads to equality between the LHS and the RHS. Remember that if we fix the value of \(\theta\), then \(P(x,z;\theta)\) is just a joint probability density of the two vectors \(x\) and \(z\), which uniquely determines the conditional probability \(P(z \mid x;\theta)\). If we plug \(Q_{opt}(z,x) = P(z \mid x;\theta)\) into the RHS of Eq. \(\eqref{eq:1}\), we see that the log term becomes \(\log P(x_i;\theta)\), and since that is independent of \(z\) it can be factored out of the integral, which means the integral is \(1\) because it’s just the normalization of a probability density, which means that the RHS is equal to the LHS.
Just fill in the blanks
And we get a very nice intuitive interpretation if we look more closely at the ELBO.
Suppose we start with an initial value \(\theta_0\). We optimize \(\phi\) and arrive at \(\phi_0\), such that \(Q(z;x,\phi_0) \approx P(z \mid x;\theta_0)\). (Just to be clear: the reason we don’t just set \(Q(z;x,\phi)\) to be equal to \(P(z \mid x;\theta)\) is because we only know \(P(x,z;\theta)\).)
Next, we fix \(\phi_0\) (meaning we fix \(Q(z;x,\phi_0) \approx P(z \mid x;\theta_0)\), which is not a function of \(\theta\)) and optimize the ELBO over \(\theta\):
\[\begin{align} \mathrm{argmax}_{\theta} \sum_i \int dz\, Q(z;x_i,\phi_0) \log \frac{P(x_i,z;\theta)}{Q(z;x_i,\phi_0)} &\approx \mathrm{argmax}_{\theta} \sum_i \int dz\, P(z \mid x_i;\theta_0) \log \frac{P(x_i,z;\theta)}{P(z \mid x_i;\theta_0)} \\ &= \mathrm{argmax}_{\theta} \sum_i \int dz\, P(z \mid x_i;\theta_0) \log P(x_i,z;\theta) , \end{align}\]where we dropped \(P(z \mid x_i;\theta_0)\) from the denominator of the log term because it is not dependent on \(\theta\). And now we see from the last line that we are really optimizing the log-likelihood of the full joint probability density \(P(x,z;\theta)\), and “filling in the missing values of \(z\)” by weighting each possible \(z\) value with the conditional probability (assuming the model parametrized by \(\theta_0\)) of \(z\) given \(x_i\).
Summary of using EM to learn a generative model
Whew, that was still pretty complicated! Here is a final summary:
- We focus on an example of a generative model: a joint probability density over \(x\), the space of images, and \(z\), the latent space of classifications or categories
- If we had the conditional probability \(P(x\mid z)\), then we could pick a \(z\) and generate samples of \(x\) that fit that category
- Unfortunately, we do not know this conditional probability. Instead, we create a parametrized model \(P(x,z;\theta)\), and use a training set of images \(x_i\) to learn the parameter \(\theta\) and the relationship between \(x\) and the latent \(z\). We do this via maximum likelihood estimation (MLE) by adjusting \(\theta\) to maximize the evidence (log-likelihoods).
- Unfortunately, MLE usually requires that we have observations for both \(x\) and \(z\), but we only have \(x\). So we have the problem of missing data. So we need to marginalize out \(z\) by integrating over it. Unfortunately, that results in an ugly log of an integral, which is difficult to handle.
- So instead we optimize the ELBO, which is a lower bound on the evidence, and we chose to use the ELBO because it enables nice math cleverness that makes things work out. And we have two sets of parameters to optimize over, \(\theta\) and \(\phi\). The name “variational inference” comes about because \(\phi\) is a set of knobs to adjust the function \(Q(\cdot)\), and this makes one think of the calculus of variations (from physics, Euler-Lagrange equations, etc., even though here we do not use the calculus of variations at all).
- And it turns out there is a clever “greedy” way to do the optimization, where we optimize alternately over the parameters \(\phi\) and \(\theta\).
- And it turns out that this is mathematically very similar to guessing the most likely values of \(z\) given the observed \(x\), and using that to fill in the missing data (see step 4 above). We do this iteratively. We have some estimate (by optimizing \(\phi\)) of the likely values of \(z\) given \(x\) and use that to fill in the missing \(z\)’s, then we use that to update (via MLE and optimizing \(\theta\)) our parametrized model for the joint density of \(x\) and \(z\), then we use that to update (by further optimizing \(\phi\)) our estimate of the likely values of \(z\) given \(x\), and so on.
EM is all about filling in the missing \(z\)’s in the latent space by “making it up” according to the current hypothesis for \(z\), running MLE, then rinse and repeat.