Implement mean ablation for circuit discovery in neural networks. Given a model's activations at a particular layer, replace specific neurons' activations with their dataset-level mean values to measure their contribution to the output.
def mean_ablation(
activations: list[list[float]],
mean_activations: list[float],
neurons_to_ablate: list[int],
) -> list[list[float]]:
"""
activations: [batch_size, num_neurons]
mean_activations: [num_neurons] - precomputed mean for each neuron
neurons_to_ablate: list of neuron indices to replace with means
"""
batch_size = len(activations)
num_neurons = len(activations[0])
ablate_set = set(neurons_to_ablate)
result = []
for b in range(batch_size):
row = []
for n in range(num_neurons):
if n in ablate_set:
row.append(round(mean_activations[n], 6))
else:
row.append(activations[b][n])
result.append(row)
return result
def compute_ablation_effect(
original_output: list[float],
ablated_output: list[float],
) -> float:
"""
Measure effect as the L2 distance between original and ablated outputs.
"""
diff_sq = sum((a - b) ** 2 for a, b in zip(original_output, ablated_output))
return round(diff_sq ** 0.5, 6)