← back

Implement Batch Normalization for BCHW Input

#115 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement Batch Normalization for a 4D input tensor of shape (batch, channels, height, width) (BCHW format). Normalize each channel across the batch and spatial dimensions, then apply learnable scale (gamma) and shift (beta) parameters.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import numpy as np

def batch_norm_2d(
    x: np.ndarray,
    gamma: np.ndarray = None,
    beta: np.ndarray = None,
    eps: float = 1e-5
) -> np.ndarray:
    B, C, H, W = x.shape

    # Compute mean and variance per channel (over batch and spatial dims)
    mean = np.mean(x, axis=(0, 2, 3), keepdims=True)   # (1, C, 1, 1)
    var = np.var(x, axis=(0, 2, 3), keepdims=True)      # (1, C, 1, 1)

    # Normalize
    x_norm = (x - mean) / np.sqrt(var + eps)

    # Scale and shift
    if gamma is not None:
        gamma = gamma.reshape(1, C, 1, 1)
        x_norm = x_norm * gamma
    if beta is not None:
        beta = beta.reshape(1, C, 1, 1)
        x_norm = x_norm + beta

    return x_norm

Explanation

  1. Per-channel statistics: Compute mean and variance for each channel across the batch dimension and both spatial dimensions (H, W). This gives one mean and one variance per channel.
  2. Normalize: Subtract the channel mean and divide by the channel standard deviation (with epsilon for numerical stability).
  3. Affine transform: Apply learnable gamma (scale) and beta (shift) parameters per channel, reshaped to (1, C, 1, 1) for broadcasting.
  4. Why batch norm: Reduces internal covariate shift, allows higher learning rates, and acts as a regularizer. During inference, running averages of mean/variance are used instead of batch statistics.

Complexity

  • Time: O(B C H * W) for computing statistics and normalizing
  • Space: O(B C H * W) for the normalized output