Implement a single DDIM (Denoising Diffusion Implicit Models) deterministic sampling step. Unlike DDPM which adds stochastic noise, DDIM provides a deterministic mapping from noise to data, enabling faster sampling with fewer steps.
import numpy as np
def ddim_step(
xt: np.ndarray,
predicted_noise: np.ndarray,
t: int,
t_prev: int,
alpha_bars: np.ndarray,
eta: float = 0.0
) -> np.ndarray:
alpha_bar_t = alpha_bars[t]
alpha_bar_prev = alpha_bars[t_prev] if t_prev >= 0 else 1.0
# Predict x_0
x0_pred = (xt - np.sqrt(1 - alpha_bar_t) * predicted_noise) / np.sqrt(alpha_bar_t)
# Optionally clip x0 prediction
x0_pred = np.clip(x0_pred, -1.0, 1.0)
# Compute variance
sigma_t = eta * np.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar_t)) * np.sqrt(1 - alpha_bar_t / alpha_bar_prev)
# Direction pointing to x_t
dir_xt = np.sqrt(1 - alpha_bar_prev - sigma_t ** 2) * predicted_noise
# Combine
x_prev = np.sqrt(alpha_bar_prev) * x0_pred + dir_xt
if eta > 0:
noise = np.random.randn(*xt.shape)
x_prev = x_prev + sigma_t * noise
return x_prevx_0 from the current noisy image and predicted noise using the relationship x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1-alpha_bar_t) * eps.eta. When eta=0, the process is fully deterministic (DDIM); when eta=1, it matches DDPM.x_t.x_{t-1} = sqrt(alpha_bar_{t-1}) * x_0_pred + direction + sigma * noise.