Implement Grouped Query Attention (GQA), an interpolation between Multi-Head Attention and Multi-Query Attention. Instead of one KV head for all query heads, GQA uses g KV groups, where each group of query heads shares one KV head.
import numpy as np
def grouped_query_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, num_groups: int) -> np.ndarray:
# Q: (batch, num_heads, seq_len, d_k)
# K: (batch, num_groups, seq_len, d_k)
# V: (batch, num_groups, seq_len, d_k)
batch, num_heads, seq_len, d_k = Q.shape
heads_per_group = num_heads // num_groups
# Reshape Q to (batch, num_groups, heads_per_group, seq_len, d_k)
Q = Q.reshape(batch, num_groups, heads_per_group, seq_len, d_k)
# Expand K, V to (batch, num_groups, 1, seq_len, d_k) for broadcasting
K = K[:, :, np.newaxis, :, :]
V = V[:, :, np.newaxis, :, :]
# Attention scores: (batch, num_groups, heads_per_group, seq_len, seq_len)
scores = np.matmul(Q, K.transpose(0, 1, 2, 4, 3)) / np.sqrt(d_k)
scores_max = np.max(scores, axis=-1, keepdims=True)
exp_scores = np.exp(scores - scores_max)
weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
output = np.matmul(weights, V)
# Reshape back to (batch, num_heads, seq_len, d_k)
output = output.reshape(batch, num_heads, seq_len, d_k)
return output(batch, num_groups, heads_per_group, seq_len, d_k).g=1 is MQA; with g=h is standard MHA.