-
Notifications
You must be signed in to change notification settings - Fork 479
Fix device and dtype synchronization in attention and model bridge #1142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-3.x-folding
Are you sure you want to change the base?
Fix device and dtype synchronization in attention and model bridge #1142
Conversation
…xts.
We also update .gitignore to exclude .env (commonly used local file exclution), e.g. to allow collaborators to add their on HF_TOKEN for test suite
Core Fixes:
-----------
transformer_lens/components/abstract_attention.py:
- Replace pattern.to(self.cfg.dtype) with pattern.to(v.dtype) to handle cases
where tensors are upcast to float32 for numerical stability while cfg.dtype
remains float16/bfloat16
- Add explicit device/dtype synchronization for output projection:
* Move weights (W_O) and bias (b_O) to match input device (z.device)
* Ensure z matches weight dtype before final linear operation
transformer_lens/model_bridge/bridge.py:
- Replace direct original_model.to() call with move_to_and_update_config()
utility to ensure:
* All bridge components (not just original_model) are moved to target device
* cfg.device and cfg.dtype stay synchronized with actual model state
* Multi-GPU cache tensors remain on correct devices
Test Fixes:
-----------
tests/acceptance/test_hooked_encoder.py:
- Fix test_cuda() to use correct fixture name 'tokens' instead of 'mlm_tokens'
tests/acceptance/test_multi_gpu.py:
- Update test_cache_device() to pass torch.device("cpu") instead of string
"cpu" for proper device type validation
tests/unit/components/test_attention.py:
- Add test_attention_forward_half_precisions() to validate attention works
correctly with bfloat16/float16 dtypes on CUDA devices
tests/unit/factored_matrix/test_multiply_by_scalar.py:
- Add test IDs to parametrize decorators to avoid pytest cache issues when
random numbers appear in test names
Tests Fixed by This Commit:
---------------------------
- tests/acceptance/test_multi_gpu.py::test_cache_device
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2]
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1]
- tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2]
Enhance to() method to properly handle both device and dtype arguments in all supported PyTorch formats (positional, keyword, combined). Separately invoke move_to_and_update_config for device/dtype to update cfg while delegating the actual tensor movement to original_model.to() with original args/kwargs. This ensures TransformerBridge respects standard PyTorch behavior for model.to() calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR addresses device and dtype synchronization issues in TransformerLens that caused test failures in mixed-precision (float16/bfloat16) and multi-GPU contexts. The changes ensure consistent tensor dtypes during attention operations and proper device synchronization across model components.
Key Changes:
- Enhanced attention operations to synchronize pattern dtype with value tensor dtype instead of config dtype
- Improved output projection device/dtype alignment in attention forward pass
- Extended TransformerBridge.to() to properly update configuration state when moving devices/dtypes
- Fixed test issues including incorrect fixture references and device type parameters
Reviewed changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
transformer_lens/components/abstract_attention.py |
Updated pattern dtype synchronization to match value tensors; added device/dtype checks for output projection weights |
transformer_lens/model_bridge/bridge.py |
Enhanced to() method to parse device/dtype arguments and update config using move_to_and_update_config utility |
tests/unit/components/test_attention.py |
Added half-precision forward pass test for float16/bfloat16 validation |
tests/acceptance/test_multi_gpu.py |
Fixed device parameter to use torch.device object instead of string |
tests/acceptance/test_hooked_encoder.py |
Corrected fixture reference from mlm_tokens to tokens |
tests/unit/factored_matrix/test_multiply_by_scalar.py |
Added test IDs to parametrize for better test identification |
.vscode/settings.json |
Updated pytest path from "transformer_lens" to "tests" for proper test discovery |
.gitignore |
Added .env file to ignore list for local environment configuration |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Resolves #1140
Firstly, thank you so much for building and maintaining TransformerLens - it's a seminal, foundational and invaluable component enabling increasingly vital open-source interpretability research!
This PR resolves device/dtype synchronization issues that caused test failures in mixed-precision (float16/bfloat16) and multi-GPU contexts. The changes ensure tensors maintain consistent dtypes during attention operations and that model components stay synchronized with their configured devices.
Note
This PR represents foundational fixes that are included in two companion PRs (
non-leaf-tensor-fixandtlens-generate-enhancement) to ensure their test suites pass cleanly. Those PRs build on these device/dtype fixes to add additional functionality.The Issues
Half-Precision Type Mismatches
Attention operations were casting patterns to
cfg.dtypeeven when value tensors had been upcast to float32 for numerical stability, causing dtype mismatches in matrix operations.Multi-Device Cache Failures
TransformerBridge.to()only moved the underlying HuggingFace model, leaving bridge configuration and components unsynchronized, causing cache tensors to remain on incorrect devices.Output Projection Device Conflicts
Output projection weights and biases weren't consistently moved to match input activation devices.
The Solution
1. Dynamic dtype synchronization in attention (
transformer_lens/components/abstract_attention.py)This change allows patterns to adapt to the actual dtype of value tensors, which may differ from
cfg.dtypewhen operations require float32 for numerical stability.Output projection device/dtype alignment:
2. Comprehensive device/dtype tracking in TransformerBridge (
transformer_lens/model_bridge/bridge.py)Enhanced the
to()method to usemove_to_and_update_config()utility:This ensures:
cfg.deviceandcfg.dtypestay synchronized with actual tensor locationsoriginal_model) are properly moved3. Test improvements
tokensinstead ofmlm_tokens)torch.device("cpu")instead of string"cpu"for proper type validationtest_attention_forward_half_precisions()to validate float16/bfloat16 compatibility4. Repository hygiene
.envto.gitignoreto support local environment configuration (e.g., HF_TOKEN for test suite)"transformer_lens"to"tests"Test Results
Before:
After:
All previously failing tests now pass:
test_multi_gpu.py::test_cache_devicetest_legacy_hooked_transformer_coverage.py::test_memory_efficiency[gpt2]test_legacy_hooked_transformer_coverage.py::test_consistent_outputs[gpt2]test_hooked_transformer.py::test_half_precision[dtype0]test_hooked_transformer.py::test_half_precision[dtype1]test_attention.py::test_attention_forward_half_precisions[dtype0]test_attention.py::test_attention_forward_half_precisions[dtype1]test_utils.py::test_device_compatibility[gpt2]Benefits
Robustness: Mixed-precision workflows now work correctly without manual dtype management
Multi-GPU support: Proper device synchronization enables reliable multi-device inference
Maintainability: Configuration state now accurately reflects model state, reducing debugging complexity
Backward Compatibility
✅ No breaking changes - all existing usage patterns remain unchanged. The fixes are purely internal implementation improvements that make the library more robust in edge cases.
Type of change
Please delete options that are not relevant.
Summary
This PR resolves device and dtype synchronization issues in TransformerLens v3's attention mechanism and model bridge architecture.
Thank you so much to the TransformerLens maintainers/community for making this foundationally valuable contribution to the open-source interpretability ecosystem!
I'm using it extensively in a downstream analysis framework I'm building and couldn't appreciate your work more! I should also thank Claude 1 for its help (somewhat obviously) creating the scaffolding of this PR and for generating/refining the unit tests.
Footnotes
Transitively thanking the authors of TransformerLens again for my ability to thank Claude, they really deserve the attribution :) ↩