Implement Instance Normalization. Unlike Batch Normalization which normalizes across the batch, Instance Normalization normalizes each sample independently across spatial dimensions. It is commonly used in style transfer networks.
import numpy as np
def instance_norm(x: np.ndarray, gamma: np.ndarray = None, beta: np.ndarray = None, eps: float = 1e-5) -> np.ndarray:
# x shape: (N, C, H, W)
N, C, H, W = x.shape
# Compute mean and variance per instance per channel
mean = np.mean(x, axis=(2, 3), keepdims=True) # (N, C, 1, 1)
var = np.var(x, axis=(2, 3), keepdims=True) # (N, C, 1, 1)
# Normalize
x_norm = (x - mean) / np.sqrt(var + eps)
# Apply optional affine parameters (per channel)
if gamma is not None:
x_norm = x_norm * gamma.reshape(1, C, 1, 1)
if beta is not None:
x_norm = x_norm + beta.reshape(1, C, 1, 1)
return x_norm