← back

Diffusion Reconstruction Loss

#302 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the reconstruction loss (denoising score matching loss) for a diffusion model. Given the original data, the noisy data, the predicted noise, and the actual noise added, compute the mean squared error loss.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

def diffusion_reconstruction_loss(noise_true: np.ndarray,
                                   noise_pred: np.ndarray) -> float:
    """Simple MSE loss between true and predicted noise."""
    return float(np.mean((noise_true - noise_pred) ** 2))

def weighted_diffusion_loss(noise_true: np.ndarray,
                            noise_pred: np.ndarray,
                            timesteps: np.ndarray,
                            max_timesteps: int = 1000) -> float:
    """Weighted loss that gives more importance to certain timesteps."""
    weights = 1.0 / (1.0 + timesteps / max_timesteps)
    per_sample_loss = np.mean((noise_true - noise_pred) ** 2, axis=tuple(range(1, noise_true.ndim)))
    return float(np.mean(weights * per_sample_loss))

Explanation

  1. In diffusion models, the training objective is to predict the noise that was added to the original data at a given timestep.
  2. Simple loss: L = E[||epsilon - epsilon_hat||^2] where epsilon is the true noise and epsilon_hat is the model's prediction.
  3. Weighted loss: different timesteps can be weighted differently. Early timesteps (small noise) may be weighted more heavily since they affect final image quality more.
  4. This loss is equivalent to a variational lower bound on the data log-likelihood under certain parameterizations.

Complexity

  • Time: O(n * d) where n is batch size and d is data dimensionality
  • Space: O(n * d) for intermediate difference array