Implement the reconstruction loss (denoising score matching loss) for a diffusion model. Given the original data, the noisy data, the predicted noise, and the actual noise added, compute the mean squared error loss.
import numpy as np
def diffusion_reconstruction_loss(noise_true: np.ndarray,
noise_pred: np.ndarray) -> float:
"""Simple MSE loss between true and predicted noise."""
return float(np.mean((noise_true - noise_pred) ** 2))
def weighted_diffusion_loss(noise_true: np.ndarray,
noise_pred: np.ndarray,
timesteps: np.ndarray,
max_timesteps: int = 1000) -> float:
"""Weighted loss that gives more importance to certain timesteps."""
weights = 1.0 / (1.0 + timesteps / max_timesteps)
per_sample_loss = np.mean((noise_true - noise_pred) ** 2, axis=tuple(range(1, noise_true.ndim)))
return float(np.mean(weights * per_sample_loss))