Skip to content

Conversation

@speediedan
Copy link

Firstly, thank you so much for building and maintaining TransformerLens - it's a seminal, foundational and invaluable component enabling increasingly vital open-source interpretability research!

This PR enhances the TransformerLens generate() API to support HuggingFace's ModelOutput abstraction, enabling better interoperability with the HuggingFace Transformers ecosystem while preserving full backward compatibility.

Note

This PR builds on the device-dtype-sync-fixes branch to ensure all tests pass cleanly before adding generation enhancements.

Motivation

Improved HuggingFace Transformers Interface Compatibility

HuggingFace Transformers uses ModelOutput dataclasses as the primary abstraction for model outputs. In v4.x, return_dict=True became the default, making ModelOutput the standard return type. The upcoming v5 continues this pattern while streamlining the output class hierarchy. By adopting ModelOutput in TransformerLens, we ensure:

  • Ecosystem alignment: Consistent interfaces with HuggingFace's generation API
  • Future-proofing: Compatibility with both v4.x and the upcoming v5 release
  • Developer ergonomics: Familiar .sequences and .logits attributes for users coming from HF

Enabling Expanded/Unanticipated Mechanistic Interpretability Workflows

It's convenient and efficient for downstream libraries to have access to generation logits (or other intermediate generation data HF generation interface provides). E.g., the pre-MVP package interpretune uses this capability for various world model analysis and collaborative attribution graph analysis workflows.

We propose adding an optional output_logits flag to the HookedTransformer/TransformerBridge generate() methods, allowing users to retrieve per-step logits alongside generated tokens in a ModelOutput format. To enable access to the full HuggingFace generation API, we also add a new hf_generate() method to TransformerBridge that forwards all generation parameters directly to the underlying HF model. Given the current prioritization of HookerTransformer/TransformerBridge parity, we keep HookedTransformer/TransformerBridge generate() as the main API for standard use cases while providing hf_generate() for users needing full HF generation features. We could optionally consider adding hf_generate() to HookedTransformer in the future if desired or in the future pivot generate() to support the full HuggingFace generation API as hf_generate() does.

Changes

1. HookedTransformer Generation Enhancement (transformer_lens/HookedTransformer.py)

Added output_logits flag to generate() method:

def generate(
    self,
    input: str | list[str] | torch.Tensor = "",
    max_new_tokens: int = 10,
    # ... existing parameters
    output_logits: bool = False,  # New parameter
    **generation_kwargs,
) -> Union[str, List[str], torch.Tensor, Any]:  # Any for ModelOutput

Behavior:

  • When output_logits=True: Returns GenerateDecoderOnlyOutput (or ModelOutput fallback) with:
    • .sequences: Generated token IDs [batch, seq_len]
    • .logits: Tuple of per-step logits, each [batch, vocab_size]
  • When output_logits=False (default): Returns original format (str, list, or tensor)

Implementation:

# During generation, optionally collect logits
logits_seq_list: Optional[List[torch.Tensor]] = [] if output_logits else None

for _ in range(max_new_tokens):
    logits = self(current_tokens, ...)
    final_logits = logits[:, -1, :]
    
    if output_logits:
        logits_seq_list.append(final_logits.unsqueeze(1))
    
    # ... sampling logic

# Return ModelOutput if requested
if output_logits:
    from transformers.generation.utils import GenerateDecoderOnlyOutput
    return GenerateDecoderOnlyOutput(
        sequences=output_tokens,
        logits=tuple(logits_seq_list)
    )

Compatibility handling:

  • Silently ignores return_dict_in_generate (TransformerLens doesn't need this flag)
  • Warns on unsupported HF kwargs (output_scores, output_attentions, etc.)

2. TransformerBridge Generation Enhancements (transformer_lens/model_bridge/bridge.py)

Enhanced generate() method:

Added output_logits parameter to match HookedTransformer API:

def generate(
    self,
    input: str | list[str] | torch.Tensor = "",
    max_new_tokens: int = 10,
    # ... existing parameters
    output_logits: bool = False,  # New parameter
) -> str | list[str] | torch.Tensor | Any:

Collects logits during generation and returns GenerateDecoderOnlyOutput when requested.

New hf_generate() method:

Full HuggingFace generation API pass-through:

def hf_generate(
    self,
    input: str | list[str] | torch.Tensor = "",
    max_new_tokens: int = 10,
    # ... standard parameters
    **generation_kwargs,  # Any HF generation parameter
) -> str | list[str] | torch.Tensor | Any:
    """Generate using underlying HuggingFace model with full HF API support.
    
    Forwards all generation parameters (output_scores, output_logits,
    output_attentions, output_hidden_states) to the HF model.
    
    Use this when you need full HuggingFace generation features.
    For standard TransformerLens-compatible generation, use generate().
    """

Features:

  • Auto-sets return_dict_in_generate=True when HF dict flags are present
  • Handles ModelOutput returns with .sequences attribute extraction
  • Full pass-through of HF generation parameters
  • Consistent return type handling based on return_type parameter

Example usage:

# Get full HF ModelOutput with logits and attentions
result = bridge.hf_generate(
    "Hello world",
    max_new_tokens=5,
    output_logits=True,
    output_attentions=True,
    return_dict_in_generate=True
)
print(result.sequences)     # Generated tokens
print(result.logits)        # Logits for each step
print(result.attentions)    # Attention weights

3. BlockBridge Tuple Return Format Fix (transformer_lens/model_bridge/generalized_components/block.py)

Fixed a batch dimension preservation issue:

# Before: Could return bare tensor, causing batch collapse
if len(output) == 1:
    return first  # [seq, hidden] - batch dimension lost!

# After: Always return tuple for HF compatibility
if len(output) == 1:
    return (first,)  # ([batch, seq, hidden],) - batch preserved!

Why this matters:

HuggingFace's GPT2Model does hidden_states = outputs[0] in its forward loop. When BlockBridge returns a bare tensor instead of a tuple:

  1. The indexing operation outputs[0] treats the tensor as a sequence
  2. Selects the first element along the batch dimension
  3. Collapses batch from [batch, seq, hidden] to [seq, hidden]
  4. Causes dimension mismatches in subsequent blocks

By always returning tuples, BlockBridge maintains consistency with HF's expected format and preserves batch dimensions through the generation pipeline.

4. Type Annotations

Used Any type for ModelOutput returns due to beartype's forward reference resolution limitations:

from typing import Any  # For ModelOutput - see beartype issue #546

def generate(...) -> str | list[str] | torch.Tensor | Any:
    # Any for transformers.utils.ModelOutput
    # Using Any due to beartype's forward reference resolution limitations.
    # See: https://github.com/beartype/beartype/issues/546

Standard TYPE_CHECKING imports don't work with from __future__ import annotations and beartype's runtime checks, so we use Any with explanatory comments.

5. Comprehensive Test Coverage (tests/integration/test_generation_compatibility.py)

New test file with 510 lines covering:

HookedTransformer Tests:

  • test_generate_with_output_logits_returns_modeloutput: Validates ModelOutput structure
  • test_generate_without_output_logits_returns_normal: Ensures backward compatibility
  • test_generate_output_logits_with_return_type_tokens: Tests token return format
  • test_return_dict_in_generate_silently_ignored: Validates compatibility flag handling
  • test_unsupported_hf_flags_trigger_warning: Ensures proper warnings
  • test_logits_consistency_with_forward_pass: Validates logit correctness
  • test_output_logits_batch_generation: Tests batch processing

TransformerBridge Tests:

  • test_generate_with_output_logits_returns_modeloutput: Validates Bridge ModelOutput
  • test_generate_without_output_logits_returns_normal: Ensures backward compatibility
  • test_generate_output_logits_batch: Tests batch generation

TransformerBridge HF API Tests:

  • test_hf_generate_with_output_scores: Validates HF flag forwarding
  • test_hf_generate_sets_return_dict_in_generate: Tests auto-flag behavior
  • test_hf_generate_multiple_flags_simultaneously: Tests complex flag combinations
  • test_hf_generate_return_type_tokens: Tests token return format
  • test_hf_generate_flags_coerced_to_bool: Validates flag coercion
  • test_hf_generate_batch_generation: Tests batch processing

Backward Compatibility Tests:

  • test_hooked_transformer_basic_generation_unchanged: Validates HookedTransformer compatibility
  • test_bridge_basic_generation_unchanged: Validates Bridge compatibility
  • test_hooked_transformer_return_types_unchanged: Tests all return_type options

BlockBridge Batch Compatibility Tests:

  • test_block_bridge_batched_generation_compatibility: Critical test validating:
    • BlockBridge returns tuples (not bare tensors)
    • Batch dimensions preserved through multi-block generation
    • Independent batch items remain independent

Test Results

All tests pass with the enhancements:

========= 1087 passed, 63 skipped, 152 warnings in 1014.55s (0:16:54) ==========

Including all new generation compatibility tests.

Backward Compatibility

Zero breaking changes - all existing usage patterns remain unchanged:

Default behavior unchanged:

# Still returns string/list/tensor as before
result = model.generate("Hello", max_new_tokens=5)
assert isinstance(result, str)

All return_type options work:

# return_type='str', 'tokens', 'embeds' all unchanged
result = model.generate(prompt, return_type="tokens")
assert isinstance(result, torch.Tensor)

Existing tests continue passing:

  • All 1087 existing tests pass without modification
  • No changes required to existing user code

Usage Examples

Basic generation (unchanged):

New: Get logits for analysis:

result = model.generate(
    "The quick brown",
    max_new_tokens=5,
    output_logits=True,
    do_sample=False
)

# Access HF-style attributes
print(result.sequences.shape)  # [batch, seq_len]
print(len(result.logits))      # 5 (one per generated token)
print(result.logits[0].shape)  # [batch, vocab_size]

# Decode generated text
text = model.tokenizer.decode(result.sequences[0])

Full HF API (TransformerBridge only):

from transformer_lens.model_bridge import TransformerBridge

bridge = TransformerBridge.boot_transformers("gpt2")

# Use full HuggingFace generation API
result = bridge.hf_generate(
    "Hello world",
    max_new_tokens=5,
    output_logits=True,
    output_attentions=True,
    output_hidden_states=True,
    return_dict_in_generate=True
)

# Access all HF ModelOutput attributes
print(result.sequences)        # Generated tokens
print(result.logits)           # Per-step logits
print(result.attentions)       # Attention weights (if available)
print(result.hidden_states)    # Hidden states (if available)

Batch generation:

prompts = ["Hello", "Goodbye"]
result = model.generate(
    prompts,
    max_new_tokens=3,
    output_logits=True
)

print(result.sequences.shape)  # [2, seq_len]
print(result.logits[0].shape)  # [2, vocab_size]

Benefits

✅ HuggingFace Ecosystem Compatibility: Seamless integration with HF-based tools and evaluation frameworks

✅ Mechanistic Interpretability: Enable more efficient logit-level analysis during generation for attribution studies

✅ Developer Experience: Familiar .sequences and .logits attributes for users from HF

✅ Performance: Collect logits during generation (no separate forward pass needed)

✅ Flexibility: Two APIs (generate() for TL workflows, hf_generate() for full HF features)

✅ Robustness: BlockBridge tuple fix ensures batch dimensions preserved

✅ Future-Proof: Compatible with HF Transformers v4.x and upcoming v5

Implementation Notes

Why Any for ModelOutput types?

Beartype's runtime type checking doesn't support forward references well with from __future__ import annotations. Using TYPE_CHECKING imports causes runtime failures. We use Any with comprehensive inline documentation pointing to beartype issue #546.

Why two generation methods in TransformerBridge?

  • generate(): TransformerLens-compatible API, limited HF features
  • hf_generate(): Full HuggingFace API pass-through with all generation parameters supported.

This separation keeps the main API simple while providing users full HF generation API capabilities if desired.

Why always return tuples from BlockBridge?

HuggingFace models expect transformer blocks to return tuples. Returning bare tensors causes the outputs[0] indexing operation (used in a number of models) to select along the batch dimension instead of extracting the first tuple element, collapsing batch size. Always returning tuples improves HF compatibility.

Type of change

  • New feature (non-breaking change which adds functionality)

Summary

This PR adds HuggingFace ModelOutput support to TransformerLens generation APIs, enabling more seamless HF ecosystem integration and novel/unanticipated mechanistic interpretability workflows. The enhancements maintain full backward compatibility while providing optional enhanced intermediate generation artifact collection.

Thank you so much to the TransformerLens maintainers/community for making this foundationally valuable contribution to the open-source interpretability ecosystem!

I'm using it extensively in a downstream analysis framework I'm building and couldn't appreciate your work more! I should also thank Claude 1 for its help (somewhat obviously) creating the scaffolding of this PR and for generating/refining the unit tests.

Footnotes

  1. Transitively thanking the authors of TransformerLens again for my ability to thank Claude, they really deserve the attribution :)

…xts.

We also update .gitignore to exclude .env (commonly used local file exclution), e.g. to allow collaborators to add their on HF_TOKEN for test suite

Core Fixes:
-----------

transformer_lens/components/abstract_attention.py:
  - Replace pattern.to(self.cfg.dtype) with pattern.to(v.dtype) to handle cases
    where tensors are upcast to float32 for numerical stability while cfg.dtype
    remains float16/bfloat16
  - Add explicit device/dtype synchronization for output projection:
    * Move weights (W_O) and bias (b_O) to match input device (z.device)
    * Ensure z matches weight dtype before final linear operation

transformer_lens/model_bridge/bridge.py:
  - Replace direct original_model.to() call with move_to_and_update_config()
    utility to ensure:
    * All bridge components (not just original_model) are moved to target device
    * cfg.device and cfg.dtype stay synchronized with actual model state
    * Multi-GPU cache tensors remain on correct devices

Test Fixes:
-----------

tests/acceptance/test_hooked_encoder.py:
  - Fix test_cuda() to use correct fixture name 'tokens' instead of 'mlm_tokens'

tests/acceptance/test_multi_gpu.py:
  - Update test_cache_device() to pass torch.device("cpu") instead of string
    "cpu" for proper device type validation

tests/unit/components/test_attention.py:
  - Add test_attention_forward_half_precisions() to validate attention works
    correctly with bfloat16/float16 dtypes on CUDA devices

tests/unit/factored_matrix/test_multiply_by_scalar.py:
  - Add test IDs to parametrize decorators to avoid pytest cache issues when
    random numbers appear in test names

Tests Fixed by This Commit:
---------------------------
- tests/acceptance/test_multi_gpu.py::test_cache_device
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2]
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1]
- tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2]
Enhance both HookedTransformer and TransformerBridge generate() methods to support
HuggingFace-style generation outputs for improved interoperability:

- HookedTransformer: Add output_logits flag to return ModelOutput with sequences and logits
- TransformerBridge: Forward HF dict flags (output_scores, output_logits, output_attentions,
  output_hidden_states) to underlying HF model
- Maintain full backward compatibility with existing generate() usage patterns
- Add 17 integration tests covering ModelOutput behavior and flag handling
- Use Any type for ModelOutput returns due to beartype forward reference limitations (TransformerLensOrg#546)

This enables downstream libraries to leverage HF's standard generation
output format for advanced mechanistic interpretability workflows.
Fix batch dimension bug where BlockBridge.forward() returned bare tensors
instead of tuples, causing HuggingFace generation to incorrectly index into
batch dimensions.

Changes:
- BlockBridge: Always return (first,) tuple for single-element outputs
- Add regression test for batch dimension preservation during generation
- Rename test_generation_modeloutput.py -> test_generation_compatibility.py

The bug manifested when HF's GPT2Model did `hidden_states = outputs[0]`
expecting a tuple but got a tensor, causing it to index the batch dimension
instead of extracting the first tuple element.

Fixes batch generation for TransformerBridge with multiple blocks.
Enhance to() method to properly handle both device and dtype arguments in
all supported PyTorch formats (positional, keyword, combined). Separately
invoke move_to_and_update_config for device/dtype to update cfg while
delegating the actual tensor movement to original_model.to() with original
args/kwargs. This ensures TransformerBridge respects standard PyTorch
behavior for model.to() calls.
Enhance HF compatibility of HookedTransformer and TransformerBridge

Core changes:
- HookedTransformer: Add output_logits flag returning ModelOutput with sequences and logits
- TransformerBridge.generate(): Add output_logits flag for consistency with HookedTransformer
- TransformerBridge.hf_generate(): New method for full HF API passthrough (output_scores,
  output_logits, output_attentions, output_hidden_states)
- Maintain unified API: Both classes share same generate() signature per upstream design
- hf_generate() for users needing full HF features, evaluate possibility of making it the default generate option in the future.

Architecture:
- Respects API consistency vision (unified generate() across both HookedTransformer/TransformerBridge classes)
- Adds escape hatch for advanced HF use cases without compromising clean API
- Clear separation: .generate() = TL-style, .hf_generate() = full HF

Testing:
- Comprehensive test suite (20 tests) covering ModelOutput behavior and flag handling
- Full backward compatibility maintained with existing generate() usage

This enables downstream libraries to leverage HF's standard generation
output format for advanced workflows while maintaining
TransformerLens's clean, consistent API.
@speediedan speediedan marked this pull request as ready for review November 29, 2025 22:47
Copilot AI review requested due to automatic review settings November 29, 2025 22:47
Copilot finished reviewing on behalf of speediedan November 29, 2025 22:50
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds HuggingFace ModelOutput support to TransformerLens generation APIs, enabling better interoperability with the HuggingFace ecosystem while maintaining full backward compatibility. The changes introduce an optional output_logits parameter to collect per-step logits during generation and return them in a HF-compatible format, along with a new hf_generate() method for full HF generation API access.

Key Changes:

  • Added output_logits parameter to generate() methods in both HookedTransformer and TransformerBridge to return ModelOutput objects with sequences and per-step logits
  • Introduced TransformerBridge.hf_generate() method for direct access to the full HuggingFace generation API
  • Fixed BlockBridge to always return tuples (not bare tensors) to maintain batch dimension consistency with HF models
  • Enhanced device/dtype handling in attention components for better multi-precision support
  • Added comprehensive test coverage (510 lines) for generation compatibility and batch processing

Reviewed changes

Copilot reviewed 10 out of 11 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
transformer_lens/model_bridge/generalized_components/block.py Fixed BlockBridge to always return tuples for HF compatibility, preventing batch dimension collapse
transformer_lens/model_bridge/bridge.py Added output_logits parameter to generate(), implemented hf_generate() for full HF API access, and enhanced to() method for device/dtype handling
transformer_lens/components/abstract_attention.py Improved device/dtype synchronization in attention operations for multi-precision support
transformer_lens/HookedTransformer.py Added output_logits support with ModelOutput returns and kwargs handling for HF compatibility
tests/integration/test_generation_compatibility.py Comprehensive new test suite (510 lines) covering ModelOutput returns, batch generation, and backward compatibility
tests/unit/factored_matrix/test_multiply_by_scalar.py Added test IDs for better parametrize readability
tests/unit/components/test_attention.py Added half-precision (bfloat16/float16) attention tests
tests/acceptance/test_multi_gpu.py Fixed device parameter to use torch.device object for consistency
tests/acceptance/test_hooked_encoder.py Fixed test parameter name from mlm_tokens to tokens
.vscode/settings.json Changed pytest path from "transformer_lens" to "tests"
.gitignore Added .env to ignored files

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant