← back

Implement Multi-Head Attention

#94 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement Multi-Head Attention from scratch. Given query, key, and value matrices along with the number of heads, compute the multi-head attention output. This is the core mechanism in Transformer architectures.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np

def multi_head_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, n_heads: int) -> np.ndarray:
    batch_size, seq_len, d_model = Q.shape
    assert d_model % n_heads == 0
    d_k = d_model // n_heads

    def split_heads(x):
        # (batch, seq, d_model) -> (batch, heads, seq, d_k)
        x = x.reshape(batch_size, seq_len, n_heads, d_k)
        return x.transpose(0, 2, 1, 3)

    def scaled_dot_product_attention(q, k, v):
        scores = q @ k.transpose(0, 1, 3, 2) / np.sqrt(d_k)
        weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
        weights = weights / np.sum(weights, axis=-1, keepdims=True)
        return weights @ v

    Q_split = split_heads(Q)
    K_split = split_heads(K)
    V_split = split_heads(V)

    attn_output = scaled_dot_product_attention(Q_split, K_split, V_split)

    # Concatenate heads: (batch, heads, seq, d_k) -> (batch, seq, d_model)
    attn_output = attn_output.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, d_model)
    return attn_output

Explanation

  1. Split heads: Reshape Q, K, V from (batch, seq, d_model) to (batch, heads, seq, d_k) where d_k = d_model / n_heads.
  2. Scaled dot-product attention: Compute softmax(Q K^T / sqrt(d_k)) V for each head independently. The scaling by sqrt(d_k) prevents dot products from growing too large.
  3. Softmax: Use the numerically stable version by subtracting the max before exponentiating.
  4. Concatenate: Merge all head outputs back into (batch, seq, d_model).

Note: In a full Transformer, linear projection matrices W_Q, W_K, W_V, and W_O would be applied. This implementation focuses on the core attention computation.

Complexity

  • Time: O(batch heads seq^2 * d_k) dominated by the attention score computation
  • Space: O(batch heads seq^2) for the attention weight matrices