← back

Pre-Norm vs Post-Norm Transformer Block

#408 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement both Pre-Norm and Post-Norm transformer blocks and compare them. Pre-Norm applies layer normalization before the sub-layer (attention/FFN), while Post-Norm applies it after. Pre-Norm is now standard as it provides more stable training.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np

def layer_norm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float = 1e-5) -> np.ndarray:
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return gamma * (x - mean) / np.sqrt(var + eps) + beta


def pre_norm_block(x: np.ndarray, attn_fn, ffn_fn, gamma1: np.ndarray, beta1: np.ndarray, gamma2: np.ndarray, beta2: np.ndarray) -> np.ndarray:
    # Pre-Norm: Norm -> SubLayer -> Residual
    # Attention sub-layer
    normed = layer_norm(x, gamma1, beta1)
    x = x + attn_fn(normed)

    # FFN sub-layer
    normed = layer_norm(x, gamma2, beta2)
    x = x + ffn_fn(normed)
    return x


def post_norm_block(x: np.ndarray, attn_fn, ffn_fn, gamma1: np.ndarray, beta1: np.ndarray, gamma2: np.ndarray, beta2: np.ndarray) -> np.ndarray:
    # Post-Norm: SubLayer -> Residual -> Norm
    # Attention sub-layer
    x = layer_norm(x + attn_fn(x), gamma1, beta1)

    # FFN sub-layer
    x = layer_norm(x + ffn_fn(x), gamma2, beta2)
    return x

Explanation

  1. Post-Norm (original Transformer): the residual connection adds the sub-layer output to the input, then layer norm is applied. The gradient must pass through the norm at every layer, which can cause vanishing/exploding gradients in deep models.
  2. Pre-Norm (GPT-2 onward): layer norm is applied to the input before the sub-layer, and the residual connection bypasses the norm entirely. This creates a clean gradient path through the residual stream.
  3. Pre-Norm typically requires no learning rate warmup and trains more stably, but Post-Norm can achieve slightly better final performance with careful tuning.
  4. Most modern LLMs (GPT, LLaMA, etc.) use Pre-Norm (often with RMSNorm instead of LayerNorm).

Complexity

  • Time: O(n d) per normalization, total dominated by attention O(n^2 d)
  • Space: O(n * d) for normalized intermediate representations