Implement a single reverse sampling step of DDPM. Given a noisy image at timestep t, the predicted noise, and the noise schedule parameters, compute the denoised sample at timestep t-1.
import numpy as np
def ddpm_reverse_step(
xt: np.ndarray,
predicted_noise: np.ndarray,
t: int,
betas: np.ndarray,
alphas: np.ndarray,
alpha_bars: np.ndarray
) -> np.ndarray:
beta_t = betas[t]
alpha_t = alphas[t]
alpha_bar_t = alpha_bars[t]
# Predict x_0
sqrt_alpha_bar_t = np.sqrt(alpha_bar_t)
sqrt_one_minus_alpha_bar_t = np.sqrt(1.0 - alpha_bar_t)
# Compute mean of p(x_{t-1} | x_t)
coeff1 = 1.0 / np.sqrt(alpha_t)
coeff2 = beta_t / sqrt_one_minus_alpha_bar_t
mean = coeff1 * (xt - coeff2 * predicted_noise)
if t == 0:
return mean
# Variance
sigma_t = np.sqrt(beta_t)
noise = np.random.randn(*xt.shape)
return mean + sigma_t * noisep(x_{t-1} | x_t) as a Gaussian with a learned mean and fixed variance.mu = (1/sqrt(alpha_t)) * (x_t - (beta_t / sqrt(1 - alpha_bar_t)) * epsilon_predicted).beta_t (simplified schedule).t=0, return the mean directly without adding noise since we want the final clean sample.t>0, sample from the Gaussian by adding scaled random noise to the mean.