← back

Mean Ablation for Circuit Discovery

#236 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement mean ablation for circuit discovery in neural networks. Given a model's activations at a particular layer, replace specific neurons' activations with their dataset-level mean values to measure their contribution to the output.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def mean_ablation(
    activations: list[list[float]],
    mean_activations: list[float],
    neurons_to_ablate: list[int],
) -> list[list[float]]:
    """
    activations: [batch_size, num_neurons]
    mean_activations: [num_neurons] - precomputed mean for each neuron
    neurons_to_ablate: list of neuron indices to replace with means
    """
    batch_size = len(activations)
    num_neurons = len(activations[0])
    ablate_set = set(neurons_to_ablate)

    result = []
    for b in range(batch_size):
        row = []
        for n in range(num_neurons):
            if n in ablate_set:
                row.append(round(mean_activations[n], 6))
            else:
                row.append(activations[b][n])
        result.append(row)

    return result


def compute_ablation_effect(
    original_output: list[float],
    ablated_output: list[float],
) -> float:
    """
    Measure effect as the L2 distance between original and ablated outputs.
    """
    diff_sq = sum((a - b) ** 2 for a, b in zip(original_output, ablated_output))
    return round(diff_sq ** 0.5, 6)

Explanation

  1. Mean ablation replaces selected neurons' activations with their dataset-level mean.
  2. The intuition: if ablating a neuron significantly changes the output, that neuron is important for the computation (part of the "circuit").
  3. Neurons not in the ablation set retain their original activations.
  4. The ablation effect is measured by the L2 distance between the original and ablated model outputs.
  5. This technique is central to mechanistic interpretability and circuit discovery in transformers.

Complexity

  • Time: O(batch_size * num_neurons)
  • Space: O(batch_size * num_neurons) for the ablated activations