name: attribution-patching description: Gradient-based approximation to activation patching for scalable circuit analysis. Use when activation patching is too slow or when analyzing many components simultaneously.
Attribution Patching
Attribution patching uses gradients to approximate activation patching results in a single backward pass, making it practical to analyze thousands of components simultaneously.
Core Idea
Instead of running separate forward passes for each component:
- Run clean and corrupted forward passes
- Compute gradients of the metric w.r.t. corrupted activations
- Multiply gradients by (clean - corrupted) activation differences
This linear approximation works when clean and corrupted runs are similar.
Mathematical Formula
attribution(component) = grad_corrupted(metric) * (clean_activation - corrupted_activation)
Setup
from nnsight import LanguageModel
import torch
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
correct_token = model.tokenizer(" John")["input_ids"][0]
incorrect_token = model.tokenizer(" Mary")["input_ids"][0]
def logit_diff(logits):
return logits[0, -1, correct_token] - logits[0, -1, incorrect_token]
Basic Attribution Patching
n_layers = len(model.transformer.h)
clean_acts = []
corrupted_acts = []
corrupted_grads = []
# Clean forward pass - save activations
with model.trace(clean_prompt):
for layer in model.transformer.h:
act = layer.output[0]
clean_acts.append(act.save())
# Corrupted forward + backward pass
with model.trace(corrupted_prompt):
# Register intermediate values in forward order
for layer in model.transformer.h:
act = layer.output[0]
act.requires_grad = True
corrupted_acts.append(act.save())
# Compute metric
logits = model.lm_head.output
metric = logit_diff(logits)
# Access gradients in REVERSE order within backward context
with metric.backward():
for layer in reversed(model.transformer.h):
corrupted_grads.insert(0, layer.output[0].grad.save())
# Compute attributions
attributions = []
for i in range(n_layers):
clean = clean_acts[i].value
corrupted = corrupted_acts[i].value
grad = corrupted_grads[i].value
# Attribution = grad * (clean - corrupted)
attr = (grad * (clean - corrupted)).sum()
attributions.append(attr.item())
attributions = torch.tensor(attributions)
Per-Position Attribution
seq_len = clean_acts[0].value.shape[1]
position_attrs = torch.zeros(n_layers, seq_len)
for layer_idx in range(n_layers):
clean = clean_acts[layer_idx].value
corrupted = corrupted_acts[layer_idx].value
grad = corrupted_grads[layer_idx].value
# Sum over hidden dimension only, keep position
diff = clean - corrupted
attr = (grad * diff).sum(dim=-1).squeeze() # [seq_len]
position_attrs[layer_idx] = attr
Attention Head Attribution
from einops import rearrange
n_heads = model.config.n_head
head_dim = model.config.n_embd // n_heads
head_attrs = torch.zeros(n_layers, n_heads)
# Collect clean attention outputs
clean_attn = []
with model.trace(clean_prompt):
for layer in model.transformer.h:
attn_out = layer.attn.c_proj.input[0][0] # Before projection
clean_attn.append(attn_out.save())
# Collect corrupted attention outputs and gradients
corrupted_attn = []
attn_grads = []
with model.trace(corrupted_prompt):
# Register intermediate values in forward order
for layer in model.transformer.h:
attn_out = layer.attn.c_proj.input[0][0]
attn_out.requires_grad = True
corrupted_attn.append(attn_out.save())
metric = logit_diff(model.lm_head.output)
# Access gradients in REVERSE order within backward context
with metric.backward():
for layer in reversed(model.transformer.h):
attn_grads.insert(0, layer.attn.c_proj.input[0][0].grad.save())
# Compute per-head attributions
for layer_idx in range(n_layers):
clean = clean_attn[layer_idx].value
corrupted = corrupted_attn[layer_idx].value
grad = attn_grads[layer_idx].value
# Reshape to [batch, seq, heads, head_dim]
clean_heads = rearrange(clean, 'b s (h d) -> b s h d', h=n_heads)
corrupted_heads = rearrange(corrupted, 'b s (h d) -> b s h d', h=n_heads)
grad_heads = rearrange(grad, 'b s (h d) -> b s h d', h=n_heads)
# Attribution per head
diff = clean_heads - corrupted_heads
attr = (grad_heads * diff).sum(dim=(0, 1, 3)) # Sum batch, seq, head_dim
head_attrs[layer_idx] = attr
Efficient Batched Version
Process both prompts in a single forward pass using batching:
# Batch both prompts together in a single trace
all_acts = []
all_grads = []
with model.trace([clean_prompt, corrupted_prompt]):
# Register intermediate values in forward order
for layer in model.transformer.h:
act = layer.output[0]
act.requires_grad = True
all_acts.append(act.save())
logits = model.lm_head.output
# Metric on corrupted (index 1)
metric = logit_diff(logits[1:2])
# Access gradients in REVERSE order within backward context
with metric.backward():
for layer in reversed(model.transformer.h):
all_grads.insert(0, layer.output[0].grad.save())
# Split clean/corrupted and compute attributions
attributions = []
for i in range(n_layers):
acts = all_acts[i].value
grads = all_grads[i].value
clean = acts[0:1]
corrupted = acts[1:2]
grad = grads[1:2] # Gradient is only for corrupted
attr = (grad * (clean - corrupted)).sum()
attributions.append(attr.item())
Comparison with Activation Patching
| Aspect | Activation Patching | Attribution Patching |
|---|---|---|
| Accuracy | Exact | Approximation |
| Speed | O(n_components) forwards | O(1) forward + backward |
| Memory | Lower per run | Higher (stores grads) |
| Best for | Few components | Many components |
Validation
Compare attribution results against ground truth patching:
# Scatter plot: attribution vs actual patching effect
import matplotlib.pyplot as plt
plt.scatter(attributions, actual_patching_results)
plt.xlabel("Attribution Score")
plt.ylabel("Actual Patching Effect")
plt.title("Attribution vs Patching Correlation")
correlation = torch.corrcoef(torch.stack([attributions, actual_patching_results]))[0, 1]
plt.text(0.1, 0.9, f"r = {correlation:.3f}", transform=plt.gca().transAxes)
When to Use
- Use attribution patching: Initial exploration, many components, large models
- Use activation patching: Validating specific components, exact measurements needed
- Combine both: Attribution for screening, patching for confirmation