Implement Group Normalization. Unlike Batch Normalization which normalizes across the batch dimension, Group Normalization divides channels into groups and normalizes within each group independently. This makes it effective for small batch sizes.
import numpy as np
def group_norm(x: np.ndarray, num_groups: int, gamma: np.ndarray = None, beta: np.ndarray = None, eps: float = 1e-5) -> np.ndarray:
N, C, *spatial = x.shape
assert C % num_groups == 0, "Channels must be divisible by num_groups"
channels_per_group = C // num_groups
# Reshape: (N, num_groups, channels_per_group, *spatial)
new_shape = (N, num_groups, channels_per_group) + tuple(spatial)
x_reshaped = x.reshape(new_shape)
# Compute mean and variance over (channels_per_group, *spatial) axes
axes = tuple(range(2, len(new_shape)))
mean = np.mean(x_reshaped, axis=axes, keepdims=True)
var = np.var(x_reshaped, axis=axes, keepdims=True)
# Normalize
x_norm = (x_reshaped - mean) / np.sqrt(var + eps)
# Reshape back to original
x_norm = x_norm.reshape(x.shape)
# Apply learnable affine transform
if gamma is not None:
shape = [1, C] + [1] * len(spatial)
x_norm = x_norm * gamma.reshape(shape)
if beta is not None:
shape = [1, C] + [1] * len(spatial)
x_norm = x_norm + beta.reshape(shape)
return x_normnum_groups groups, each with C // num_groups channels.