← back

Inference Head Pruning for Transformers

#233 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement inference head pruning for Transformers. Given attention head importance scores, prune (zero out) the least important heads to speed up inference while preserving model quality.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def prune_attention_heads(
    attention_outputs: list[list[list[float]]],
    head_importance: list[float],
    prune_ratio: float,
) -> list[list[list[float]]]:
    """
    attention_outputs: [num_heads, seq_len, head_dim]
    head_importance: [num_heads] importance scores
    prune_ratio: fraction of heads to prune (0 to 1)
    """
    num_heads = len(attention_outputs)
    num_to_prune = int(num_heads * prune_ratio)

    # Sort heads by importance, ascending
    indexed = sorted(enumerate(head_importance), key=lambda x: x[1])
    heads_to_prune = set(idx for idx, _ in indexed[:num_to_prune])

    # Zero out pruned heads
    seq_len = len(attention_outputs[0])
    head_dim = len(attention_outputs[0][0])

    pruned = []
    for h in range(num_heads):
        if h in heads_to_prune:
            pruned.append([[0.0] * head_dim for _ in range(seq_len)])
        else:
            pruned.append([row[:] for row in attention_outputs[h]])

    return pruned

Explanation

  1. Rank all attention heads by their importance scores (e.g., computed from gradient magnitudes or Taylor expansion).
  2. Select the bottom prune_ratio fraction of heads for pruning.
  3. Zero out the outputs of pruned heads, effectively removing their contribution.
  4. The remaining heads continue to function normally.
  5. In practice, this reduces computation by skipping pruned heads entirely during inference.

Complexity

  • Time: O(H log H) for sorting + O(H S * D) for zeroing
  • Space: O(H S D) for the output