Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 158 additions & 93 deletions tests/integration/model_bridge/test_direct_logit_attribution.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,167 @@
"""Integration tests for the Direct Logit Attribution tool on a TransformerBridge.
"""Integration tests for the Direct Logit Attribution tool.

DLA decomposes the residual-stream part of a logit (or logit difference) via the
unembedding direction ``W_U[:, token]``. The unembedding bias ``b_U`` is a
per-token constant that no component produces, so a complete decomposition
reconstructs ``logit - b_U`` (and, for a difference, ``logit_diff - (b_U[c] -
b_U[w])``), not the raw logit. We assert that invariant against the bridge's own
forward-pass logits — the point of issue #1263 is that DLA must be correct on a
``TransformerBridge`` (compatibility mode), not only ``HookedTransformer``.
DLA decomposes the part of a logit that comes from the residual stream via the
unembedding *direction* ``W_U[:, token]``. The unembedding *bias* ``b_U`` is a
per-token constant that no component produces, so the exact correctness
invariant is::

Uses distilgpt2 (CI-cached), matching test_analysis_methods.py.
sum(component DLA for token) + b_U[token] == logit[token]

and, for a difference of two tokens, the two bias terms do **not** generally
cancel (gpt2's folded ``ln_final`` bias makes them differ), so::

sum(component DLA, correct vs incorrect) == logit_diff - (b_U[c] - b_U[i])

We assert these for ``HookedTransformer`` and for ``TransformerBridge``
(compatibility mode) — the latter is the reason issue #1263 exists.

These tests load gpt2 (cached), so they live in ``integration/`` per
``tests/AGENTS.md``.
"""

import pytest
import torch

from transformer_lens.tools.analysis import dla

PROMPT = "The Eiffel Tower is in the city of"
CORRECT = " Paris"
WRONG = " London"


def _last_token_logits(bridge, prompt):
"""Final-position logits from the bridge's own forward pass."""
logits, _ = bridge.run_with_cache(prompt)
if logits.ndim == 3: # [batch, pos, vocab]
logits = logits[0]
return logits[-1] # [vocab]


class TestDirectLogitAttributionCorrectness:
"""The component contributions must reconstruct the model's real logits."""

def test_decompose_reconstructs_logit_difference(self, distilgpt2_bridge_compat):
bridge = distilgpt2_bridge_compat
c, w = bridge.to_single_token(CORRECT), bridge.to_single_token(WRONG)
scores, labels = dla(bridge, [PROMPT], torch.tensor([[c, w]]))

logits = _last_token_logits(bridge, PROMPT)
expected = (logits[c] - logits[w]) - (bridge.b_U[c] - bridge.b_U[w])
assert scores.sum().item() == pytest.approx(expected.item(), abs=1e-2)
assert len(scores) == len(labels)

def test_decompose_reconstructs_single_token_logit(self, distilgpt2_bridge_compat):
bridge = distilgpt2_bridge_compat
c = bridge.to_single_token(CORRECT)
scores, _ = dla(bridge, [PROMPT], torch.tensor([[c]]))

logits = _last_token_logits(bridge, PROMPT)
expected = logits[c] - bridge.b_U[c]
assert scores.sum().item() == pytest.approx(expected.item(), abs=1e-2)

def test_accumulated_last_entry_reconstructs(self, distilgpt2_bridge_compat):
# accumulated_resid is cumulative -> reconstruction is the LAST entry, not the sum
bridge = distilgpt2_bridge_compat
c, w = bridge.to_single_token(CORRECT), bridge.to_single_token(WRONG)
scores, labels = dla(bridge, [PROMPT], torch.tensor([[c, w]]), accumulated=True)

logits = _last_token_logits(bridge, PROMPT)
expected = (logits[c] - logits[w]) - (bridge.b_U[c] - bridge.b_U[w])
assert scores[-1].item() == pytest.approx(expected.item(), abs=1e-2)
assert len(scores) == len(labels)


class TestDirectLogitAttributionShape:
def test_decompose_labels_and_shape(self, distilgpt2_bridge_compat):
bridge = distilgpt2_bridge_compat
c = bridge.to_single_token(CORRECT)
scores, labels = dla(bridge, [PROMPT], torch.tensor([[c]]))

n_layers = bridge.cfg.n_layers
assert scores.ndim == 1
assert len(scores) == len(labels)
assert "embed" in labels
assert sum(label.endswith("attn_out") for label in labels) == n_layers
assert sum(label.endswith("mlp_out") for label in labels) == n_layers

def test_batch_of_prompts_is_averaged(self, distilgpt2_bridge_compat):
# two identical prompts -> the batch-mean equals the single-prompt result
bridge = distilgpt2_bridge_compat
c, w = bridge.to_single_token(CORRECT), bridge.to_single_token(WRONG)
single, _ = dla(bridge, [PROMPT], torch.tensor([[c, w]]))
doubled, _ = dla(bridge, [PROMPT, PROMPT], torch.tensor([[c, w], [c, w]]))
assert torch.allclose(single, doubled, atol=1e-4)


class TestDirectLogitAttributionGuardsOnRealBridge:
"""The guards must also fire on a genuine bridge, not just a mock."""

def test_non_compat_bridge_raises(self, distilgpt2_bridge):
bridge = distilgpt2_bridge # compatibility mode NOT enabled
c = bridge.to_single_token(CORRECT)
INCORRECT = " London"


def _refs(model):
"""Reference values: (logit_correct, logit_incorrect, b_U[c], b_U[i])."""
logits = model(PROMPT)
if logits.ndim == 2: # some Bridge configs may drop the batch dim
logits = logits[None]
c = model.to_single_token(CORRECT)
i = model.to_single_token(INCORRECT)
return (
logits[0, -1, c].item(),
logits[0, -1, i].item(),
model.b_U[c].item(),
model.b_U[i].item(),
)


def _assert_complete_decomposition(model, unit):
"""sum(DLA) reconstructs the logit / logit-diff up to the b_U constant."""
from transformer_lens.tools.analysis import direct_logit_attribution

logit_c, logit_i, bu_c, bu_i = _refs(model)

diff = direct_logit_attribution(
model, PROMPT, answer_tokens=CORRECT, incorrect_tokens=INCORRECT, unit=unit
)
single = direct_logit_attribution(model, PROMPT, answer_tokens=CORRECT, unit=unit)

# accumulated_resid ("layer") is cumulative: the last entry is the full
# residual stream, so it (not the column sum) is the reconstruction.
diff_total = diff.attribution[-1].sum() if unit == "layer" else diff.attribution.sum()
single_total = single.attribution[-1].sum() if unit == "layer" else single.attribution.sum()

assert diff_total.item() == pytest.approx((logit_c - logit_i) - (bu_c - bu_i), abs=1e-2)
assert single_total.item() == pytest.approx(logit_c - bu_c, abs=1e-2)


@pytest.fixture(scope="module")
def gpt2_ht():
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained("gpt2", device="cpu")


class TestDirectLogitAttributionHooked:
"""Correctness on HookedTransformer (the reference numerics)."""

@pytest.mark.parametrize("unit", ["component", "layer", "head"])
def test_decomposition_reconstructs_logit(self, gpt2_ht, unit):
_assert_complete_decomposition(gpt2_ht, unit)

def test_component_labels_and_shape(self, gpt2_ht):
from transformer_lens.tools.analysis import direct_logit_attribution

res = direct_logit_attribution(
gpt2_ht, PROMPT, answer_tokens=CORRECT, incorrect_tokens=INCORRECT, unit="component"
)
assert res.unit == "component"
assert res.attribution.shape[0] == len(res.labels)
# Embedding term(s) plus each layer's attn_out and mlp_out.
assert "embed" in res.labels
assert sum(label.endswith("_attn_out") for label in res.labels) == gpt2_ht.cfg.n_layers
assert sum(label.endswith("_mlp_out") for label in res.labels) == gpt2_ht.cfg.n_layers

def test_head_labels_include_remainder(self, gpt2_ht):
from transformer_lens.tools.analysis import direct_logit_attribution

res = direct_logit_attribution(gpt2_ht, PROMPT, answer_tokens=CORRECT, unit="head")
assert len(res.labels) == gpt2_ht.cfg.n_layers * gpt2_ht.cfg.n_heads + 1
assert res.labels[-1] == "remainder"

def test_reuses_precomputed_cache(self, gpt2_ht):
from transformer_lens.tools.analysis import direct_logit_attribution

logit_c, _, bu_c, _ = _refs(gpt2_ht)
_, cache = gpt2_ht.run_with_cache(PROMPT)
res = direct_logit_attribution(gpt2_ht, answer_tokens=CORRECT, cache=cache)
assert res.attribution.sum().item() == pytest.approx(logit_c - bu_c, abs=1e-2)

def test_pos_none_keeps_position_axis(self, gpt2_ht):
from transformer_lens.tools.analysis import direct_logit_attribution

n_tokens = gpt2_ht.to_tokens(PROMPT).shape[1]
res = direct_logit_attribution(
gpt2_ht, PROMPT, answer_tokens=CORRECT, unit="component", pos=None
)
assert res.attribution.ndim == 3 # [component, batch, pos]
assert res.attribution.shape[-1] == n_tokens

def test_top_returns_sorted_pairs(self, gpt2_ht):
from transformer_lens.tools.analysis import direct_logit_attribution

res = direct_logit_attribution(
gpt2_ht, PROMPT, answer_tokens=CORRECT, incorrect_tokens=INCORRECT, unit="head"
)
top = res.top(3)
assert len(top) == 3
values = [v for _, v in top]
assert values == sorted(values, reverse=True)


class TestDirectLogitAttributionBridge:
"""The point of #1263: DLA must work on TransformerBridge."""

@pytest.mark.parametrize("unit", ["component", "head"])
def test_decomposition_reconstructs_logit(self, gpt2_bridge_compat, unit):
_assert_complete_decomposition(gpt2_bridge_compat, unit)


class TestDirectLogitAttributionBridgeGuards:
"""Guards required for correctness on TransformerBridge (from PR #1316).

Without compatibility mode the projection direction is wrong on a Bridge and
DLA would silently return incorrect numbers — verify the explicit refusal.
"""

def test_non_compat_bridge_raises(self, gpt2_bridge):
from transformer_lens.tools.analysis import direct_logit_attribution

with pytest.raises(ValueError, match="compatibility mode"):
dla(bridge, [PROMPT], torch.tensor([[c]]))

def test_hybrid_layer_types_raises(self, distilgpt2_bridge_compat, monkeypatch):
bridge = distilgpt2_bridge_compat
monkeypatch.setattr(bridge, "layer_types", lambda: ["attn+mlp", "mamba+mlp"])
c = bridge.to_single_token(CORRECT)
with pytest.raises(NotImplementedError, match="hybrid"):
dla(bridge, [PROMPT], torch.tensor([[c]]))
direct_logit_attribution(gpt2_bridge, PROMPT, answer_tokens=CORRECT)


class TestDirectLogitAttributionValidation:
def test_invalid_unit_raises(self, gpt2_ht):
from transformer_lens.tools.analysis import direct_logit_attribution

with pytest.raises(ValueError, match="unit must be one of"):
direct_logit_attribution(gpt2_ht, PROMPT, answer_tokens=CORRECT, unit="neuron")

def test_missing_answer_tokens_raises(self, gpt2_ht):
from transformer_lens.tools.analysis import direct_logit_attribution

with pytest.raises(ValueError, match="answer_tokens is required"):
direct_logit_attribution(gpt2_ht, PROMPT)

def test_missing_input_and_cache_raises(self, gpt2_ht):
from transformer_lens.tools.analysis import direct_logit_attribution

with pytest.raises(ValueError, match="either `input`"):
direct_logit_attribution(gpt2_ht, answer_tokens=CORRECT)
34 changes: 17 additions & 17 deletions tests/unit/tools/test_direct_logit_attribution.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,53 @@
"""Unit tests for the Direct Logit Attribution guards and argument validation.
"""Unit tests for Direct Logit Attribution guards and argument validation.

These exercise the fast-failing checks (argument validation, the
These exercise the fast-failing checks (argument validation, the Bridge
compatibility-mode requirement, and the hybrid-architecture refusal) without
loading a real model, using a ``spec``-ed mock bridge so the checks are reached
before any forward pass.
loading a real modelusing a ``spec``-ed mock TransformerBridge so the checks
fire before any forward pass.
"""

from unittest.mock import MagicMock

import pytest
import torch

from transformer_lens.model_bridge import TransformerBridge
from transformer_lens.tools.analysis import dla
from transformer_lens.tools.analysis import direct_logit_attribution


def _mock_bridge(compatibility_mode=True, layer_types=("attn+mlp",)):
"""A mock TransformerBridge that satisfies isinstance/type checks."""
"""A mock TransformerBridge that satisfies isinstance checks."""
bridge = MagicMock(spec=TransformerBridge)
bridge.compatibility_mode = compatibility_mode
bridge.layer_types.return_value = list(layer_types)
return bridge


def test_prompt_answer_length_mismatch_raises():
def test_invalid_unit_raises():
bridge = _mock_bridge()
with pytest.raises(ValueError, match="matching row"):
dla(bridge, ["a", "b"], torch.tensor([[1]]))
with pytest.raises(ValueError, match="unit must be one of"):
direct_logit_attribution(bridge, "hi", answer_tokens=" world", unit="neuron")


def test_invalid_answer_columns_raises():
def test_missing_answer_tokens_raises():
bridge = _mock_bridge()
with pytest.raises(ValueError, match="columns"):
dla(bridge, ["a"], torch.tensor([[1, 2, 3]]))
with pytest.raises(ValueError, match="answer_tokens is required"):
direct_logit_attribution(bridge, "hi")


def test_requires_compatibility_mode():
bridge = _mock_bridge(compatibility_mode=False)
with pytest.raises(ValueError, match="compatibility mode"):
dla(bridge, ["a"], torch.tensor([[1]]))
direct_logit_attribution(bridge, "hi", answer_tokens=" world")


def test_rejects_hybrid_architecture():
bridge = _mock_bridge(layer_types=("attn+mlp", "mamba+mlp"))
with pytest.raises(NotImplementedError, match="hybrid"):
dla(bridge, ["a"], torch.tensor([[1]]))
direct_logit_attribution(bridge, "hi", answer_tokens=" world")


def test_dla_is_exported_from_analysis_package():
def test_direct_logit_attribution_is_exported_from_analysis_package():
from transformer_lens.tools import analysis

assert analysis.dla is dla
assert analysis.direct_logit_attribution is direct_logit_attribution
assert hasattr(analysis, "DirectLogitAttribution")
2 changes: 1 addition & 1 deletion transformer_lens/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
including the model registry for discovering compatible HuggingFace models.

Subpackages:
- analysis: Interpretability analyses such as Direct Logit Attribution
- analysis: High-level interpretability analyses (e.g. Direct Logit Attribution)
- model_registry: Tools for discovering and documenting supported models
"""

Expand Down
18 changes: 15 additions & 3 deletions transformer_lens/tools/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
"""Analysis tools for the TransformerBridge system."""
"""Analysis tools for TransformerLens.

from transformer_lens.tools.analysis.direct_logit_attribution import dla
This subpackage collects high-level, single-call interpretability analyses that
sit on top of the hook/cache system. They work with both ``HookedTransformer``
and the newer ``TransformerBridge`` (the two share the ``ActivationCache`` API).

__all__ = ["dla"]
Tools:
- direct_logit_attribution: Direct Logit Attribution (DLA) over components,
layers, or attention heads.
"""

from transformer_lens.tools.analysis.direct_logit_attribution import (
DirectLogitAttribution,
direct_logit_attribution,
)

__all__ = ["DirectLogitAttribution", "direct_logit_attribution"]
Loading
Loading