Skip to content

Conversation

@speediedan
Copy link

Resolves #1140

Firstly, thank you so much for building and maintaining TransformerLens - it's a seminal, foundational and invaluable component enabling increasingly vital open-source interpretability research!

This PR resolves device/dtype synchronization issues that caused test failures in mixed-precision (float16/bfloat16) and multi-GPU contexts. The changes ensure tensors maintain consistent dtypes during attention operations and that model components stay synchronized with their configured devices.

Note

This PR represents foundational fixes that are included in two companion PRs (non-leaf-tensor-fix and tlens-generate-enhancement) to ensure their test suites pass cleanly. Those PRs build on these device/dtype fixes to add additional functionality.

The Issues

Half-Precision Type Mismatches

Attention operations were casting patterns to cfg.dtype even when value tensors had been upcast to float32 for numerical stability, causing dtype mismatches in matrix operations.

Multi-Device Cache Failures

TransformerBridge.to() only moved the underlying HuggingFace model, leaving bridge configuration and components unsynchronized, causing cache tensors to remain on incorrect devices.

Output Projection Device Conflicts

Output projection weights and biases weren't consistently moved to match input activation devices.

The Solution

1. Dynamic dtype synchronization in attention (transformer_lens/components/abstract_attention.py)

# Before: Cast pattern to config dtype
pattern = pattern.to(self.cfg.dtype)
pattern = pattern.to(v.device)

# After: Match pattern dtype to value tensor dtype
pattern = pattern.to(device=v.device, dtype=v.dtype)

This change allows patterns to adapt to the actual dtype of value tensors, which may differ from cfg.dtype when operations require float32 for numerical stability.

Output projection device/dtype alignment:

# Move weights and bias to match input device
if w.device != z.device:
    w = w.to(z.device)
b_O = self.b_O
if b_O.device != z.device:
    b_O = b_O.to(z.device)

# Ensure activations match weight dtype before projection
if z.dtype != w.dtype:
    z = z.to(w.dtype)

out = F.linear(z.reshape(...), w, b_O)

2. Comprehensive device/dtype tracking in TransformerBridge (transformer_lens/model_bridge/bridge.py)

Enhanced the to() method to use move_to_and_update_config() utility:

def to(self, *args, **kwargs) -> "TransformerBridge":
    """Move model to device and/or change dtype."""
    # Extract and parse device/dtype from various call patterns
    target_device, target_dtype = None, None
    
    # Handle: to(device), to(dtype), to(device, dtype), 
    # to(device=...), to(dtype=...), to(device=..., dtype=...)
    if len(args) >= 1:
        first_arg = args[0]
        if isinstance(first_arg, (torch.device, str)):
            target_device = first_arg
        elif isinstance(first_arg, torch.dtype):
            target_dtype = first_arg
    # ... additional parsing logic
    
    # Synchronize config with actual device/dtype
    if target_device is not None:
        move_to_and_update_config(self, target_device, print_details)
    if target_dtype is not None:
        move_to_and_update_config(self, target_dtype, print_details)
    
    # Move the original HF model
    self.original_model = self.original_model.to(*args, **kwargs)
    return self

This ensures:

  • cfg.device and cfg.dtype stay synchronized with actual tensor locations
  • All bridge components (not just original_model) are properly moved
  • Multi-GPU cache tensors remain on correct devices

3. Test improvements

  • test_hooked_encoder.py: Fixed fixture name reference (tokens instead of mlm_tokens)
  • test_multi_gpu.py: Updated to pass torch.device("cpu") instead of string "cpu" for proper type validation
  • test_attention.py: Added test_attention_forward_half_precisions() to validate float16/bfloat16 compatibility
  • test_multiply_by_scalar.py: Added test IDs to parametrizations to avoid pytest cache issues with random values

4. Repository hygiene

  • Added .env to .gitignore to support local environment configuration (e.g., HF_TOKEN for test suite)
  • Updated VSCode test discovery path from "transformer_lens" to "tests"

Test Results

Before:

==== 8 failed, 1075 passed, 67 skipped, 154 warnings in 1005.09s (0:16:45) =====

After:

========= 1087 passed, 63 skipped, 152 warnings in 1014.55s (0:16:54) ==========

All previously failing tests now pass:

  • test_multi_gpu.py::test_cache_device
  • test_legacy_hooked_transformer_coverage.py::test_memory_efficiency[gpt2]
  • test_legacy_hooked_transformer_coverage.py::test_consistent_outputs[gpt2]
  • test_hooked_transformer.py::test_half_precision[dtype0]
  • test_hooked_transformer.py::test_half_precision[dtype1]
  • test_attention.py::test_attention_forward_half_precisions[dtype0]
  • test_attention.py::test_attention_forward_half_precisions[dtype1]
  • test_utils.py::test_device_compatibility[gpt2]

Benefits

Robustness: Mixed-precision workflows now work correctly without manual dtype management

Multi-GPU support: Proper device synchronization enables reliable multi-device inference

Maintainability: Configuration state now accurately reflects model state, reducing debugging complexity

Backward Compatibility

No breaking changes - all existing usage patterns remain unchanged. The fixes are purely internal implementation improvements that make the library more robust in edge cases.


Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)

Summary

This PR resolves device and dtype synchronization issues in TransformerLens v3's attention mechanism and model bridge architecture.

Thank you so much to the TransformerLens maintainers/community for making this foundationally valuable contribution to the open-source interpretability ecosystem!

I'm using it extensively in a downstream analysis framework I'm building and couldn't appreciate your work more! I should also thank Claude 1 for its help (somewhat obviously) creating the scaffolding of this PR and for generating/refining the unit tests.

Footnotes

  1. Transitively thanking the authors of TransformerLens again for my ability to thank Claude, they really deserve the attribution :)

…xts.

We also update .gitignore to exclude .env (commonly used local file exclution), e.g. to allow collaborators to add their on HF_TOKEN for test suite

Core Fixes:
-----------

transformer_lens/components/abstract_attention.py:
  - Replace pattern.to(self.cfg.dtype) with pattern.to(v.dtype) to handle cases
    where tensors are upcast to float32 for numerical stability while cfg.dtype
    remains float16/bfloat16
  - Add explicit device/dtype synchronization for output projection:
    * Move weights (W_O) and bias (b_O) to match input device (z.device)
    * Ensure z matches weight dtype before final linear operation

transformer_lens/model_bridge/bridge.py:
  - Replace direct original_model.to() call with move_to_and_update_config()
    utility to ensure:
    * All bridge components (not just original_model) are moved to target device
    * cfg.device and cfg.dtype stay synchronized with actual model state
    * Multi-GPU cache tensors remain on correct devices

Test Fixes:
-----------

tests/acceptance/test_hooked_encoder.py:
  - Fix test_cuda() to use correct fixture name 'tokens' instead of 'mlm_tokens'

tests/acceptance/test_multi_gpu.py:
  - Update test_cache_device() to pass torch.device("cpu") instead of string
    "cpu" for proper device type validation

tests/unit/components/test_attention.py:
  - Add test_attention_forward_half_precisions() to validate attention works
    correctly with bfloat16/float16 dtypes on CUDA devices

tests/unit/factored_matrix/test_multiply_by_scalar.py:
  - Add test IDs to parametrize decorators to avoid pytest cache issues when
    random numbers appear in test names

Tests Fixed by This Commit:
---------------------------
- tests/acceptance/test_multi_gpu.py::test_cache_device
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2]
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1]
- tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2]
Enhance to() method to properly handle both device and dtype arguments in
all supported PyTorch formats (positional, keyword, combined). Separately
invoke move_to_and_update_config for device/dtype to update cfg while
delegating the actual tensor movement to original_model.to() with original
args/kwargs. This ensures TransformerBridge respects standard PyTorch
behavior for model.to() calls.
@speediedan speediedan marked this pull request as ready for review November 29, 2025 22:28
Copilot AI review requested due to automatic review settings November 29, 2025 22:28
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses device and dtype synchronization issues in TransformerLens that caused test failures in mixed-precision (float16/bfloat16) and multi-GPU contexts. The changes ensure consistent tensor dtypes during attention operations and proper device synchronization across model components.

Key Changes:

  • Enhanced attention operations to synchronize pattern dtype with value tensor dtype instead of config dtype
  • Improved output projection device/dtype alignment in attention forward pass
  • Extended TransformerBridge.to() to properly update configuration state when moving devices/dtypes
  • Fixed test issues including incorrect fixture references and device type parameters

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
transformer_lens/components/abstract_attention.py Updated pattern dtype synchronization to match value tensors; added device/dtype checks for output projection weights
transformer_lens/model_bridge/bridge.py Enhanced to() method to parse device/dtype arguments and update config using move_to_and_update_config utility
tests/unit/components/test_attention.py Added half-precision forward pass test for float16/bfloat16 validation
tests/acceptance/test_multi_gpu.py Fixed device parameter to use torch.device object instead of string
tests/acceptance/test_hooked_encoder.py Corrected fixture reference from mlm_tokens to tokens
tests/unit/factored_matrix/test_multiply_by_scalar.py Added test IDs to parametrize for better test identification
.vscode/settings.json Updated pytest path from "transformer_lens" to "tests" for proper test discovery
.gitignore Added .env file to ignore list for local environment configuration

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant