← back

Exponential Moving Average (EMA) for Diffusion Model Weights

#401 · Deep Learning · Easy

⊣ Solve on deep-ml.com

Problem

Implement Exponential Moving Average (EMA) for diffusion model weights. EMA maintains a shadow copy of model parameters that is a smoothed version of the training parameters, which typically produces better samples.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np

class EMA:
    def __init__(self, decay: float = 0.999):
        self.decay = decay
        self.shadow_params = {}

    def register(self, params: dict[str, np.ndarray]) -> None:
        for name, param in params.items():
            self.shadow_params[name] = param.copy()

    def update(self, params: dict[str, np.ndarray]) -> None:
        for name, param in params.items():
            self.shadow_params[name] = (
                self.decay * self.shadow_params[name] + (1 - self.decay) * param
            )

    def get_params(self) -> dict[str, np.ndarray]:
        return {name: param.copy() for name, param in self.shadow_params.items()}

    def apply_shadow(self, params: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
        backup = {name: param.copy() for name, param in params.items()}
        for name in params:
            params[name] = self.shadow_params[name].copy()
        return backup

    def restore(self, params: dict[str, np.ndarray], backup: dict[str, np.ndarray]) -> None:
        for name in params:
            params[name] = backup[name].copy()

Explanation

  1. Register: initialize shadow parameters as a copy of the model's initial weights.
  2. Update: after each training step, update each shadow parameter as shadow = decay * shadow + (1 - decay) * param. A high decay (e.g., 0.999 or 0.9999) means the shadow changes slowly.
  3. Apply shadow: swap in the EMA weights for inference/sampling, keeping a backup of training weights.
  4. Restore: swap training weights back after inference.
  5. EMA-smoothed weights reduce noise from stochastic training and consistently produce higher-quality diffusion samples.

Complexity

  • Time: O(P) per update where P is the total number of parameters
  • Space: O(P) for the shadow copy