← back

Implement mHC Forward Pass

#298 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the forward pass of a Multi-Head Cross-Attention (mHC) layer. Given query from one sequence and key/value from another, compute scaled dot-product attention across multiple heads.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np

def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x_max = np.max(x, axis=axis, keepdims=True)
    e_x = np.exp(x - x_max)
    return e_x / np.sum(e_x, axis=axis, keepdims=True)

def multi_head_cross_attention(
    Q: np.ndarray,   # (seq_q, d_model)
    K: np.ndarray,   # (seq_k, d_model)
    V: np.ndarray,   # (seq_k, d_model)
    W_Q: np.ndarray, # (num_heads, d_model, d_k)
    W_K: np.ndarray, # (num_heads, d_model, d_k)
    W_V: np.ndarray, # (num_heads, d_model, d_v)
    W_O: np.ndarray, # (num_heads * d_v, d_model)
    num_heads: int
) -> np.ndarray:
    d_k = W_Q.shape[-1]
    head_outputs = []

    for h in range(num_heads):
        Q_h = Q @ W_Q[h]       # (seq_q, d_k)
        K_h = K @ W_K[h]       # (seq_k, d_k)
        V_h = V @ W_V[h]       # (seq_k, d_v)

        scores = Q_h @ K_h.T / np.sqrt(d_k)  # (seq_q, seq_k)
        attn_weights = softmax(scores, axis=-1)
        head_out = attn_weights @ V_h          # (seq_q, d_v)
        head_outputs.append(head_out)

    concat = np.concatenate(head_outputs, axis=-1)  # (seq_q, num_heads * d_v)
    output = concat @ W_O  # (seq_q, d_model)
    return output

Explanation

  1. Project Q, K, V through separate learned weight matrices for each head, yielding lower-dimensional representations.
  2. Scaled dot-product attention per head: scores = (Q_h @ K_h^T) / sqrt(d_k), then softmax to get attention weights, then weighted sum of V_h.
  3. Cross-attention differs from self-attention in that Q comes from one sequence while K and V come from another (e.g., decoder queries attending to encoder outputs).
  4. Concatenate all head outputs and project through W_O to get the final output.

Complexity

  • Time: O(num_heads (seq_q seq_k d_k + seq_q seq_k * d_v))
  • Space: O(num_heads seq_q seq_k) for attention weight matrices