← back

Multi-Head Latent Attention (MLA)

#405 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement Multi-Head Latent Attention (MLA) as used in DeepSeek-V2. MLA compresses keys and values into a low-rank latent representation to reduce KV-cache memory, then reconstructs K and V from this compressed latent during attention computation.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np

class MultiHeadLatentAttention:
    def __init__(self, d_model: int, num_heads: int, d_latent: int):
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_latent = d_latent

        # Down-projection for KV compression
        self.W_dkv = np.random.randn(d_model, d_latent) * 0.02
        # Up-projections to reconstruct K and V from latent
        self.W_uk = np.random.randn(d_latent, d_model) * 0.02
        self.W_uv = np.random.randn(d_latent, d_model) * 0.02
        # Query projection
        self.W_q = np.random.randn(d_model, d_model) * 0.02
        # Output projection
        self.W_o = np.random.randn(d_model, d_model) * 0.02

    def forward(self, x: np.ndarray) -> np.ndarray:
        # x: (batch, seq_len, d_model)
        batch, seq_len, _ = x.shape

        Q = x @ self.W_q  # (batch, seq_len, d_model)

        # Compress KV into latent
        c_kv = x @ self.W_dkv  # (batch, seq_len, d_latent) <-- this is cached

        # Reconstruct K, V from latent
        K = c_kv @ self.W_uk  # (batch, seq_len, d_model)
        V = c_kv @ self.W_uv  # (batch, seq_len, d_model)

        # Reshape for multi-head
        Q = Q.reshape(batch, seq_len, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
        K = K.reshape(batch, seq_len, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
        V = V.reshape(batch, seq_len, self.num_heads, self.d_k).transpose(0, 2, 1, 3)

        # Scaled dot-product attention
        scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(self.d_k)
        scores_max = np.max(scores, axis=-1, keepdims=True)
        weights = np.exp(scores - scores_max)
        weights = weights / np.sum(weights, axis=-1, keepdims=True)

        out = np.matmul(weights, V)
        out = out.transpose(0, 2, 1, 3).reshape(batch, seq_len, self.d_model)
        return out @ self.W_o

Explanation

  1. KV compression: project the input into a low-rank latent space (d_latent << d_model) using a down-projection. Only this latent vector is cached during autoregressive decoding.
  2. KV reconstruction: up-project the latent to reconstruct full K and V. This happens on the fly during attention.
  3. Standard attention: Q, K, V are reshaped into multiple heads and standard scaled dot-product attention is applied.
  4. The KV-cache stores only the compressed latent (size d_latent per token) instead of full K and V (size 2 * d_model per token), achieving significant memory savings.

Complexity

  • Time: O(B h n^2 d_k + B n d_model d_latent)
  • Space: O(B n d_latent) for KV-cache (reduced from O(B n d_model))