← back

Noise Prediction Loss for Diffusion Training

#400 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the noise prediction loss (simple loss) used for training diffusion models. Given a batch of clean data, sample random timesteps and noise, create noisy versions, and compute the MSE between the predicted and actual noise.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np

def diffusion_training_loss(
    model_fn,
    x0: np.ndarray,
    alpha_bars: np.ndarray,
    num_timesteps: int
) -> float:
    batch_size = x0.shape[0]

    # Sample random timesteps
    t = np.random.randint(0, num_timesteps, size=batch_size)

    # Sample noise
    noise = np.random.randn(*x0.shape)

    # Create noisy samples
    sqrt_alpha_bar = np.sqrt(alpha_bars[t])
    sqrt_one_minus_alpha_bar = np.sqrt(1.0 - alpha_bars[t])

    # Reshape for broadcasting
    while sqrt_alpha_bar.ndim < x0.ndim:
        sqrt_alpha_bar = sqrt_alpha_bar[..., np.newaxis]
        sqrt_one_minus_alpha_bar = sqrt_one_minus_alpha_bar[..., np.newaxis]

    xt = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise

    # Predict noise
    predicted_noise = model_fn(xt, t)

    # Simple MSE loss
    loss = np.mean((predicted_noise - noise) ** 2)
    return float(loss)

Explanation

  1. Sample a random timestep t for each item in the batch uniformly from [0, T).
  2. Sample Gaussian noise epsilon of the same shape as the data.
  3. Create the noisy version using the forward process: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon.
  4. The model predicts the noise from x_t and t.
  5. Loss is the MSE between predicted and actual noise. This "simple" loss from the DDPM paper works better in practice than the full variational bound.

Complexity

  • Time: O(B * d) plus the cost of the model forward pass
  • Space: O(B * d) for noise and noisy samples