#115 · Deep Learning · Medium
⊣ Solve on deep-ml.comImplement Batch Normalization for a 4D input tensor of shape (batch, channels, height, width) (BCHW format). Normalize each channel across the batch and spatial dimensions, then apply learnable scale (gamma) and shift (beta) parameters.
import numpy as np
def batch_norm_2d(
x: np.ndarray,
gamma: np.ndarray = None,
beta: np.ndarray = None,
eps: float = 1e-5
) -> np.ndarray:
B, C, H, W = x.shape
# Compute mean and variance per channel (over batch and spatial dims)
mean = np.mean(x, axis=(0, 2, 3), keepdims=True) # (1, C, 1, 1)
var = np.var(x, axis=(0, 2, 3), keepdims=True) # (1, C, 1, 1)
# Normalize
x_norm = (x - mean) / np.sqrt(var + eps)
# Scale and shift
if gamma is not None:
gamma = gamma.reshape(1, C, 1, 1)
x_norm = x_norm * gamma
if beta is not None:
beta = beta.reshape(1, C, 1, 1)
x_norm = x_norm + beta
return x_norm(1, C, 1, 1) for broadcasting.