-
Notifications
You must be signed in to change notification settings - Fork 477
Add HuggingFace ModelOutput support to TransformerLens generation API #1144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-3.x-folding
Are you sure you want to change the base?
Add HuggingFace ModelOutput support to TransformerLens generation API #1144
Conversation
…xts.
We also update .gitignore to exclude .env (commonly used local file exclution), e.g. to allow collaborators to add their on HF_TOKEN for test suite
Core Fixes:
-----------
transformer_lens/components/abstract_attention.py:
- Replace pattern.to(self.cfg.dtype) with pattern.to(v.dtype) to handle cases
where tensors are upcast to float32 for numerical stability while cfg.dtype
remains float16/bfloat16
- Add explicit device/dtype synchronization for output projection:
* Move weights (W_O) and bias (b_O) to match input device (z.device)
* Ensure z matches weight dtype before final linear operation
transformer_lens/model_bridge/bridge.py:
- Replace direct original_model.to() call with move_to_and_update_config()
utility to ensure:
* All bridge components (not just original_model) are moved to target device
* cfg.device and cfg.dtype stay synchronized with actual model state
* Multi-GPU cache tensors remain on correct devices
Test Fixes:
-----------
tests/acceptance/test_hooked_encoder.py:
- Fix test_cuda() to use correct fixture name 'tokens' instead of 'mlm_tokens'
tests/acceptance/test_multi_gpu.py:
- Update test_cache_device() to pass torch.device("cpu") instead of string
"cpu" for proper device type validation
tests/unit/components/test_attention.py:
- Add test_attention_forward_half_precisions() to validate attention works
correctly with bfloat16/float16 dtypes on CUDA devices
tests/unit/factored_matrix/test_multiply_by_scalar.py:
- Add test IDs to parametrize decorators to avoid pytest cache issues when
random numbers appear in test names
Tests Fixed by This Commit:
---------------------------
- tests/acceptance/test_multi_gpu.py::test_cache_device
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2]
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1]
- tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2]
Enhance both HookedTransformer and TransformerBridge generate() methods to support HuggingFace-style generation outputs for improved interoperability: - HookedTransformer: Add output_logits flag to return ModelOutput with sequences and logits - TransformerBridge: Forward HF dict flags (output_scores, output_logits, output_attentions, output_hidden_states) to underlying HF model - Maintain full backward compatibility with existing generate() usage patterns - Add 17 integration tests covering ModelOutput behavior and flag handling - Use Any type for ModelOutput returns due to beartype forward reference limitations (TransformerLensOrg#546) This enables downstream libraries to leverage HF's standard generation output format for advanced mechanistic interpretability workflows.
Fix batch dimension bug where BlockBridge.forward() returned bare tensors instead of tuples, causing HuggingFace generation to incorrectly index into batch dimensions. Changes: - BlockBridge: Always return (first,) tuple for single-element outputs - Add regression test for batch dimension preservation during generation - Rename test_generation_modeloutput.py -> test_generation_compatibility.py The bug manifested when HF's GPT2Model did `hidden_states = outputs[0]` expecting a tuple but got a tensor, causing it to index the batch dimension instead of extracting the first tuple element. Fixes batch generation for TransformerBridge with multiple blocks.
Enhance to() method to properly handle both device and dtype arguments in all supported PyTorch formats (positional, keyword, combined). Separately invoke move_to_and_update_config for device/dtype to update cfg while delegating the actual tensor movement to original_model.to() with original args/kwargs. This ensures TransformerBridge respects standard PyTorch behavior for model.to() calls.
Enhance HF compatibility of HookedTransformer and TransformerBridge Core changes: - HookedTransformer: Add output_logits flag returning ModelOutput with sequences and logits - TransformerBridge.generate(): Add output_logits flag for consistency with HookedTransformer - TransformerBridge.hf_generate(): New method for full HF API passthrough (output_scores, output_logits, output_attentions, output_hidden_states) - Maintain unified API: Both classes share same generate() signature per upstream design - hf_generate() for users needing full HF features, evaluate possibility of making it the default generate option in the future. Architecture: - Respects API consistency vision (unified generate() across both HookedTransformer/TransformerBridge classes) - Adds escape hatch for advanced HF use cases without compromising clean API - Clear separation: .generate() = TL-style, .hf_generate() = full HF Testing: - Comprehensive test suite (20 tests) covering ModelOutput behavior and flag handling - Full backward compatibility maintained with existing generate() usage This enables downstream libraries to leverage HF's standard generation output format for advanced workflows while maintaining TransformerLens's clean, consistent API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds HuggingFace ModelOutput support to TransformerLens generation APIs, enabling better interoperability with the HuggingFace ecosystem while maintaining full backward compatibility. The changes introduce an optional output_logits parameter to collect per-step logits during generation and return them in a HF-compatible format, along with a new hf_generate() method for full HF generation API access.
Key Changes:
- Added
output_logitsparameter togenerate()methods in bothHookedTransformerandTransformerBridgeto returnModelOutputobjects with sequences and per-step logits - Introduced
TransformerBridge.hf_generate()method for direct access to the full HuggingFace generation API - Fixed BlockBridge to always return tuples (not bare tensors) to maintain batch dimension consistency with HF models
- Enhanced device/dtype handling in attention components for better multi-precision support
- Added comprehensive test coverage (510 lines) for generation compatibility and batch processing
Reviewed changes
Copilot reviewed 10 out of 11 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
transformer_lens/model_bridge/generalized_components/block.py |
Fixed BlockBridge to always return tuples for HF compatibility, preventing batch dimension collapse |
transformer_lens/model_bridge/bridge.py |
Added output_logits parameter to generate(), implemented hf_generate() for full HF API access, and enhanced to() method for device/dtype handling |
transformer_lens/components/abstract_attention.py |
Improved device/dtype synchronization in attention operations for multi-precision support |
transformer_lens/HookedTransformer.py |
Added output_logits support with ModelOutput returns and kwargs handling for HF compatibility |
tests/integration/test_generation_compatibility.py |
Comprehensive new test suite (510 lines) covering ModelOutput returns, batch generation, and backward compatibility |
tests/unit/factored_matrix/test_multiply_by_scalar.py |
Added test IDs for better parametrize readability |
tests/unit/components/test_attention.py |
Added half-precision (bfloat16/float16) attention tests |
tests/acceptance/test_multi_gpu.py |
Fixed device parameter to use torch.device object for consistency |
tests/acceptance/test_hooked_encoder.py |
Fixed test parameter name from mlm_tokens to tokens |
.vscode/settings.json |
Changed pytest path from "transformer_lens" to "tests" |
.gitignore |
Added .env to ignored files |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Firstly, thank you so much for building and maintaining TransformerLens - it's a seminal, foundational and invaluable component enabling increasingly vital open-source interpretability research!
This PR enhances the TransformerLens
generate()API to support HuggingFace'sModelOutputabstraction, enabling better interoperability with the HuggingFace Transformers ecosystem while preserving full backward compatibility.Note
This PR builds on the
device-dtype-sync-fixesbranch to ensure all tests pass cleanly before adding generation enhancements.Motivation
Improved HuggingFace Transformers Interface Compatibility
HuggingFace Transformers uses
ModelOutputdataclasses as the primary abstraction for model outputs. In v4.x,return_dict=Truebecame the default, makingModelOutputthe standard return type. The upcoming v5 continues this pattern while streamlining the output class hierarchy. By adoptingModelOutputin TransformerLens, we ensure:.sequencesand.logitsattributes for users coming from HFEnabling Expanded/Unanticipated Mechanistic Interpretability Workflows
It's convenient and efficient for downstream libraries to have access to generation logits (or other intermediate generation data HF generation interface provides). E.g., the pre-MVP package interpretune uses this capability for various world model analysis and collaborative attribution graph analysis workflows.
We propose adding an optional
output_logitsflag to the HookedTransformer/TransformerBridgegenerate()methods, allowing users to retrieve per-step logits alongside generated tokens in aModelOutputformat. To enable access to the full HuggingFace generation API, we also add a newhf_generate()method to TransformerBridge that forwards all generation parameters directly to the underlying HF model. Given the current prioritization of HookerTransformer/TransformerBridge parity, we keep HookedTransformer/TransformerBridgegenerate()as the main API for standard use cases while providinghf_generate()for users needing full HF generation features. We could optionally consider addinghf_generate()to HookedTransformer in the future if desired or in the future pivotgenerate()to support the full HuggingFace generation API ashf_generate()does.Changes
1. HookedTransformer Generation Enhancement (
transformer_lens/HookedTransformer.py)Added
output_logitsflag togenerate()method:Behavior:
output_logits=True: ReturnsGenerateDecoderOnlyOutput(orModelOutputfallback) with:.sequences: Generated token IDs[batch, seq_len].logits: Tuple of per-step logits, each[batch, vocab_size]output_logits=False(default): Returns original format (str, list, or tensor)Implementation:
Compatibility handling:
return_dict_in_generate(TransformerLens doesn't need this flag)output_scores,output_attentions, etc.)2. TransformerBridge Generation Enhancements (
transformer_lens/model_bridge/bridge.py)Enhanced
generate()method:Added
output_logitsparameter to match HookedTransformer API:Collects logits during generation and returns
GenerateDecoderOnlyOutputwhen requested.New
hf_generate()method:Full HuggingFace generation API pass-through:
Features:
return_dict_in_generate=Truewhen HF dict flags are presentModelOutputreturns with.sequencesattribute extractionreturn_typeparameterExample usage:
3. BlockBridge Tuple Return Format Fix (
transformer_lens/model_bridge/generalized_components/block.py)Fixed a batch dimension preservation issue:
Why this matters:
HuggingFace's GPT2Model does
hidden_states = outputs[0]in its forward loop. When BlockBridge returns a bare tensor instead of a tuple:outputs[0]treats the tensor as a sequence[batch, seq, hidden]to[seq, hidden]By always returning tuples, BlockBridge maintains consistency with HF's expected format and preserves batch dimensions through the generation pipeline.
4. Type Annotations
Used
Anytype forModelOutputreturns due to beartype's forward reference resolution limitations:Standard
TYPE_CHECKINGimports don't work withfrom __future__ import annotationsand beartype's runtime checks, so we useAnywith explanatory comments.5. Comprehensive Test Coverage (
tests/integration/test_generation_compatibility.py)New test file with 510 lines covering:
HookedTransformer Tests:
test_generate_with_output_logits_returns_modeloutput: Validates ModelOutput structuretest_generate_without_output_logits_returns_normal: Ensures backward compatibilitytest_generate_output_logits_with_return_type_tokens: Tests token return formattest_return_dict_in_generate_silently_ignored: Validates compatibility flag handlingtest_unsupported_hf_flags_trigger_warning: Ensures proper warningstest_logits_consistency_with_forward_pass: Validates logit correctnesstest_output_logits_batch_generation: Tests batch processingTransformerBridge Tests:
test_generate_with_output_logits_returns_modeloutput: Validates Bridge ModelOutputtest_generate_without_output_logits_returns_normal: Ensures backward compatibilitytest_generate_output_logits_batch: Tests batch generationTransformerBridge HF API Tests:
test_hf_generate_with_output_scores: Validates HF flag forwardingtest_hf_generate_sets_return_dict_in_generate: Tests auto-flag behaviortest_hf_generate_multiple_flags_simultaneously: Tests complex flag combinationstest_hf_generate_return_type_tokens: Tests token return formattest_hf_generate_flags_coerced_to_bool: Validates flag coerciontest_hf_generate_batch_generation: Tests batch processingBackward Compatibility Tests:
test_hooked_transformer_basic_generation_unchanged: Validates HookedTransformer compatibilitytest_bridge_basic_generation_unchanged: Validates Bridge compatibilitytest_hooked_transformer_return_types_unchanged: Tests all return_type optionsBlockBridge Batch Compatibility Tests:
test_block_bridge_batched_generation_compatibility: Critical test validating:Test Results
All tests pass with the enhancements:
Including all new generation compatibility tests.
Backward Compatibility
✅ Zero breaking changes - all existing usage patterns remain unchanged:
Default behavior unchanged:
All return_type options work:
Existing tests continue passing:
Usage Examples
Basic generation (unchanged):
New: Get logits for analysis:
Full HF API (TransformerBridge only):
Batch generation:
Benefits
✅ HuggingFace Ecosystem Compatibility: Seamless integration with HF-based tools and evaluation frameworks
✅ Mechanistic Interpretability: Enable more efficient logit-level analysis during generation for attribution studies
✅ Developer Experience: Familiar
.sequencesand.logitsattributes for users from HF✅ Performance: Collect logits during generation (no separate forward pass needed)
✅ Flexibility: Two APIs (
generate()for TL workflows,hf_generate()for full HF features)✅ Robustness: BlockBridge tuple fix ensures batch dimensions preserved
✅ Future-Proof: Compatible with HF Transformers v4.x and upcoming v5
Implementation Notes
Why
Anyfor ModelOutput types?Beartype's runtime type checking doesn't support forward references well with
from __future__ import annotations. UsingTYPE_CHECKINGimports causes runtime failures. We useAnywith comprehensive inline documentation pointing to beartype issue #546.Why two generation methods in TransformerBridge?
generate(): TransformerLens-compatible API, limited HF featureshf_generate(): Full HuggingFace API pass-through with all generation parameters supported.This separation keeps the main API simple while providing users full HF generation API capabilities if desired.
Why always return tuples from BlockBridge?
HuggingFace models expect transformer blocks to return tuples. Returning bare tensors causes the
outputs[0]indexing operation (used in a number of models) to select along the batch dimension instead of extracting the first tuple element, collapsing batch size. Always returning tuples improves HF compatibility.Type of change
Summary
This PR adds HuggingFace
ModelOutputsupport to TransformerLens generation APIs, enabling more seamless HF ecosystem integration and novel/unanticipated mechanistic interpretability workflows. The enhancements maintain full backward compatibility while providing optional enhanced intermediate generation artifact collection.Thank you so much to the TransformerLens maintainers/community for making this foundationally valuable contribution to the open-source interpretability ecosystem!
I'm using it extensively in a downstream analysis framework I'm building and couldn't appreciate your work more! I should also thank Claude 1 for its help (somewhat obviously) creating the scaffolding of this PR and for generating/refining the unit tests.
Footnotes
Transitively thanking the authors of TransformerLens again for my ability to thank Claude, they really deserve the attribution :) ↩