Implement attention sink detection for a transformer model. Given an attention weight matrix, identify "sink" tokens that receive disproportionately high attention from many other tokens. Return the indices of tokens whose average received attention exceeds a given threshold.
def detect_attention_sinks(
attention_weights: list[list[float]], threshold: float
) -> list[int]:
n = len(attention_weights)
avg = [sum(attention_weights[i][j] for i in range(n)) / n for j in range(n)]
return [j for j, v in enumerate(avg) if v > threshold]attention_weights has shape (seq_len, seq_len) where entry (i, j) is how much token i attends to token j.j receives by averaging across all query positions (axis 0 of the matrix).threshold are classified as attention sinks. These are typically the first few tokens (like BOS) and delimiter tokens.