← back

Implement RMSNorm (Root Mean Square Layer Normalization)

#372 · Deep Learning · Easy

⊣ Solve on deep-ml.com

Problem

Implement RMSNorm (Root Mean Square Layer Normalization), a simpler alternative to LayerNorm that normalizes by the root mean square of the activations without centering (no mean subtraction).

Solution

1
2
3
4
5
6
7
import numpy as np

def rms_norm(x: np.ndarray, gamma: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    # x shape: (..., d) where d is the feature dimension
    rms = np.sqrt(np.mean(x ** 2, axis=-1, keepdims=True) + eps)
    x_norm = x / rms
    return gamma * x_norm

Explanation

  1. Compute the root mean square of the input along the last dimension: rms = sqrt(mean(x^2) + eps).
  2. Normalize by dividing by the RMS value.
  3. Scale by learnable parameter gamma (element-wise).
  4. Unlike LayerNorm, RMSNorm does not subtract the mean or add a bias term, making it simpler and faster while achieving comparable performance.

Complexity

  • Time: O(n * d) where n is the batch size and d is the feature dimension
  • Space: O(n * d) for the normalized output