Implement the encoding and decoding steps for latent diffusion models (LDMs). Instead of diffusing in pixel space, LDMs use a pretrained VAE to encode images into a lower-dimensional latent space, perform diffusion there, then decode back to pixel space.
import numpy as np
class LatentDiffusion:
def __init__(self, encoder_fn, decoder_fn, scale_factor: float = 0.18215):
self.encoder_fn = encoder_fn
self.decoder_fn = decoder_fn
self.scale_factor = scale_factor
def encode(self, x: np.ndarray) -> np.ndarray:
# x: (batch, channels, height, width)
# Encoder outputs mean and log_var of latent distribution
mu, log_var = self.encoder_fn(x)
# Reparameterization trick
std = np.exp(0.5 * log_var)
eps = np.random.randn(*mu.shape)
z = mu + std * eps
# Scale latents
z = z * self.scale_factor
return z
def decode(self, z: np.ndarray) -> np.ndarray:
# Unscale latents
z = z / self.scale_factor
return self.decoder_fn(z)
def forward_diffusion(self, z: np.ndarray, t: int, alpha_bars: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
noise = np.random.randn(*z.shape)
sqrt_ab = np.sqrt(alpha_bars[t])
sqrt_1_ab = np.sqrt(1 - alpha_bars[t])
zt = sqrt_ab * z + sqrt_1_ab * noise
return zt, noise