← back

Implement Group Normalization

#126 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement Group Normalization. Unlike Batch Normalization which normalizes across the batch dimension, Group Normalization divides channels into groups and normalizes within each group independently. This makes it effective for small batch sizes.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np

def group_norm(x: np.ndarray, num_groups: int, gamma: np.ndarray = None, beta: np.ndarray = None, eps: float = 1e-5) -> np.ndarray:
    N, C, *spatial = x.shape
    assert C % num_groups == 0, "Channels must be divisible by num_groups"
    channels_per_group = C // num_groups

    # Reshape: (N, num_groups, channels_per_group, *spatial)
    new_shape = (N, num_groups, channels_per_group) + tuple(spatial)
    x_reshaped = x.reshape(new_shape)

    # Compute mean and variance over (channels_per_group, *spatial) axes
    axes = tuple(range(2, len(new_shape)))
    mean = np.mean(x_reshaped, axis=axes, keepdims=True)
    var = np.var(x_reshaped, axis=axes, keepdims=True)

    # Normalize
    x_norm = (x_reshaped - mean) / np.sqrt(var + eps)

    # Reshape back to original
    x_norm = x_norm.reshape(x.shape)

    # Apply learnable affine transform
    if gamma is not None:
        shape = [1, C] + [1] * len(spatial)
        x_norm = x_norm * gamma.reshape(shape)
    if beta is not None:
        shape = [1, C] + [1] * len(spatial)
        x_norm = x_norm + beta.reshape(shape)

    return x_norm

Explanation

  1. Reshape the input so that channels are split into num_groups groups, each with C // num_groups channels.
  2. Compute the mean and variance within each group (across the channels-per-group and spatial dimensions).
  3. Normalize each group to zero mean and unit variance.
  4. Reshape back to the original shape and apply optional learnable scale (gamma) and shift (beta) parameters.

Complexity

  • Time: O(N C H * W) where the spatial dimensions are H x W
  • Space: O(N C H * W) for the normalized output