Implement inference head pruning for Transformers. Given attention head importance scores, prune (zero out) the least important heads to speed up inference while preserving model quality.
def prune_attention_heads(
attention_outputs: list[list[list[float]]],
head_importance: list[float],
prune_ratio: float,
) -> list[list[list[float]]]:
"""
attention_outputs: [num_heads, seq_len, head_dim]
head_importance: [num_heads] importance scores
prune_ratio: fraction of heads to prune (0 to 1)
"""
num_heads = len(attention_outputs)
num_to_prune = int(num_heads * prune_ratio)
# Sort heads by importance, ascending
indexed = sorted(enumerate(head_importance), key=lambda x: x[1])
heads_to_prune = set(idx for idx, _ in indexed[:num_to_prune])
# Zero out pruned heads
seq_len = len(attention_outputs[0])
head_dim = len(attention_outputs[0][0])
pruned = []
for h in range(num_heads):
if h in heads_to_prune:
pruned.append([[0.0] * head_dim for _ in range(seq_len)])
else:
pruned.append([row[:] for row in attention_outputs[h]])
return prunedprune_ratio fraction of heads for pruning.