Implement Multi-Query Attention (MQA). Unlike standard multi-head attention where each head has its own K and V projections, MQA shares a single K and V head across all query heads, significantly reducing KV-cache memory during inference.
import numpy as np
def multiquery_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray) -> np.ndarray:
# Q: (batch, num_heads, seq_len, d_k)
# K: (batch, 1, seq_len, d_k) - single KV head
# V: (batch, 1, seq_len, d_k) - single KV head
d_k = Q.shape[-1]
# K and V are broadcast across all heads
# scores: (batch, num_heads, seq_len, seq_len)
scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(d_k)
# Softmax along last axis
scores_max = np.max(scores, axis=-1, keepdims=True)
exp_scores = np.exp(scores - scores_max)
weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
# output: (batch, num_heads, seq_len, d_k)
output = np.matmul(weights, V)
return output1/sqrt(d_k) and passed through softmax.