Implement both Pre-Norm and Post-Norm transformer blocks and compare them. Pre-Norm applies layer normalization before the sub-layer (attention/FFN), while Post-Norm applies it after. Pre-Norm is now standard as it provides more stable training.
import numpy as np
def layer_norm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float = 1e-5) -> np.ndarray:
mean = np.mean(x, axis=-1, keepdims=True)
var = np.var(x, axis=-1, keepdims=True)
return gamma * (x - mean) / np.sqrt(var + eps) + beta
def pre_norm_block(x: np.ndarray, attn_fn, ffn_fn, gamma1: np.ndarray, beta1: np.ndarray, gamma2: np.ndarray, beta2: np.ndarray) -> np.ndarray:
# Pre-Norm: Norm -> SubLayer -> Residual
# Attention sub-layer
normed = layer_norm(x, gamma1, beta1)
x = x + attn_fn(normed)
# FFN sub-layer
normed = layer_norm(x, gamma2, beta2)
x = x + ffn_fn(normed)
return x
def post_norm_block(x: np.ndarray, attn_fn, ffn_fn, gamma1: np.ndarray, beta1: np.ndarray, gamma2: np.ndarray, beta2: np.ndarray) -> np.ndarray:
# Post-Norm: SubLayer -> Residual -> Norm
# Attention sub-layer
x = layer_norm(x + attn_fn(x), gamma1, beta1)
# FFN sub-layer
x = layer_norm(x + ffn_fn(x), gamma2, beta2)
return x