-
Notifications
You must be signed in to change notification settings - Fork 478
Improve TransformerBridge optimizer compatibility via dual PyTorch/TransformerLens parameter access API #1143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-3.x-folding
Are you sure you want to change the base?
Improve TransformerBridge optimizer compatibility via dual PyTorch/TransformerLens parameter access API #1143
Conversation
…xts.
We also update .gitignore to exclude .env (commonly used local file exclution), e.g. to allow collaborators to add their on HF_TOKEN for test suite
Core Fixes:
-----------
transformer_lens/components/abstract_attention.py:
- Replace pattern.to(self.cfg.dtype) with pattern.to(v.dtype) to handle cases
where tensors are upcast to float32 for numerical stability while cfg.dtype
remains float16/bfloat16
- Add explicit device/dtype synchronization for output projection:
* Move weights (W_O) and bias (b_O) to match input device (z.device)
* Ensure z matches weight dtype before final linear operation
transformer_lens/model_bridge/bridge.py:
- Replace direct original_model.to() call with move_to_and_update_config()
utility to ensure:
* All bridge components (not just original_model) are moved to target device
* cfg.device and cfg.dtype stay synchronized with actual model state
* Multi-GPU cache tensors remain on correct devices
Test Fixes:
-----------
tests/acceptance/test_hooked_encoder.py:
- Fix test_cuda() to use correct fixture name 'tokens' instead of 'mlm_tokens'
tests/acceptance/test_multi_gpu.py:
- Update test_cache_device() to pass torch.device("cpu") instead of string
"cpu" for proper device type validation
tests/unit/components/test_attention.py:
- Add test_attention_forward_half_precisions() to validate attention works
correctly with bfloat16/float16 dtypes on CUDA devices
tests/unit/factored_matrix/test_multiply_by_scalar.py:
- Add test IDs to parametrize decorators to avoid pytest cache issues when
random numbers appear in test names
Tests Fixed by This Commit:
---------------------------
- tests/acceptance/test_multi_gpu.py::test_cache_device
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2]
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1]
- tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2]
…rameter/named_parameter generator API Add parameters()/named_parameters() returning leaf tensors for PyTorch optimizers and tl_parameters()/tl_named_parameters() for TransformerLens-style format. Update SVDInterpreter to use tl_parameters(). Add comprehensive tests.
Enhance to() method to properly handle both device and dtype arguments in all supported PyTorch formats (positional, keyword, combined). Separately invoke move_to_and_update_config for device/dtype to update cfg while delegating the actual tensor movement to original_model.to() with original args/kwargs. This ensures TransformerBridge respects standard PyTorch behavior for model.to() calls.
Add comprehensive test validating optimizer parity between TransformerBridge and HookedTransformer across optimization steps (1, 10). Test checks: - Initial forward pass alignment (logits, loss) - Parameter updates (unembed weights remain close) - Post-update forward pass convergence Uses NamedTuple for threshold definitions with magnitude-based validation thresholds (1e-3, 1e-5, etc.) for cleaner readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR resolves a critical bug where TransformerBridge.parameters() returned non-leaf tensors that couldn't be optimized by PyTorch optimizers. The solution introduces a dual API approach: standard PyTorch methods (parameters(), named_parameters()) delegate to the underlying HuggingFace model, while new TransformerLens-specific methods (tl_parameters(), tl_named_parameters()) provide processed parameters for analysis tools.
Key Changes:
- Introduced dual parameter access API to support both optimization and analysis workflows
- Updated
SVDInterpreterto usetl_parameters()for TransformerBridge compatibility - Enhanced device/dtype handling in
TransformerBridge.to()method and attention components
Reviewed changes
Copilot reviewed 10 out of 12 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
transformer_lens/model_bridge/bridge.py |
Added parameters(), named_parameters(), tl_parameters(), and tl_named_parameters() methods; enhanced to() method for improved device/dtype handling |
transformer_lens/SVDInterpreter.py |
Updated to use tl_parameters() for TransformerBridge models while maintaining backward compatibility with HookedTransformer |
transformer_lens/components/abstract_attention.py |
Improved device and dtype consistency for attention operations by consolidating device/dtype conversions |
tests/unit/model_bridge/test_optimizer_compatibility.py |
Comprehensive unit tests validating leaf tensor guarantees, optimizer compatibility, and dual API correctness |
tests/integration/model_bridge/test_optimizer_compatibility.py |
Integration tests validating end-to-end optimization workflows and multi-step training parity |
tests/unit/factored_matrix/test_multiply_by_scalar.py |
Added test IDs for better test identification |
tests/unit/components/test_attention.py |
Added half-precision (bfloat16/float16) attention tests for CUDA |
tests/acceptance/test_multi_gpu.py |
Fixed device specification to use torch.device object |
tests/acceptance/test_hooked_encoder.py |
Fixed test fixture reference from mlm_tokens to tokens |
.vscode/settings.json |
Updated pytest args to test tests directory instead of transformer_lens |
.gitignore |
Added .env to ignore list |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Thanks so much for this! I'll review your prs relatively quickly, and likely get it all folded into the coming beta release |
Resolves #1141
Firstly, thank you so much for building and maintaining TransformerLens - it's a seminal, foundational and invaluable component enabling increasingly vital open-source interpretability research!
This PR fixes an issue where
TransformerBridge.parameters()returns non-leaf tensors that cannot be optimized, while preserving TransformerLens-style parameter access for analysis tools.Note
This PR builds on the
device-dtype-sync-fixesbranch to ensure all tests pass cleanly before adding these optimizer compatibility enhancements.The Issue
TransformerBridge's
parameters()method returned non-leaf tensors (created byeinops.rearrange()), breaking PyTorch's fundamental requirement that optimizer parameters be leaf tensors. This prevented users from fine-tuning or training TransformerBridge models:The Solution
Dual Parameter API - Adhere to standard PyTorch parameter()/named_parameters() semantics while providing explicit TransformerLens-style parameter access methods:
1. PyTorch-style API - Follows standard PyTorch semantics while delegating to underlying HuggingFace model:
2. TransformerLens-style API - preserves TL conventions:
3. Update SVDInterpreter - Use
tl_parameters()for TransformerBridge:Key Design Decisions
Rather than using the dual API approach, we considered several alternatives including monolithic APIs with flags or instance configuration modes. However, these had significant drawbacks, among them:
Test Coverage
New test files:
tests/unit/model_bridge/test_optimizer_compatibility.py- Unit tests for parameter API correctnesstests/integration/model_bridge/test_optimizer_compatibility.py- Integration tests for end-to-end workflowsTest coverage includes:
1. Basic Optimizer Compatibility
2. Leaf Tensor Validation
3. Gradient Flow
4. Parameter Updates
5. TL Parameter API
6. Multi-Step Optimization Parity (integration test)
7. Compatibility Mode
Test Results
All 255 unit tests and 334 integration tests pass, including the complete optimizer compatibility suite.
Unit tests validate:
Integration tests validate:
Benefits
✅ Optimizer Compatibility: Standard PyTorch optimizers work out of the box
✅ Analysis Tool Support: SVDInterpreter and other TL tools continue working unchanged
✅ Clear API Separation: Explicit methods for training (
parameters()) vs analysis (tl_parameters())✅ Zero Breaking Changes: Existing code using
bridge.original_model.parameters()continues working✅ Discoverable API: Clear naming convention (
tl_*) signals TransformerLens-specific functionalityBackward Compatibility
For optimization workflows:
parameters()now returns optimizable leaf tensors (was broken, now fixed)named_parameters()returns HF-style names (consistent with underlying model)For analysis workflows:
tl_parameters()provides TL-style parameter dict (same as oldget_params())tl_named_parameters()provides iterator version (new convenience method)SVDInterpreterupdated to usetl_parameters()(transparent to users)No breaking changes for users who:
bridge.original_model.parameters()as a workaround (still works)Type of change
Summary
This PR fixes a bug preventing TransformerBridge optimization while preserving full TransformerLens analysis functionality. By introducing a clear dual API—
parameters()for standard PyTorch parameter semantics andtl_parameters()for analysis—we allow both standard PyTorch workflows and TransformerLens-specific tooling to coexist seamlessly.Thank you so much to the TransformerLens maintainers/community for making this foundationally valuable contribution to the open-source interpretability ecosystem!
I'm using it extensively in a downstream analysis framework I'm building and couldn't appreciate your work more! I should also thank Claude 1 for its help (somewhat obviously) creating the scaffolding of this PR and for generating/refining the unit tests.
Footnotes
Transitively thanking the authors of TransformerLens again for my ability to thank Claude, they really deserve the attribution :) ↩