← back

Implement Attention Sink Detection

#457 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement attention sink detection for a transformer model. Given an attention weight matrix, identify "sink" tokens that receive disproportionately high attention from many other tokens. Return the indices of tokens whose average received attention exceeds a given threshold.

Solution

1
2
3
4
5
6
def detect_attention_sinks(
    attention_weights: list[list[float]], threshold: float
) -> list[int]:
    n = len(attention_weights)
    avg = [sum(attention_weights[i][j] for i in range(n)) / n for j in range(n)]
    return [j for j, v in enumerate(avg) if v > threshold]

Explanation

  1. attention_weights has shape (seq_len, seq_len) where entry (i, j) is how much token i attends to token j.
  2. Compute the mean attention each token j receives by averaging across all query positions (axis 0 of the matrix).
  3. Tokens whose average received attention exceeds threshold are classified as attention sinks. These are typically the first few tokens (like BOS) and delimiter tokens.

Complexity

  • Time: O(n^2) where n is the sequence length
  • Space: O(n) for the average-attention vector