#400 · Deep Learning · Medium
⊣ Solve on deep-ml.comImplement the noise prediction loss (simple loss) used for training diffusion models. Given a batch of clean data, sample random timesteps and noise, create noisy versions, and compute the MSE between the predicted and actual noise.
import numpy as np
def diffusion_training_loss(
model_fn,
x0: np.ndarray,
alpha_bars: np.ndarray,
num_timesteps: int
) -> float:
batch_size = x0.shape[0]
# Sample random timesteps
t = np.random.randint(0, num_timesteps, size=batch_size)
# Sample noise
noise = np.random.randn(*x0.shape)
# Create noisy samples
sqrt_alpha_bar = np.sqrt(alpha_bars[t])
sqrt_one_minus_alpha_bar = np.sqrt(1.0 - alpha_bars[t])
# Reshape for broadcasting
while sqrt_alpha_bar.ndim < x0.ndim:
sqrt_alpha_bar = sqrt_alpha_bar[..., np.newaxis]
sqrt_one_minus_alpha_bar = sqrt_one_minus_alpha_bar[..., np.newaxis]
xt = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise
# Predict noise
predicted_noise = model_fn(xt, t)
# Simple MSE loss
loss = np.mean((predicted_noise - noise) ** 2)
return float(loss)t for each item in the batch uniformly from [0, T).x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon.x_t and t.