← back

Implement Grouped Query Attention (GQA)

#391 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement Grouped Query Attention (GQA), an interpolation between Multi-Head Attention and Multi-Query Attention. Instead of one KV head for all query heads, GQA uses g KV groups, where each group of query heads shares one KV head.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np

def grouped_query_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, num_groups: int) -> np.ndarray:
    # Q: (batch, num_heads, seq_len, d_k)
    # K: (batch, num_groups, seq_len, d_k)
    # V: (batch, num_groups, seq_len, d_k)
    batch, num_heads, seq_len, d_k = Q.shape
    heads_per_group = num_heads // num_groups

    # Reshape Q to (batch, num_groups, heads_per_group, seq_len, d_k)
    Q = Q.reshape(batch, num_groups, heads_per_group, seq_len, d_k)

    # Expand K, V to (batch, num_groups, 1, seq_len, d_k) for broadcasting
    K = K[:, :, np.newaxis, :, :]
    V = V[:, :, np.newaxis, :, :]

    # Attention scores: (batch, num_groups, heads_per_group, seq_len, seq_len)
    scores = np.matmul(Q, K.transpose(0, 1, 2, 4, 3)) / np.sqrt(d_k)

    scores_max = np.max(scores, axis=-1, keepdims=True)
    exp_scores = np.exp(scores - scores_max)
    weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)

    output = np.matmul(weights, V)
    # Reshape back to (batch, num_heads, seq_len, d_k)
    output = output.reshape(batch, num_heads, seq_len, d_k)
    return output

Explanation

  1. GQA divides query heads into groups. Each group shares one KV head.
  2. Reshape Q so group structure is explicit: (batch, num_groups, heads_per_group, seq_len, d_k).
  3. Add a broadcast dimension to K and V so each KV head is shared across its group's query heads.
  4. Compute standard scaled dot-product attention within each group.
  5. Reshape back to the original multi-head format. GQA with g=1 is MQA; with g=h is standard MHA.

Complexity

  • Time: O(B h n^2 * d) same as MHA
  • Space: O(B g n * d) for KV-cache, reduced by factor of h/g vs MHA