← back

Implement Masked Self-Attention

#107 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement Masked Self-Attention (causal attention). Given query, key, and value matrices, compute scaled dot-product attention with a causal mask that prevents attending to future positions. This is the core component of decoder-only Transformers like GPT.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np

def masked_self_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray) -> np.ndarray:
    d_k = Q.shape[-1]

    # Compute attention scores
    scores = Q @ K.T / np.sqrt(d_k)

    # Create causal mask (upper triangle = -inf)
    seq_len = scores.shape[0]
    mask = np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1)
    scores = np.where(mask, -1e9, scores)

    # Softmax
    scores_shifted = scores - np.max(scores, axis=-1, keepdims=True)
    exp_scores = np.exp(scores_shifted)
    attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)

    # Weighted sum of values
    output = attention_weights @ V
    return output

Explanation

  1. Scaled dot-product: Compute Q K^T / sqrt(d_k) to get raw attention scores. Scaling by sqrt(d_k) prevents softmax saturation.
  2. Causal mask: Set upper-triangular entries (where query position < key position) to a very large negative number. After softmax, these become ~0, preventing attention to future tokens.
  3. Numerically stable softmax: Subtract the row-wise max before exponentiating to prevent overflow.
  4. Output: Multiply attention weights by values. Each position's output is a weighted combination of values from current and previous positions only.

Complexity

  • Time: O(n^2 * d) where n is sequence length and d is the dimension
  • Space: O(n^2) for the attention score matrix