Expectation Maximization Sketch

Dontloo bio photo By Dontloo


Say we have observed data \(X\), the latent variable \(Z\) and parameter \(\theta\), we want to maximize the log-likelihood \(\log p(X|\theta)\). Sometimes it’s not an easy task, probably because it doesn’t have a closed-form solution, the gradient is difficult to compute, or there’re complicated constraints that \(\theta\) must satisfy.

If somehow the joint log-likelihood \(\log p(X, Z|\theta)\) can be maximized more easily, we can turn to the Expectation Maximization algorithm for help. There are several ways to formulate the EM algorithm, as will be discussed in this blog.

Joint Log-likelihood

The basic idea is just to optimize the joint log-likelihood \(\log p(X, Z|\theta)\) instead of the data log-likelihood \(\log p(X|\theta)\). But since the true values of latent variables \(Z\) are unknown, we need to estimate a posterior distribution \(p(z|x, \theta)\) for each data point \(x\), then maximize the expected log-likelihood over the posterior
\[\sum_x E_{p_{z|x}}[\log p(x,z|\theta)].\]

The optimization follows two iterative steps. The E-step computes the expectation under the current parameter \(\theta’\) \[\sum_x \sum_{z} p(z|x, \theta’) \log p(x,z|\theta) = Q(\theta|\theta’).\] The M-step tries to find the new parameter \(\theta\) that maximizes \(Q(\theta|\theta’)\). It turns out that such method is guaranteed to find a local maximum data log-likelihood \(\log p(X|\theta)\), as will be shown in later sections.

Evidence Lower Bound (ELBO)

One way to derive EM formally is via constructing the evidence lower bound of \(\log p(X|\theta)\) using Jensen’s inequality \[\log p(X|\theta) = \sum_x \log p(x|\theta)\] \[= \sum_x \log \sum_z p(x, z|\theta)\] \[= \sum_x \log \sum_z q_{z|x}(z) \frac{p(x, z|\theta)}{q_{z|x}(z)}\] \[\geq \sum_x \sum_z q_{z|x}(z) \log \frac{p(x, z|\theta)}{q_{z|x}(z)}\] where \(q_{z|x}(z)\) is an arbitrary distribution over the latent variable associated with data point \(x\).

At the E-step, we keep \(\theta\) fixed and find the \(q\) that makes the equality hold. Since \(q\) has to satisfy the properties of being a probability distribution, the problem becomes, for each data point \(x\), \[\max_{q_{z|x}(z)} \sum_z q_{z|x}(z) \log \frac{p(x, z|\theta)}{q_{z|x}(z)}\]
\[q_{z|x}(z)\geq 0, \sum_z q_{z|x}(z) = 1.\] Knowing from the previous section, the solution to this should be \(q_{z|x}(z) = p(z|x, \theta)\). Specifically, if \(z\) is a discrete variable, it can be solved using Lagrange multipliers, see this tutorial by Justin Domke (my teacher ;).

At the M-step we maximize over \(\theta\) while keeping \(q_{z|x}\) fixed. \[\sum_x \sum_z q_{z|x}(z) \log \frac{p(x, z|\theta)}{q_{z|x}(z)} \] \[= \sum_x \sum_z q_{z|x}(z) \log p(x, z|\theta) - \sum_x \sum_z q_{z|x}(z) \log q_{z|x}(z) \] \[= Q(\theta|\theta’) + \sum_x H(q_{z|x}) \] The second term \(H(q_{z|x})\) is independent of \(\theta\) given \(q_{z|x}\) is fixed, so we only need to optimize \(Q(\theta|\theta’)\) which is in line with the previous formulation.

So the M-step maximizes the lower bound w.r.t \(\theta\)) and the E-step sets a new lower bound based on the current value \(\theta\).

Latent Distribution

Let’s see now to decompose the lower bound from the data likelihood without using Jensen’s inequality. For simplicity only the derivation of one data point \(x\) is given here. \[\log p(x|\theta) = \sum_z q_{z|x}(z) \log p(x|\theta) \] \[= \sum_z q_{z|x}(z) \log \frac{p(x,z|\theta)}{p(z|x,\theta)} \] \[= \sum_z q_{z|x}(z) \log \frac{p(x,z|\theta)q_{z|x}(z)}{p(z|x,\theta)q_{z|x}(x)} \] \[= \sum_z q_{z|x}(z) \log \frac{p(x,z|\theta)}{q_{z|x}(z)} - \sum_z q_{z|x}(z) \log \frac{p(z|x,\theta)}{q_{z|x}(z)}\] \[= F(q_{z|x}, \theta) + D_{KL}(q_{z|x} | p_{z|x})\] Here \(F(q_{z|x}, \theta)\) is the evidence lower bound and the remaining term is the KL divergence between the latent distribution \(q_{z|x}(z)\) and the posterior \(p(z|x,\theta)\).

We’ve formalized the lower bound as a function (functional) of two parameters, EM essentially does the optimization via coordinate ascent.

In the E-step we optimize \(F(q_{z|x}, \theta)\) w.r.t \(q_{z|x}\) while holding \(\theta\) fixed. Since \(\log p(x|\theta)\) does not depend on \(q_{z|x}\), the largest value of \(F(q_{z|x}, \theta)\) occurs when \(D_{KL}(q_{z|x} | p_{z|x})=0\), we have again \(q_{z|x}(z) = p(z|x,\theta)\). In the M-step \(F(q_{z|x}, \theta)\) is maximized w.r.t \(\theta\), which is the same as the above section.

KL Divergence

It turns out the lower bound \(F(q_{z|x}, \theta)\) above is also in the form of KL divergence. If we let \(q(x,z) = q(z|x)p(x)\), where \(q(z|x) = q_{z|x}(z)\) and \(p(x)=\frac{1}{|X|}\sum_{i}\delta_i(x)\) is a distribution that places all its mass on the observed data \(X \), we have \[\sum_x p(x)f(x) = |X|\sum_x f(x). \] Then the lower bound can be rewritten as \[\sum_x F(q_{z|x}, \theta) = \sum_x \sum_z q_{z|x}(z) \log \frac{p(x,z|\theta)}{q_{z|x}(z)}\] \[= \frac{|X|}{\log|X|} \sum_x \sum_z q(x,z) \log \frac{p(x,z|\theta)}{ q(z|x) } \] \[= -\frac{|X|}{\log|X|} D_{KL}(q_{x,z} | p_{x,z}). \] Therefore the E-step is minimizing \(D_{KL}(q_{x,z} | p_{x,z})\) w.r.t \(p_{x,z}\).

Similarly for the \(D_{KL}(q_{z|x} | p_{z|x})\) term, it follows \[\sum_x D_{KL}(q_{z|x} | p_{z|x}) = - \sum_x \sum_z q_{z|x}(z) \log \frac{p(z|x,\theta)}{ q_{z|x}(z) }\] \[= - |X| \sum_x \sum_z q(x,z) \log \frac{p(x, z, \theta)}{q(x,z)}\] \[= |X| D_{KL}(q_{x,z} | p_{x,z}) \] So the M-step becomes is minimizing the same KL divergence \(D_{KL}(q_{x,z} | p_{x,z})\) but w.r.t \(q\).

Since \(q(x,z)\) follows the restriction that it must aline with the data, and \(p(x,z|\theta)\) must be a distribution under the specified model, they can be thought of as living on two manifolds in the space of all distributions, namely the data manifold and the model manifold. Therefore EM can be viewed as to minimize the distance between two manifolds \(D_{KL}(q_{x,z} | p_{x,z})\) via coordinate descent.

More about the geometric view of EM please refer to this paper, also see this question on SE.

Log-sum to Sum-log

In spite of these different views of EM, the advantage of EM lies in Jensen’s inequality which moves the logarithm inside the summation \[\sum_x \log \sum_z q_{z|x}(z) \frac{p(x, z|\theta)}{q_{z|x}(z)} \geq \sum_x \sum_z q_{z|x}(z) \log \frac{p(x, z|\theta)}{q_{z|x}(z)}. \] If the joint distribution \(p(x, z|\theta)\) belongs to the exponential family, it turns a log-sum-exp operation into a weighted summation of the exponents (often sufficient statistics), which could be easier to optimize.

Alternatives for E-step and M-step

Sometimes we aren’t able to reach the optimal solution to the E-step or the M-step, probably because the difficulty in calculation, optimization, trade-off between simplicity and accuracy, or other restrictions on distributions or parameters. In these cases, we can use alternative approaches for a suboptimal solution.

For example K-means is a special case of EM for GMMs, where the latent distribution is restricted to be a delta function (hard assignment).

In LDA, a prior distribution is added to the parameter, thus has made the parameter another latent variable and the posterior of latent variables becomes difficult to compute. So variational methods are used for approximation, specifically, the latent distribution \(q\) is characterized by a variational model with parameter \(\psi\). Then in the E-step we optimize \(q\) w.r.t \(\phi\) and in the M-step we optimize \(p\) w.r.t \(\theta\). For parameters that can not be solved in closed-form, gradient based optimization are applied.