← back

Variational Inference: ELBO Computation

#206 · Probability · Hard

⊣ Solve on deep-ml.com

Problem

Implement the Evidence Lower Bound (ELBO) computation for variational inference. Given samples from a variational posterior q(z|x), a likelihood model p(x|z), and a prior p(z), compute the ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z)).

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np

def gaussian_log_prob(x, mean, log_var):
    var = np.exp(log_var)
    return -0.5 * (np.log(2 * np.pi) + log_var + (x - mean) ** 2 / var)

def kl_divergence_gaussians(mu_q, logvar_q, mu_p=0.0, logvar_p=0.0):
    # KL(q || p) for diagonal Gaussians
    var_q = np.exp(logvar_q)
    var_p = np.exp(logvar_p)
    kl = 0.5 * (logvar_p - logvar_q + var_q / var_p +
                 (mu_q - mu_p) ** 2 / var_p - 1.0)
    return np.sum(kl)

def compute_elbo(x, z_samples, mu_q, logvar_q,
                 decoder_fn, mu_p=0.0, logvar_p=0.0):
    # Reconstruction term: E_q[log p(x|z)]
    n_samples = len(z_samples)
    recon = 0.0
    for z in z_samples:
        x_recon = decoder_fn(z)
        # Assuming Gaussian likelihood with unit variance
        recon += -0.5 * np.sum((x - x_recon) ** 2)
    recon /= n_samples

    # KL term
    kl = kl_divergence_gaussians(mu_q, logvar_q, mu_p, logvar_p)

    elbo = recon - kl
    return float(elbo), float(recon), float(kl)

def reparameterize(mu, logvar, n_samples=1):
    std = np.exp(0.5 * logvar)
    samples = []
    for _ in range(n_samples):
        eps = np.random.randn(*mu.shape)
        samples.append(mu + std * eps)
    return samples

Explanation

  1. Reparameterization trick: Sample z = mu + std * epsilon where epsilon ~ N(0, I). This makes sampling differentiable with respect to mu and logvar.
  2. Reconstruction term: Monte Carlo estimate of E_q[log p(x|z)] using z samples. With a Gaussian likelihood, this reduces to the negative squared error.
  3. KL divergence: Analytically computed between two diagonal Gaussians: KL(q||p) = 0.5 * sum(logvar_p - logvar_q + var_q/var_p + (mu_q - mu_p)^2/var_p - 1).
  4. ELBO = Reconstruction - KL. Maximizing the ELBO simultaneously maximizes the likelihood of the data and keeps q close to the prior.

Complexity

  • Time: O(S * D) where S is the number of samples and D is the latent dimension
  • Space: O(S * D) for storing samples