diff --git a/tests/integration/model_bridge/test_direct_logit_attribution.py b/tests/integration/model_bridge/test_direct_logit_attribution.py index 5f7c941d7..5c175737e 100644 --- a/tests/integration/model_bridge/test_direct_logit_attribution.py +++ b/tests/integration/model_bridge/test_direct_logit_attribution.py @@ -1,102 +1,167 @@ -"""Integration tests for the Direct Logit Attribution tool on a TransformerBridge. +"""Integration tests for the Direct Logit Attribution tool. -DLA decomposes the residual-stream part of a logit (or logit difference) via the -unembedding direction ``W_U[:, token]``. The unembedding bias ``b_U`` is a -per-token constant that no component produces, so a complete decomposition -reconstructs ``logit - b_U`` (and, for a difference, ``logit_diff - (b_U[c] - -b_U[w])``), not the raw logit. We assert that invariant against the bridge's own -forward-pass logits — the point of issue #1263 is that DLA must be correct on a -``TransformerBridge`` (compatibility mode), not only ``HookedTransformer``. +DLA decomposes the part of a logit that comes from the residual stream via the +unembedding *direction* ``W_U[:, token]``. The unembedding *bias* ``b_U`` is a +per-token constant that no component produces, so the exact correctness +invariant is:: -Uses distilgpt2 (CI-cached), matching test_analysis_methods.py. + sum(component DLA for token) + b_U[token] == logit[token] + +and, for a difference of two tokens, the two bias terms do **not** generally +cancel (gpt2's folded ``ln_final`` bias makes them differ), so:: + + sum(component DLA, correct vs incorrect) == logit_diff - (b_U[c] - b_U[i]) + +We assert these for ``HookedTransformer`` and for ``TransformerBridge`` +(compatibility mode) — the latter is the reason issue #1263 exists. + +These tests load gpt2 (cached), so they live in ``integration/`` per +``tests/AGENTS.md``. """ import pytest -import torch - -from transformer_lens.tools.analysis import dla PROMPT = "The Eiffel Tower is in the city of" CORRECT = " Paris" -WRONG = " London" - - -def _last_token_logits(bridge, prompt): - """Final-position logits from the bridge's own forward pass.""" - logits, _ = bridge.run_with_cache(prompt) - if logits.ndim == 3: # [batch, pos, vocab] - logits = logits[0] - return logits[-1] # [vocab] - - -class TestDirectLogitAttributionCorrectness: - """The component contributions must reconstruct the model's real logits.""" - - def test_decompose_reconstructs_logit_difference(self, distilgpt2_bridge_compat): - bridge = distilgpt2_bridge_compat - c, w = bridge.to_single_token(CORRECT), bridge.to_single_token(WRONG) - scores, labels = dla(bridge, [PROMPT], torch.tensor([[c, w]])) - - logits = _last_token_logits(bridge, PROMPT) - expected = (logits[c] - logits[w]) - (bridge.b_U[c] - bridge.b_U[w]) - assert scores.sum().item() == pytest.approx(expected.item(), abs=1e-2) - assert len(scores) == len(labels) - - def test_decompose_reconstructs_single_token_logit(self, distilgpt2_bridge_compat): - bridge = distilgpt2_bridge_compat - c = bridge.to_single_token(CORRECT) - scores, _ = dla(bridge, [PROMPT], torch.tensor([[c]])) - - logits = _last_token_logits(bridge, PROMPT) - expected = logits[c] - bridge.b_U[c] - assert scores.sum().item() == pytest.approx(expected.item(), abs=1e-2) - - def test_accumulated_last_entry_reconstructs(self, distilgpt2_bridge_compat): - # accumulated_resid is cumulative -> reconstruction is the LAST entry, not the sum - bridge = distilgpt2_bridge_compat - c, w = bridge.to_single_token(CORRECT), bridge.to_single_token(WRONG) - scores, labels = dla(bridge, [PROMPT], torch.tensor([[c, w]]), accumulated=True) - - logits = _last_token_logits(bridge, PROMPT) - expected = (logits[c] - logits[w]) - (bridge.b_U[c] - bridge.b_U[w]) - assert scores[-1].item() == pytest.approx(expected.item(), abs=1e-2) - assert len(scores) == len(labels) - - -class TestDirectLogitAttributionShape: - def test_decompose_labels_and_shape(self, distilgpt2_bridge_compat): - bridge = distilgpt2_bridge_compat - c = bridge.to_single_token(CORRECT) - scores, labels = dla(bridge, [PROMPT], torch.tensor([[c]])) - - n_layers = bridge.cfg.n_layers - assert scores.ndim == 1 - assert len(scores) == len(labels) - assert "embed" in labels - assert sum(label.endswith("attn_out") for label in labels) == n_layers - assert sum(label.endswith("mlp_out") for label in labels) == n_layers - - def test_batch_of_prompts_is_averaged(self, distilgpt2_bridge_compat): - # two identical prompts -> the batch-mean equals the single-prompt result - bridge = distilgpt2_bridge_compat - c, w = bridge.to_single_token(CORRECT), bridge.to_single_token(WRONG) - single, _ = dla(bridge, [PROMPT], torch.tensor([[c, w]])) - doubled, _ = dla(bridge, [PROMPT, PROMPT], torch.tensor([[c, w], [c, w]])) - assert torch.allclose(single, doubled, atol=1e-4) - - -class TestDirectLogitAttributionGuardsOnRealBridge: - """The guards must also fire on a genuine bridge, not just a mock.""" - - def test_non_compat_bridge_raises(self, distilgpt2_bridge): - bridge = distilgpt2_bridge # compatibility mode NOT enabled - c = bridge.to_single_token(CORRECT) +INCORRECT = " London" + + +def _refs(model): + """Reference values: (logit_correct, logit_incorrect, b_U[c], b_U[i]).""" + logits = model(PROMPT) + if logits.ndim == 2: # some Bridge configs may drop the batch dim + logits = logits[None] + c = model.to_single_token(CORRECT) + i = model.to_single_token(INCORRECT) + return ( + logits[0, -1, c].item(), + logits[0, -1, i].item(), + model.b_U[c].item(), + model.b_U[i].item(), + ) + + +def _assert_complete_decomposition(model, unit): + """sum(DLA) reconstructs the logit / logit-diff up to the b_U constant.""" + from transformer_lens.tools.analysis import direct_logit_attribution + + logit_c, logit_i, bu_c, bu_i = _refs(model) + + diff = direct_logit_attribution( + model, PROMPT, answer_tokens=CORRECT, incorrect_tokens=INCORRECT, unit=unit + ) + single = direct_logit_attribution(model, PROMPT, answer_tokens=CORRECT, unit=unit) + + # accumulated_resid ("layer") is cumulative: the last entry is the full + # residual stream, so it (not the column sum) is the reconstruction. + diff_total = diff.attribution[-1].sum() if unit == "layer" else diff.attribution.sum() + single_total = single.attribution[-1].sum() if unit == "layer" else single.attribution.sum() + + assert diff_total.item() == pytest.approx((logit_c - logit_i) - (bu_c - bu_i), abs=1e-2) + assert single_total.item() == pytest.approx(logit_c - bu_c, abs=1e-2) + + +@pytest.fixture(scope="module") +def gpt2_ht(): + from transformer_lens import HookedTransformer + + return HookedTransformer.from_pretrained("gpt2", device="cpu") + + +class TestDirectLogitAttributionHooked: + """Correctness on HookedTransformer (the reference numerics).""" + + @pytest.mark.parametrize("unit", ["component", "layer", "head"]) + def test_decomposition_reconstructs_logit(self, gpt2_ht, unit): + _assert_complete_decomposition(gpt2_ht, unit) + + def test_component_labels_and_shape(self, gpt2_ht): + from transformer_lens.tools.analysis import direct_logit_attribution + + res = direct_logit_attribution( + gpt2_ht, PROMPT, answer_tokens=CORRECT, incorrect_tokens=INCORRECT, unit="component" + ) + assert res.unit == "component" + assert res.attribution.shape[0] == len(res.labels) + # Embedding term(s) plus each layer's attn_out and mlp_out. + assert "embed" in res.labels + assert sum(label.endswith("_attn_out") for label in res.labels) == gpt2_ht.cfg.n_layers + assert sum(label.endswith("_mlp_out") for label in res.labels) == gpt2_ht.cfg.n_layers + + def test_head_labels_include_remainder(self, gpt2_ht): + from transformer_lens.tools.analysis import direct_logit_attribution + + res = direct_logit_attribution(gpt2_ht, PROMPT, answer_tokens=CORRECT, unit="head") + assert len(res.labels) == gpt2_ht.cfg.n_layers * gpt2_ht.cfg.n_heads + 1 + assert res.labels[-1] == "remainder" + + def test_reuses_precomputed_cache(self, gpt2_ht): + from transformer_lens.tools.analysis import direct_logit_attribution + + logit_c, _, bu_c, _ = _refs(gpt2_ht) + _, cache = gpt2_ht.run_with_cache(PROMPT) + res = direct_logit_attribution(gpt2_ht, answer_tokens=CORRECT, cache=cache) + assert res.attribution.sum().item() == pytest.approx(logit_c - bu_c, abs=1e-2) + + def test_pos_none_keeps_position_axis(self, gpt2_ht): + from transformer_lens.tools.analysis import direct_logit_attribution + + n_tokens = gpt2_ht.to_tokens(PROMPT).shape[1] + res = direct_logit_attribution( + gpt2_ht, PROMPT, answer_tokens=CORRECT, unit="component", pos=None + ) + assert res.attribution.ndim == 3 # [component, batch, pos] + assert res.attribution.shape[-1] == n_tokens + + def test_top_returns_sorted_pairs(self, gpt2_ht): + from transformer_lens.tools.analysis import direct_logit_attribution + + res = direct_logit_attribution( + gpt2_ht, PROMPT, answer_tokens=CORRECT, incorrect_tokens=INCORRECT, unit="head" + ) + top = res.top(3) + assert len(top) == 3 + values = [v for _, v in top] + assert values == sorted(values, reverse=True) + + +class TestDirectLogitAttributionBridge: + """The point of #1263: DLA must work on TransformerBridge.""" + + @pytest.mark.parametrize("unit", ["component", "head"]) + def test_decomposition_reconstructs_logit(self, gpt2_bridge_compat, unit): + _assert_complete_decomposition(gpt2_bridge_compat, unit) + + +class TestDirectLogitAttributionBridgeGuards: + """Guards required for correctness on TransformerBridge (from PR #1316). + + Without compatibility mode the projection direction is wrong on a Bridge and + DLA would silently return incorrect numbers — verify the explicit refusal. + """ + + def test_non_compat_bridge_raises(self, gpt2_bridge): + from transformer_lens.tools.analysis import direct_logit_attribution + with pytest.raises(ValueError, match="compatibility mode"): - dla(bridge, [PROMPT], torch.tensor([[c]])) - - def test_hybrid_layer_types_raises(self, distilgpt2_bridge_compat, monkeypatch): - bridge = distilgpt2_bridge_compat - monkeypatch.setattr(bridge, "layer_types", lambda: ["attn+mlp", "mamba+mlp"]) - c = bridge.to_single_token(CORRECT) - with pytest.raises(NotImplementedError, match="hybrid"): - dla(bridge, [PROMPT], torch.tensor([[c]])) + direct_logit_attribution(gpt2_bridge, PROMPT, answer_tokens=CORRECT) + + +class TestDirectLogitAttributionValidation: + def test_invalid_unit_raises(self, gpt2_ht): + from transformer_lens.tools.analysis import direct_logit_attribution + + with pytest.raises(ValueError, match="unit must be one of"): + direct_logit_attribution(gpt2_ht, PROMPT, answer_tokens=CORRECT, unit="neuron") + + def test_missing_answer_tokens_raises(self, gpt2_ht): + from transformer_lens.tools.analysis import direct_logit_attribution + + with pytest.raises(ValueError, match="answer_tokens is required"): + direct_logit_attribution(gpt2_ht, PROMPT) + + def test_missing_input_and_cache_raises(self, gpt2_ht): + from transformer_lens.tools.analysis import direct_logit_attribution + + with pytest.raises(ValueError, match="either `input`"): + direct_logit_attribution(gpt2_ht, answer_tokens=CORRECT) diff --git a/tests/unit/tools/test_direct_logit_attribution.py b/tests/unit/tools/test_direct_logit_attribution.py index 44332ddf5..2b77805e7 100644 --- a/tests/unit/tools/test_direct_logit_attribution.py +++ b/tests/unit/tools/test_direct_logit_attribution.py @@ -1,53 +1,53 @@ -"""Unit tests for the Direct Logit Attribution guards and argument validation. +"""Unit tests for Direct Logit Attribution guards and argument validation. -These exercise the fast-failing checks (argument validation, the +These exercise the fast-failing checks (argument validation, the Bridge compatibility-mode requirement, and the hybrid-architecture refusal) without -loading a real model, using a ``spec``-ed mock bridge so the checks are reached -before any forward pass. +loading a real model — using a ``spec``-ed mock TransformerBridge so the checks +fire before any forward pass. """ from unittest.mock import MagicMock import pytest -import torch from transformer_lens.model_bridge import TransformerBridge -from transformer_lens.tools.analysis import dla +from transformer_lens.tools.analysis import direct_logit_attribution def _mock_bridge(compatibility_mode=True, layer_types=("attn+mlp",)): - """A mock TransformerBridge that satisfies isinstance/type checks.""" + """A mock TransformerBridge that satisfies isinstance checks.""" bridge = MagicMock(spec=TransformerBridge) bridge.compatibility_mode = compatibility_mode bridge.layer_types.return_value = list(layer_types) return bridge -def test_prompt_answer_length_mismatch_raises(): +def test_invalid_unit_raises(): bridge = _mock_bridge() - with pytest.raises(ValueError, match="matching row"): - dla(bridge, ["a", "b"], torch.tensor([[1]])) + with pytest.raises(ValueError, match="unit must be one of"): + direct_logit_attribution(bridge, "hi", answer_tokens=" world", unit="neuron") -def test_invalid_answer_columns_raises(): +def test_missing_answer_tokens_raises(): bridge = _mock_bridge() - with pytest.raises(ValueError, match="columns"): - dla(bridge, ["a"], torch.tensor([[1, 2, 3]])) + with pytest.raises(ValueError, match="answer_tokens is required"): + direct_logit_attribution(bridge, "hi") def test_requires_compatibility_mode(): bridge = _mock_bridge(compatibility_mode=False) with pytest.raises(ValueError, match="compatibility mode"): - dla(bridge, ["a"], torch.tensor([[1]])) + direct_logit_attribution(bridge, "hi", answer_tokens=" world") def test_rejects_hybrid_architecture(): bridge = _mock_bridge(layer_types=("attn+mlp", "mamba+mlp")) with pytest.raises(NotImplementedError, match="hybrid"): - dla(bridge, ["a"], torch.tensor([[1]])) + direct_logit_attribution(bridge, "hi", answer_tokens=" world") -def test_dla_is_exported_from_analysis_package(): +def test_direct_logit_attribution_is_exported_from_analysis_package(): from transformer_lens.tools import analysis - assert analysis.dla is dla + assert analysis.direct_logit_attribution is direct_logit_attribution + assert hasattr(analysis, "DirectLogitAttribution") diff --git a/transformer_lens/tools/__init__.py b/transformer_lens/tools/__init__.py index ea31a6af7..2b78fa84e 100644 --- a/transformer_lens/tools/__init__.py +++ b/transformer_lens/tools/__init__.py @@ -4,7 +4,7 @@ including the model registry for discovering compatible HuggingFace models. Subpackages: - - analysis: Interpretability analyses such as Direct Logit Attribution + - analysis: High-level interpretability analyses (e.g. Direct Logit Attribution) - model_registry: Tools for discovering and documenting supported models """ diff --git a/transformer_lens/tools/analysis/__init__.py b/transformer_lens/tools/analysis/__init__.py index 137ec42e3..a9cc5d19e 100644 --- a/transformer_lens/tools/analysis/__init__.py +++ b/transformer_lens/tools/analysis/__init__.py @@ -1,5 +1,17 @@ -"""Analysis tools for the TransformerBridge system.""" +"""Analysis tools for TransformerLens. -from transformer_lens.tools.analysis.direct_logit_attribution import dla +This subpackage collects high-level, single-call interpretability analyses that +sit on top of the hook/cache system. They work with both ``HookedTransformer`` +and the newer ``TransformerBridge`` (the two share the ``ActivationCache`` API). -__all__ = ["dla"] +Tools: + - direct_logit_attribution: Direct Logit Attribution (DLA) over components, + layers, or attention heads. +""" + +from transformer_lens.tools.analysis.direct_logit_attribution import ( + DirectLogitAttribution, + direct_logit_attribution, +) + +__all__ = ["DirectLogitAttribution", "direct_logit_attribution"] diff --git a/transformer_lens/tools/analysis/direct_logit_attribution.py b/transformer_lens/tools/analysis/direct_logit_attribution.py index b5814d15b..96c31d5e9 100644 --- a/transformer_lens/tools/analysis/direct_logit_attribution.py +++ b/transformer_lens/tools/analysis/direct_logit_attribution.py @@ -1,177 +1,264 @@ -"""Direct Logit Attribution. +"""Direct Logit Attribution (DLA). -Decomposes a model's logit difference into per-component contributions from -the residual stream. See :func:`dla` for usage. +Direct Logit Attribution decomposes a model's output logit (or a logit +*difference* between a correct and an incorrect token) into the additive +contributions of upstream components — the embedding, each attention and MLP +sublayer, or each individual attention head. Because the unembedding is linear +and the residual stream is a sum of component outputs, the final logit is +(up to the final LayerNorm) a sum of per-component dot products with the +unembedding direction of the token of interest. DLA reads off those dot +products. See the `logit lens +`_ +and `Interpretability in the Wild `_ for the +canonical uses. + +This module exposes a single entry point, :func:`direct_logit_attribution`, +that wraps the lower-level ``ActivationCache`` primitives +(:meth:`~transformer_lens.ActivationCache.ActivationCache.decompose_resid`, +:meth:`~transformer_lens.ActivationCache.ActivationCache.accumulated_resid`, +:meth:`~transformer_lens.ActivationCache.ActivationCache.stack_head_results` +and :meth:`~transformer_lens.ActivationCache.ActivationCache.logit_attrs`) into +one call. It works unchanged with both ``HookedTransformer`` and +``TransformerBridge`` because they share the cache API. + +Example:: + + from transformer_lens import HookedTransformer + from transformer_lens.tools.analysis import direct_logit_attribution + + model = HookedTransformer.from_pretrained("gpt2", device="cpu") + result = direct_logit_attribution( + model, + "The Eiffel Tower is in the city of", + answer_tokens=" Paris", + incorrect_tokens=" London", + unit="component", + ) + for label, value in zip(result.labels, result.attribution.squeeze()): + print(f"{label:>12}: {value.item():+.3f}") """ -from typing import List, Tuple +from dataclasses import dataclass +from typing import List, Optional, Union -import einops import torch -from jaxtyping import Float, Int +from jaxtyping import Float from transformer_lens.ActivationCache import ActivationCache -from transformer_lens.model_bridge import TransformerBridge -from transformer_lens.utilities import get_act_name +from transformer_lens.utilities import SliceInput + +# Token-like inputs accepted for the correct/incorrect answers, mirroring +# ActivationCache.logit_attrs. +TokenInput = Union[ + str, + int, + torch.Tensor, +] + +# Which structural unit the residual stream is decomposed into. +Unit = str # one of: "component", "layer", "head" + +_VALID_UNITS = ("component", "layer", "head") # Block variants that lack the attn_out + mlp_out structure decompose_resid expects. +# When TransformerBridge.layer_types() reports any of these we refuse early — the +# downstream decompose_resid would otherwise raise a confusing KeyError. _HYBRID_VARIANT_NAMES = ("mamba", "ssm", "mixer", "linear_attn") -def dla( - bridge: TransformerBridge, - prompts: List[str], - answer_tokens: Int[torch.Tensor, "batch answers"], - accumulated: bool = False, -) -> Tuple[Float[torch.Tensor, "component"], List[str]]: - """Compute Direct Logit Attribution for a TransformerBridge model. - - Decomposes the logit (or logit difference between a correct and wrong token) - into per-component contributions from the residual stream, averaged across - the batch of prompts. Two modes: +@dataclass +class DirectLogitAttribution: + """Result of a :func:`direct_logit_attribution` call. - * ``accumulated=False`` (default): the contribution of each individual - component — embedding, per-layer attention output, per-layer MLP output — - via :meth:`transformer_lens.ActivationCache.decompose_resid`. These are - additive: their sum reconstructs the (bias-excluded) logit difference. - * ``accumulated=True``: the cumulative residual stream at each layer boundary - (logit-lens style) via - :meth:`transformer_lens.ActivationCache.accumulated_resid`. These are - cumulative, so the full reconstruction is the *last* entry, not the sum. + Attributes: + attribution: + Tensor of logit (or logit-difference) attributions with shape + ``[component, *batch_and_pos]``. The leading axis is aligned with + ``labels``. When ``pos`` selects a single position (the default) the + position axis is dropped, leaving ``[component, batch]`` — or + ``[component]`` if the cache had its batch dimension removed. + labels: + Human-readable name for each component, aligned with the leading + axis of ``attribution`` (e.g. ``"embed"``, ``"0_attn_out"``, + ``"L3H7"``). + unit: + The decomposition unit used ("component", "layer", or "head"). + """ - Warning: + attribution: Float[torch.Tensor, "component *batch_and_pos"] + labels: List[str] + unit: Unit - Returned scores sum (decompose mode) or end (accumulated mode) at - ``actual_logit_diff - (b_U[correct] - b_U[wrong])``, not ``actual_logit_diff`` - directly. The unembedding bias :math:`b_U` is a constant offset added after - unembedding rather than a residual-stream contribution, so it is excluded by - convention — matching :meth:`transformer_lens.ActivationCache.logit_attrs`. + def top(self, k: int = 5) -> List[tuple]: + """Return the ``k`` highest-attribution ``(label, value)`` pairs. - Warning: + Attribution is reduced to a scalar per component by meaning over any + remaining batch/position dimensions, so this is most meaningful when a + single position was selected. + """ + flat = self.attribution + if flat.ndim > 1: + flat = flat.flatten(start_dim=1).mean(dim=-1) + values, indices = torch.topk(flat, min(k, flat.shape[0])) + return [(self.labels[i], values[j].item()) for j, i in enumerate(indices.tolist())] - This function requires the bridge to be in compatibility mode so the final - LayerNorm weights are folded into :math:`W_U`. Without folding, the - projection direction is wrong and the scores do not reflect actual logit - contributions. Call ``bridge.enable_compatibility_mode()`` after loading. - Warning: +def _residual_stack_and_labels( + cache: ActivationCache, + unit: Unit, + pos_slice: SliceInput, +): + """Decompose the residual stream into ``unit`` components plus labels. - Hybrid architectures (Mamba, SSM, Mixer, LinearAttention) are not yet - supported and raise :class:`NotImplementedError`; support requires extending - :meth:`transformer_lens.ActivationCache.decompose_resid` to handle - non-attention blocks. + LayerNorm is intentionally *not* applied here — ``logit_attrs`` applies the + final-layer scaling itself, so applying it twice would double-count. + """ + if unit == "component": + # embed (+ pos_embed) and each layer's attn_out / mlp_out. + return cache.decompose_resid(apply_ln=False, pos_slice=pos_slice, return_labels=True) + if unit == "layer": + # Cumulative residual stream after each sublayer — logit-lens style. + return cache.accumulated_resid( + apply_ln=False, incl_mid=True, pos_slice=pos_slice, return_labels=True + ) + if unit == "head": + # Each attention head's contribution, plus the MLP/embedding remainder. + return cache.stack_head_results( + apply_ln=False, pos_slice=pos_slice, incl_remainder=True, return_labels=True + ) + raise ValueError(f"unit must be one of {_VALID_UNITS}, got {unit!r}") - Args: - bridge: - A :class:`transformer_lens.model_bridge.TransformerBridge` with - compatibility mode enabled. - prompts: - Prompt strings to evaluate. Length must match ``answer_tokens.shape[0]``. - answer_tokens: - Tensor of shape ``(batch, answers)``. ``answers == 1`` decomposes the - single target token's logit; ``answers == 2`` treats each row as a - ``(correct, wrong)`` pair and decomposes the logit difference. - accumulated: - If ``True``, return cumulative per-layer contributions (logit lens). - If ``False`` (default), return per-component contributions. - Returns: - ``(scores, labels)``: ``scores`` is a 1D tensor of per-component (or - per-layer) contributions averaged across the batch, and ``labels`` is the - matching list of human-readable component names. +def _validate_bridge_compatibility(model) -> None: + """Reject Bridge inputs that DLA can't produce correct numbers for. - Raises: - ValueError: If ``len(prompts)`` does not match ``answer_tokens.shape[0]``, - if ``answer_tokens.shape[1]`` is not ``1`` or ``2``, or if the bridge - is not in compatibility mode. - NotImplementedError: If the bridge contains hybrid (non attention + MLP) - blocks. + HookedTransformer always has LN folded into W_U, so these checks only fire + for TransformerBridge. The compatibility-mode check catches a silent- + correctness footgun: without folded LN, the projection direction in + ``logit_attrs`` is wrong on a Bridge. The hybrid-arch check catches Mamba/ + SSM blocks early with a clear error rather than letting ``decompose_resid`` + raise a confusing KeyError downstream. """ + # Lazy import — keeps the module importable without dragging in the bridge. + from transformer_lens.model_bridge import TransformerBridge - # input validation - if len(prompts) != answer_tokens.shape[0]: - raise ValueError( - "Each prompt needs a matching row of answer tokens: got " - f"{len(prompts)} prompts but {answer_tokens.shape[0]} answer-token rows." - ) - if answer_tokens.shape[1] not in (1, 2): - raise ValueError( - "answer_tokens must have 1 (single token) or 2 (correct, wrong) columns, " - f"got {answer_tokens.shape[1]}." - ) + if not isinstance(model, TransformerBridge): + return - # safeguard: DLA needs LayerNorm folded into W_U, which compatibility mode does - if not getattr(bridge, "compatibility_mode", False): + if not getattr(model, "compatibility_mode", False): raise ValueError( - "DLA requires the bridge to be in compatibility mode so that LayerNorm " - "weights are folded into W_U. Call `bridge.enable_compatibility_mode()` " + "DLA on a TransformerBridge requires compatibility mode so that LayerNorm " + "weights are folded into W_U. Call `model.enable_compatibility_mode()` " "after loading the bridge, then re-run DLA." ) - # safeguard: hybrid blocks (Mamba/SSM/...) have no attn_out + mlp_out to decompose - hybrid_blocks = [ - layer_type - for layer_type in bridge.layer_types() - if any(part in _HYBRID_VARIANT_NAMES for part in layer_type.split("+")) - ] - if hybrid_blocks: + layer_types = model.layer_types() + hybrid = [lt for lt in layer_types if any(p in _HYBRID_VARIANT_NAMES for p in lt.split("+"))] + if hybrid: raise NotImplementedError( - f"DLA does not yet support hybrid architectures (found block types " - f"{hybrid_blocks}). Only standard attention + MLP transformers (e.g. " - f"GPT-2, LLaMA, Pythia) are supported; hybrid support requires extending " - f"ActivationCache.decompose_resid." + f"DLA does not yet support hybrid architectures (found block types {hybrid}). " + f"Only standard attention + MLP transformers (e.g. GPT-2, LLaMA, Pythia) are " + f"supported; hybrid support requires extending ActivationCache.decompose_resid." ) - # unembedding direction per prompt; single-token tensors collapse to [d_model] - answer_residual_directions = bridge.tokens_to_residual_directions(answer_tokens).reshape( - answer_tokens.shape[0], answer_tokens.shape[1], -1 - ) - if answer_tokens.shape[1] == 1: - logit_diff_directions: Float[torch.Tensor, "batch d_model"] = answer_residual_directions[ - :, 0, : - ] - else: # (correct, wrong) pair -> direction of the logit difference - correct_direction, wrong_direction = answer_residual_directions.unbind(dim=1) - logit_diff_directions = correct_direction - wrong_direction - - def residual_stack_to_logit_diff( - residual_stack: Float[torch.Tensor, "... batch d_model"], - cache: ActivationCache, - ) -> Float[torch.Tensor, "..."]: - # apply the final LayerNorm once, project onto the answer direction, average over prompts - batch_size = residual_stack.size(-2) - scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1) - return ( - einops.einsum( - scaled_residual_stack, - logit_diff_directions, - "... batch d_model, batch d_model -> ...", - ) - / batch_size - ) - if accumulated: - n_layers = bridge.cfg.n_layers - _, cache = bridge.run_with_cache( - prompts, - names_filter=lambda name: name == get_act_name("resid_post", n_layers - 1) - or name == "ln_final.hook_scale" - or name.endswith("resid_pre") - or name.endswith("resid_mid"), - ) - accumulated_residual, labels = cache.accumulated_resid( - layer=-1, pos_slice=-1, incl_mid=True, return_labels=True - ) - return residual_stack_to_logit_diff(accumulated_residual, cache), labels - - _, cache = bridge.run_with_cache( - prompts, - names_filter=lambda name: name == "ln_final.hook_scale" - or name in ("hook_embed", "hook_pos_embed") - or name.endswith("attn_out") - or name.endswith("mlp_out"), - ) - per_component_residual, labels = cache.decompose_resid( - layer=-1, pos_slice=-1, return_labels=True +def direct_logit_attribution( + model, + input: Union[str, List[str], torch.Tensor, None] = None, + answer_tokens: Optional[TokenInput] = None, + incorrect_tokens: Optional[TokenInput] = None, + *, + unit: Unit = "component", + pos: SliceInput = -1, + cache: Optional[ActivationCache] = None, +) -> DirectLogitAttribution: + """Compute Direct Logit Attribution for a prompt. + + Decomposes the contribution of model components to the logit of + ``answer_tokens`` (or, if ``incorrect_tokens`` is given, to the logit + *difference* ``answer - incorrect`` along the ``W_U`` direction, which is + usually what you want for circuit analysis). + + The model is run once with caching unless a precomputed ``cache`` is passed. + Works with both ``HookedTransformer`` and ``TransformerBridge``. + + Note that DLA attributes only the part of a logit that comes from the + residual stream through the unembedding direction; the unembedding bias + ``b_U`` is a per-token constant that no component produces. So a complete + decomposition reconstructs ``logit[token] - b_U[token]`` rather than the raw + logit. + + On a ``TransformerBridge``, compatibility mode must be enabled (so the final + LayerNorm is folded into ``W_U``) — otherwise the projection direction is + wrong and DLA returns silently incorrect numbers. Hybrid architectures + (Mamba/SSM/Mixer/LinearAttention) are not yet supported because + ``decompose_resid`` only understands the ``attn_out + mlp_out`` block layout; + both conditions raise an explicit error at call time. + + Args: + model: + A ``HookedTransformer`` or ``TransformerBridge`` (the latter with + ``enable_compatibility_mode()`` already called). + input: + Prompt to run — a string, list of strings, or token tensor. Optional + only when a precomputed ``cache`` is supplied. + answer_tokens: + The correct token(s) to attribute, as a string, id, or tensor. A + string is converted with ``model.to_single_token``. + incorrect_tokens: + Optional baseline token(s). When given, attribution is computed for + the ``answer - incorrect`` residual direction. Must broadcast to the + same shape as ``answer_tokens``. + unit: + Decomposition granularity: + + - ``"component"`` (default): embedding + each layer's attention and + MLP output (via ``decompose_resid``). + - ``"layer"``: cumulative residual stream after each sublayer, i.e. + logit-lens trajectory (via ``accumulated_resid``). + - ``"head"``: each attention head individually, plus a remainder + term for everything else (via ``stack_head_results``). + pos: + Sequence position(s) to attribute. Defaults to ``-1`` (the final + token, the usual choice for next-token DLA). Pass ``None`` to keep + every position (the result then has a trailing position axis). + cache: + Optional precomputed ``ActivationCache`` to reuse instead of running + the model again. + + Returns: + A :class:`DirectLogitAttribution` with ``attribution`` (shape + ``[component, *batch_and_pos]``) and aligned ``labels``. + + Raises: + ValueError: If ``unit`` is invalid, ``answer_tokens`` is ``None``, + neither ``input`` nor ``cache`` is provided, or a + ``TransformerBridge`` is passed without compatibility mode enabled. + NotImplementedError: If a ``TransformerBridge`` reports a hybrid block + layout (Mamba/SSM/Mixer/LinearAttention). + """ + if unit not in _VALID_UNITS: + raise ValueError(f"unit must be one of {_VALID_UNITS}, got {unit!r}") + if answer_tokens is None: + raise ValueError("answer_tokens is required") + + _validate_bridge_compatibility(model) + + if cache is None: + if input is None: + raise ValueError("provide either `input` to run the model, or a precomputed `cache`") + _, cache = model.run_with_cache(input) + + residual_stack, labels = _residual_stack_and_labels(cache, unit, pos) + + # logit_attrs applies the final LayerNorm scaling (with the same pos slice) + # and dots each component against the (correct - incorrect) unembed direction. + attribution = cache.logit_attrs( + residual_stack, + tokens=answer_tokens, + incorrect_tokens=incorrect_tokens, + pos_slice=pos, + has_batch_dim=cache.has_batch_dim, ) - return residual_stack_to_logit_diff(per_component_residual, cache), labels + + return DirectLogitAttribution(attribution=attribution, labels=labels, unit=unit)