← back

Latent Diffusion Encoding and Decoding

#402 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the encoding and decoding steps for latent diffusion models (LDMs). Instead of diffusing in pixel space, LDMs use a pretrained VAE to encode images into a lower-dimensional latent space, perform diffusion there, then decode back to pixel space.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np

class LatentDiffusion:
    def __init__(self, encoder_fn, decoder_fn, scale_factor: float = 0.18215):
        self.encoder_fn = encoder_fn
        self.decoder_fn = decoder_fn
        self.scale_factor = scale_factor

    def encode(self, x: np.ndarray) -> np.ndarray:
        # x: (batch, channels, height, width)
        # Encoder outputs mean and log_var of latent distribution
        mu, log_var = self.encoder_fn(x)

        # Reparameterization trick
        std = np.exp(0.5 * log_var)
        eps = np.random.randn(*mu.shape)
        z = mu + std * eps

        # Scale latents
        z = z * self.scale_factor
        return z

    def decode(self, z: np.ndarray) -> np.ndarray:
        # Unscale latents
        z = z / self.scale_factor
        return self.decoder_fn(z)

    def forward_diffusion(self, z: np.ndarray, t: int, alpha_bars: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        noise = np.random.randn(*z.shape)
        sqrt_ab = np.sqrt(alpha_bars[t])
        sqrt_1_ab = np.sqrt(1 - alpha_bars[t])
        zt = sqrt_ab * z + sqrt_1_ab * noise
        return zt, noise

Explanation

  1. Encode: the pretrained VAE encoder maps high-resolution images to a lower-dimensional latent space (e.g., 512x512 to 64x64x4). The reparameterization trick samples from the latent distribution.
  2. Scale factor: latents are scaled by a constant (0.18215 in Stable Diffusion) to ensure unit variance, matching the assumptions of the diffusion process.
  3. Diffusion: the forward and reverse diffusion processes operate entirely in latent space, which is computationally much cheaper.
  4. Decode: after reverse diffusion completes, the VAE decoder maps latents back to pixel space.

Complexity

  • Time: O(d_latent) for diffusion (much smaller than pixel space), plus encoder/decoder costs
  • Space: O(d_latent) for latent representations (e.g., 64x reduction in spatial dimensions)