← back

Implement Multiquery Attention (MQA)

#390 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement Multi-Query Attention (MQA). Unlike standard multi-head attention where each head has its own K and V projections, MQA shares a single K and V head across all query heads, significantly reducing KV-cache memory during inference.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np

def multiquery_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray) -> np.ndarray:
    # Q: (batch, num_heads, seq_len, d_k)
    # K: (batch, 1, seq_len, d_k) - single KV head
    # V: (batch, 1, seq_len, d_k) - single KV head
    d_k = Q.shape[-1]

    # K and V are broadcast across all heads
    # scores: (batch, num_heads, seq_len, seq_len)
    scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(d_k)

    # Softmax along last axis
    scores_max = np.max(scores, axis=-1, keepdims=True)
    exp_scores = np.exp(scores - scores_max)
    weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)

    # output: (batch, num_heads, seq_len, d_k)
    output = np.matmul(weights, V)
    return output

Explanation

  1. In MQA, queries have multiple heads but keys and values have a single head.
  2. NumPy broadcasting automatically replicates the single K and V head across all query heads when computing attention scores.
  3. Scores are scaled by 1/sqrt(d_k) and passed through softmax.
  4. Each query head attends to the same shared KV, reducing KV-cache from O(h n d) to O(n * d).

Complexity

  • Time: O(B h n^2 * d) where h is number of heads, n is sequence length
  • Space: O(B n d) for KV-cache (reduced by factor of h compared to MHA)