Implement Multi-Head Latent Attention (MLA) as used in DeepSeek-V2. MLA compresses keys and values into a low-rank latent representation to reduce KV-cache memory, then reconstructs K and V from this compressed latent during attention computation.
import numpy as np
class MultiHeadLatentAttention:
def __init__(self, d_model: int, num_heads: int, d_latent: int):
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.d_latent = d_latent
# Down-projection for KV compression
self.W_dkv = np.random.randn(d_model, d_latent) * 0.02
# Up-projections to reconstruct K and V from latent
self.W_uk = np.random.randn(d_latent, d_model) * 0.02
self.W_uv = np.random.randn(d_latent, d_model) * 0.02
# Query projection
self.W_q = np.random.randn(d_model, d_model) * 0.02
# Output projection
self.W_o = np.random.randn(d_model, d_model) * 0.02
def forward(self, x: np.ndarray) -> np.ndarray:
# x: (batch, seq_len, d_model)
batch, seq_len, _ = x.shape
Q = x @ self.W_q # (batch, seq_len, d_model)
# Compress KV into latent
c_kv = x @ self.W_dkv # (batch, seq_len, d_latent) <-- this is cached
# Reconstruct K, V from latent
K = c_kv @ self.W_uk # (batch, seq_len, d_model)
V = c_kv @ self.W_uv # (batch, seq_len, d_model)
# Reshape for multi-head
Q = Q.reshape(batch, seq_len, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
K = K.reshape(batch, seq_len, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
V = V.reshape(batch, seq_len, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
# Scaled dot-product attention
scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(self.d_k)
scores_max = np.max(scores, axis=-1, keepdims=True)
weights = np.exp(scores - scores_max)
weights = weights / np.sum(weights, axis=-1, keepdims=True)
out = np.matmul(weights, V)
out = out.transpose(0, 2, 1, 3).reshape(batch, seq_len, self.d_model)
return out @ self.W_od_latent << d_model) using a down-projection. Only this latent vector is cached during autoregressive decoding.d_latent per token) instead of full K and V (size 2 * d_model per token), achieving significant memory savings.