Implement QK-Norm (Query-Key Normalization), which applies layer normalization or L2 normalization to the query and key vectors before computing attention scores. This prevents attention logits from growing too large, improving training stability especially at scale.
import numpy as np
def l2_normalize(x: np.ndarray, axis: int = -1, eps: float = 1e-12) -> np.ndarray:
norm = np.sqrt(np.sum(x ** 2, axis=axis, keepdims=True) + eps)
return x / norm
def layer_norm(x: np.ndarray, eps: float = 1e-5) -> np.ndarray:
mean = np.mean(x, axis=-1, keepdims=True)
var = np.var(x, axis=-1, keepdims=True)
return (x - mean) / np.sqrt(var + eps)
def qk_norm_attention(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
norm_type: str = "l2",
temperature: float = None
) -> np.ndarray:
# Q, K, V: (batch, num_heads, seq_len, d_k)
d_k = Q.shape[-1]
if temperature is None:
temperature = np.sqrt(d_k)
if norm_type == "l2":
Q = l2_normalize(Q, axis=-1)
K = l2_normalize(K, axis=-1)
elif norm_type == "layer_norm":
Q = layer_norm(Q)
K = layer_norm(K)
# Scores with learned or fixed temperature
scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) * temperature
# Softmax
scores_max = np.max(scores, axis=-1, keepdims=True)
weights = np.exp(scores - scores_max)
weights = weights / np.sum(weights, axis=-1, keepdims=True)
return np.matmul(weights, V)[-1, 1]. A learnable temperature then controls the sharpness.1/sqrt(d_k) scaling. Often learned as a parameter.