Implement the forward pass of a Multi-Head Cross-Attention (mHC) layer. Given query from one sequence and key/value from another, compute scaled dot-product attention across multiple heads.
import numpy as np
def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
x_max = np.max(x, axis=axis, keepdims=True)
e_x = np.exp(x - x_max)
return e_x / np.sum(e_x, axis=axis, keepdims=True)
def multi_head_cross_attention(
Q: np.ndarray, # (seq_q, d_model)
K: np.ndarray, # (seq_k, d_model)
V: np.ndarray, # (seq_k, d_model)
W_Q: np.ndarray, # (num_heads, d_model, d_k)
W_K: np.ndarray, # (num_heads, d_model, d_k)
W_V: np.ndarray, # (num_heads, d_model, d_v)
W_O: np.ndarray, # (num_heads * d_v, d_model)
num_heads: int
) -> np.ndarray:
d_k = W_Q.shape[-1]
head_outputs = []
for h in range(num_heads):
Q_h = Q @ W_Q[h] # (seq_q, d_k)
K_h = K @ W_K[h] # (seq_k, d_k)
V_h = V @ W_V[h] # (seq_k, d_v)
scores = Q_h @ K_h.T / np.sqrt(d_k) # (seq_q, seq_k)
attn_weights = softmax(scores, axis=-1)
head_out = attn_weights @ V_h # (seq_q, d_v)
head_outputs.append(head_out)
concat = np.concatenate(head_outputs, axis=-1) # (seq_q, num_heads * d_v)
output = concat @ W_O # (seq_q, d_model)
return output