Implement Multi-Head Attention from scratch. Given query, key, and value matrices along with the number of heads, compute the multi-head attention output. This is the core mechanism in Transformer architectures.
import numpy as np
def multi_head_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, n_heads: int) -> np.ndarray:
batch_size, seq_len, d_model = Q.shape
assert d_model % n_heads == 0
d_k = d_model // n_heads
def split_heads(x):
# (batch, seq, d_model) -> (batch, heads, seq, d_k)
x = x.reshape(batch_size, seq_len, n_heads, d_k)
return x.transpose(0, 2, 1, 3)
def scaled_dot_product_attention(q, k, v):
scores = q @ k.transpose(0, 1, 3, 2) / np.sqrt(d_k)
weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
weights = weights / np.sum(weights, axis=-1, keepdims=True)
return weights @ v
Q_split = split_heads(Q)
K_split = split_heads(K)
V_split = split_heads(V)
attn_output = scaled_dot_product_attention(Q_split, K_split, V_split)
# Concatenate heads: (batch, heads, seq, d_k) -> (batch, seq, d_model)
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, d_model)
return attn_output(batch, seq, d_model) to (batch, heads, seq, d_k) where d_k = d_model / n_heads.softmax(Q K^T / sqrt(d_k)) V for each head independently. The scaling by sqrt(d_k) prevents dot products from growing too large.(batch, seq, d_model).Note: In a full Transformer, linear projection matrices W_Q, W_K, W_V, and W_O would be applied. This implementation focuses on the core attention computation.