← back

QK-Norm (Query-Key Normalization)

#407 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement QK-Norm (Query-Key Normalization), which applies layer normalization or L2 normalization to the query and key vectors before computing attention scores. This prevents attention logits from growing too large, improving training stability especially at scale.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np

def l2_normalize(x: np.ndarray, axis: int = -1, eps: float = 1e-12) -> np.ndarray:
    norm = np.sqrt(np.sum(x ** 2, axis=axis, keepdims=True) + eps)
    return x / norm


def layer_norm(x: np.ndarray, eps: float = 1e-5) -> np.ndarray:
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)
    return (x - mean) / np.sqrt(var + eps)


def qk_norm_attention(
    Q: np.ndarray,
    K: np.ndarray,
    V: np.ndarray,
    norm_type: str = "l2",
    temperature: float = None
) -> np.ndarray:
    # Q, K, V: (batch, num_heads, seq_len, d_k)
    d_k = Q.shape[-1]

    if temperature is None:
        temperature = np.sqrt(d_k)

    if norm_type == "l2":
        Q = l2_normalize(Q, axis=-1)
        K = l2_normalize(K, axis=-1)
    elif norm_type == "layer_norm":
        Q = layer_norm(Q)
        K = layer_norm(K)

    # Scores with learned or fixed temperature
    scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) * temperature

    # Softmax
    scores_max = np.max(scores, axis=-1, keepdims=True)
    weights = np.exp(scores - scores_max)
    weights = weights / np.sum(weights, axis=-1, keepdims=True)

    return np.matmul(weights, V)

Explanation

  1. Problem: in standard attention, Q and K vectors can grow in magnitude during training, causing attention logits to explode and softmax to saturate.
  2. L2 normalization: normalizes Q and K to unit vectors, so dot products are bounded in [-1, 1]. A learnable temperature then controls the sharpness.
  3. Layer normalization: normalizes Q and K to zero mean and unit variance along the head dimension.
  4. Temperature: replaces the standard 1/sqrt(d_k) scaling. Often learned as a parameter.
  5. QK-Norm enables stable training at large scales (used in models like ViT-22B and Chameleon).

Complexity

  • Time: O(B h n^2 * d) (same as standard attention)
  • Space: O(B h n * d) for normalized Q and K