name: activation-patching description: Causal intervention via activation patching to identify important model components. Use when determining which layers, heads, or positions are causally responsible for model behavior.
Activation Patching
Activation patching is a causal intervention technique that identifies which model components are responsible for specific behaviors by swapping activations between different inputs.
Core Concept
- Clean run: Run model on prompt that produces desired behavior
- Corrupted run: Run on modified prompt that changes the behavior
- Patch: Replace corrupted activations with clean ones, measure if behavior is restored
If patching a component restores the clean behavior, that component is causally important.
Basic Setup
from nnsight import LanguageModel
import torch
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
# Indirect Object Identification (IOI) task
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
# Target tokens
correct_token = model.tokenizer(" John")["input_ids"][0] # Clean answer
incorrect_token = model.tokenizer(" Mary")["input_ids"][0] # Corrupted answer
Metric: Logit Difference
def logit_diff(logits, correct_idx, incorrect_idx):
"""Measure how much model prefers correct over incorrect token."""
return (logits[0, -1, correct_idx] - logits[0, -1, incorrect_idx]).item()
Three-Run Patching Pattern
n_layers = len(model.transformer.h)
results = torch.zeros(n_layers)
# Run 1: Clean - save activations
with model.trace(clean_prompt):
clean_hiddens = [layer.output[0].save() for layer in model.transformer.h]
clean_logits = model.lm_head.output.save()
# Run 2: Corrupted baseline
with model.trace(corrupted_prompt):
corrupted_logits = model.lm_head.output.save()
# Runs 3+: Patch each layer (separate forward passes)
for layer_idx in range(n_layers):
with model.trace(corrupted_prompt):
# Replace corrupted activation with clean
model.transformer.h[layer_idx].output[0][:] = clean_hiddens[layer_idx]
patched_logits = model.lm_head.output.save()
results[layer_idx] = logit_diff(patched_logits.value, correct_token, incorrect_token)
# Normalize results
clean_diff = logit_diff(clean_logits.value, correct_token, incorrect_token)
corrupted_diff = logit_diff(corrupted_logits.value, correct_token, incorrect_token)
normalized = (results - corrupted_diff) / (clean_diff - corrupted_diff)
Position-Specific Patching
Patch only specific token positions:
seq_len = len(model.tokenizer.encode(clean_prompt))
results = torch.zeros(n_layers, seq_len)
# Clean run - save activations
with model.trace(clean_prompt):
clean_hiddens = [layer.output[0].save() for layer in model.transformer.h]
# Patch each layer x position (separate forward passes)
for layer_idx in range(n_layers):
for pos_idx in range(seq_len):
with model.trace(corrupted_prompt):
# Patch only this position
model.transformer.h[layer_idx].output[0][:, pos_idx, :] = \
clean_hiddens[layer_idx][:, pos_idx, :]
patched_logits = model.lm_head.output.save()
results[layer_idx, pos_idx] = logit_diff(
patched_logits.value, correct_token, incorrect_token
)
Attention Head Patching
Patch individual attention heads:
n_heads = model.config.n_head
head_dim = model.config.n_embd // n_heads
results = torch.zeros(n_layers, n_heads)
# Clean run - save attention outputs (before projection)
with model.trace(clean_prompt):
clean_attn = [layer.attn.c_proj.input[0][0].save()
for layer in model.transformer.h]
# Patch each layer x head (separate forward passes)
for layer_idx in range(n_layers):
for head_idx in range(n_heads):
with model.trace(corrupted_prompt):
# Patch single head's output
start = head_idx * head_dim
end = (head_idx + 1) * head_dim
model.transformer.h[layer_idx].attn.c_proj.input[0][0][:, :, start:end] = \
clean_attn[layer_idx][:, :, start:end]
patched_logits = model.lm_head.output.save()
results[layer_idx, head_idx] = logit_diff(
patched_logits.value, correct_token, incorrect_token
)
Noising (Reverse Patching)
Instead of restoring clean activations, corrupt clean activations:
# Corrupted run - save activations
with model.trace(corrupted_prompt):
corrupted_hiddens = [layer.output[0].save() for layer in model.transformer.h]
# For each layer, inject corrupted activation into clean run
noising_results = torch.zeros(n_layers)
for layer_idx in range(n_layers):
with model.trace(clean_prompt):
# Inject corrupted activation into clean run
model.transformer.h[layer_idx].output[0][:] = corrupted_hiddens[layer_idx]
noised_logits = model.lm_head.output.save()
noising_results[layer_idx] = logit_diff(noised_logits.value, correct_token, incorrect_token)
Visualization
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(12, 8))
sns.heatmap(
results.numpy(),
xticklabels=[f"Pos {i}" for i in range(seq_len)],
yticklabels=[f"Layer {i}" for i in range(n_layers)],
cmap="RdBu_r",
center=0,
annot=True,
fmt=".2f"
)
plt.title("Activation Patching Results")
plt.xlabel("Token Position")
plt.ylabel("Layer")
plt.tight_layout()
plt.show()
Interpretation
- High positive values: Component is important for correct behavior
- Values near 0: Component doesn't affect this behavior
- Negative values: Component actively pushes toward wrong answer
- Clusters of importance: Suggest circuits or computational stages