← back

Classifier-Free Guidance for Conditional Diffusion

#397 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement classifier-free guidance (CFG) for conditional diffusion models. During inference, combine the conditional and unconditional noise predictions using a guidance scale to strengthen the influence of the conditioning signal.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np

def classifier_free_guidance(
    noise_pred_uncond: np.ndarray,
    noise_pred_cond: np.ndarray,
    guidance_scale: float
) -> np.ndarray:
    # CFG formula: eps = eps_uncond + w * (eps_cond - eps_uncond)
    return noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)


def cfg_training_step(
    model_fn,
    x_noisy: np.ndarray,
    t: np.ndarray,
    condition: np.ndarray,
    null_condition: np.ndarray,
    drop_prob: float = 0.1
) -> np.ndarray:
    # Randomly drop condition during training
    batch_size = x_noisy.shape[0]
    drop_mask = np.random.random(batch_size) < drop_prob

    # Replace dropped conditions with null condition
    cond = condition.copy()
    for i in range(batch_size):
        if drop_mask[i]:
            cond[i] = null_condition[0]

    return model_fn(x_noisy, t, cond)

Explanation

  1. During training, the conditioning signal is randomly dropped (replaced with a null/empty embedding) with some probability, training the model to handle both conditional and unconditional generation.
  2. At inference, the model is run twice: once with the condition and once without.
  3. The CFG formula extrapolates beyond the conditional prediction: eps = eps_uncond + w * (eps_cond - eps_uncond).
  4. A guidance scale w=1 gives standard conditional generation. w>1 amplifies the effect of the condition (higher fidelity, less diversity). w=0 gives unconditional generation.

Complexity

  • Time: O(d) for the guidance combination (2x model forward passes at inference)
  • Space: O(d) for storing the two noise predictions