← back

Implement DDPM Reverse Sampling Step

#396 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement a single reverse sampling step of DDPM. Given a noisy image at timestep t, the predicted noise, and the noise schedule parameters, compute the denoised sample at timestep t-1.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np

def ddpm_reverse_step(
    xt: np.ndarray,
    predicted_noise: np.ndarray,
    t: int,
    betas: np.ndarray,
    alphas: np.ndarray,
    alpha_bars: np.ndarray
) -> np.ndarray:
    beta_t = betas[t]
    alpha_t = alphas[t]
    alpha_bar_t = alpha_bars[t]

    # Predict x_0
    sqrt_alpha_bar_t = np.sqrt(alpha_bar_t)
    sqrt_one_minus_alpha_bar_t = np.sqrt(1.0 - alpha_bar_t)

    # Compute mean of p(x_{t-1} | x_t)
    coeff1 = 1.0 / np.sqrt(alpha_t)
    coeff2 = beta_t / sqrt_one_minus_alpha_bar_t
    mean = coeff1 * (xt - coeff2 * predicted_noise)

    if t == 0:
        return mean

    # Variance
    sigma_t = np.sqrt(beta_t)
    noise = np.random.randn(*xt.shape)
    return mean + sigma_t * noise

Explanation

  1. The DDPM reverse process computes p(x_{t-1} | x_t) as a Gaussian with a learned mean and fixed variance.
  2. The mean is derived from the current noisy sample and the predicted noise: mu = (1/sqrt(alpha_t)) * (x_t - (beta_t / sqrt(1 - alpha_bar_t)) * epsilon_predicted).
  3. The variance is beta_t (simplified schedule).
  4. At t=0, return the mean directly without adding noise since we want the final clean sample.
  5. For t>0, sample from the Gaussian by adding scaled random noise to the mean.

Complexity

  • Time: O(d) where d is the data dimensionality
  • Space: O(d) for the noise and output