Skip to content

Conversation

@speediedan
Copy link

Resolves #1141

Firstly, thank you so much for building and maintaining TransformerLens - it's a seminal, foundational and invaluable component enabling increasingly vital open-source interpretability research!

This PR fixes an issue where TransformerBridge.parameters() returns non-leaf tensors that cannot be optimized, while preserving TransformerLens-style parameter access for analysis tools.

Note

This PR builds on the device-dtype-sync-fixes branch to ensure all tests pass cleanly before adding these optimizer compatibility enhancements.

The Issue

TransformerBridge's parameters() method returned non-leaf tensors (created by einops.rearrange()), breaking PyTorch's fundamental requirement that optimizer parameters be leaf tensors. This prevented users from fine-tuning or training TransformerBridge models:

bridge = TransformerBridge.boot_transformers("gpt2")
optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-4)
# ValueError: can't optimize a non-leaf Tensor

The Solution

Dual Parameter API - Adhere to standard PyTorch parameter()/named_parameters() semantics while providing explicit TransformerLens-style parameter access methods:

1. PyTorch-style API - Follows standard PyTorch semantics while delegating to underlying HuggingFace model:

def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
    """Returns parameters following standard PyTorch semantics.
    
    Delegates to the underlying HuggingFace model's parameters().
    For TransformerLens-style parameters, use tl_parameters() instead.
    """
    return self.original_model.parameters(recurse=recurse)

def named_parameters(self, prefix: str = "", recurse: bool = True,
                     remove_duplicate: bool = True) -> Iterator[tuple[str, nn.Parameter]]:
    """Returns named parameters following standard PyTorch semantics.
    
    Delegates to the underlying HuggingFace model's named_parameters().
    For TransformerLens-style names, use tl_named_parameters() instead.
    """
    return self.original_model.named_parameters(prefix, recurse, remove_duplicate)

2. TransformerLens-style API - preserves TL conventions:

def tl_parameters(self) -> dict[str, torch.Tensor]:
    """Returns TransformerLens-style parameter dictionary.
    
    Parameter names follow TL conventions (e.g., 'blocks.0.attn.W_Q').
    May include non-leaf tensors created by einops.rearrange().
    Expected by SVDInterpreter and other analysis tools.
    """
    return self.get_params()

def tl_named_parameters(self) -> Iterator[tuple[str, torch.Tensor]]:
    """Returns iterator of TransformerLens-style named parameters.
    
    Provides same content as tl_parameters() but as an iterator
    for consistency with PyTorch's named_parameters() API pattern.
    """
    return iter(self.get_params().items())

3. Update SVDInterpreter - Use tl_parameters() for TransformerBridge:

class SVDInterpreter:
    def __init__(self, model: Any):
        self.model = model
        self.cfg = model.cfg
        # Use tl_parameters() for TransformerBridge (returns TL-style dict)
        # Fall back to named_parameters() for HookedTransformer
        if hasattr(model, "tl_parameters"):
            self.params = model.tl_parameters()
        else:
            self.params = {name: param for name, param in model.named_parameters()}

Key Design Decisions

Rather than using the dual API approach, we considered several alternatives including monolithic APIs with flags or instance configuration modes. However, these had significant drawbacks, among them:

  • ❌ variable return types changing based on parameter
  • ❌ Easy to use wrong style by accident
  • ❌ Global state making it less flexible (in instance config case)
  • ❌ Loses parameter names in TL style (in instance config case)
  • ❌ Can't mix styles on same instance (in instance config case)

Test Coverage

New test files:

  • tests/unit/model_bridge/test_optimizer_compatibility.py - Unit tests for parameter API correctness
  • tests/integration/model_bridge/test_optimizer_compatibility.py - Integration tests for end-to-end workflows

Test coverage includes:

1. Basic Optimizer Compatibility

def test_adamw_accepts_parameters(small_bridge_model):
    """Test that AdamW optimizer accepts TransformerBridge parameters."""
    optimizer = torch.optim.AdamW(small_bridge_model.parameters(), lr=1e-4)
    assert optimizer is not None

2. Leaf Tensor Validation

def test_all_parameters_are_leaf_tensors(small_bridge_model):
    """Verify all parameters returned by parameters() are leaf tensors."""
    for i, param in enumerate(small_bridge_model.parameters()):
        assert param.is_leaf, f"Parameter {i} is non-leaf"
        assert isinstance(param, nn.Parameter)

3. Gradient Flow

def test_gradient_flow_after_backward(small_bridge_model):
    """Test that gradients flow correctly after backward pass."""
    logits = small_bridge_model(input_ids, return_type="logits")
    loss = logits.sum()
    loss.backward()
    
    params_with_grad = [p for p in small_bridge_model.parameters() 
                       if p.grad is not None]
    assert len(params_with_grad) > 0

4. Parameter Updates

def test_optimizer_step_updates_parameters(small_bridge_model):
    """Test that optimizer.step() actually updates model parameters."""
    optimizer = torch.optim.SGD(small_bridge_model.parameters(), lr=0.1)
    param_before = list(small_bridge_model.parameters())[0].clone()
    
    # Forward, backward, step
    loss.backward()
    optimizer.step()
    
    param_after = list(small_bridge_model.parameters())[0]
    assert not torch.allclose(param_before, param_after)

5. TL Parameter API

def test_tl_parameters_provides_tl_style_names(small_bridge_model):
    """Verify tl_parameters() provides TransformerLens-style names."""
    tl_params = small_bridge_model.tl_parameters()
    assert any("blocks." in name and ".attn." in name 
               for name in tl_params.keys())
    assert any(name.endswith(".W_E") for name in tl_params.keys())

6. Multi-Step Optimization Parity (integration test)

def test_bridge_hooked_parity_multi_step_optimization():
    """Test parity between Bridge and HookedTransformer across 
    multiple optimization steps (1, 10).
    
    Validates both architectures maintain comparable results over
    multiple steps, checking:
    - Initial forward pass: logits and loss alignment
    - Post-update forward pass: logits and loss remain close
    - Parameter updates: unembed weights remain close
    """
    # Tests at step 1 and step 10 with defined thresholds
    # Ensures Bridge maintains parity with HookedTransformer during training

7. Compatibility Mode

def test_parameters_still_leaf_after_compatibility_mode(small_bridge_model):
    """Verify parameters() returns leaf tensors even after 
    enabling compatibility mode."""
    small_bridge_model.enable_compatibility_mode(no_processing=True)
    for param in small_bridge_model.parameters():
        assert param.is_leaf

Test Results

All 255 unit tests and 334 integration tests pass, including the complete optimizer compatibility suite.

Unit tests validate:

  • All parameters are leaf tensors
  • Optimizers accept parameters without errors
  • Gradients flow correctly
  • Parameters update during optimization
  • TL-style parameters provide correct naming conventions
  • Named parameter iterators match dictionary methods
  • Compatibility mode doesn't break optimizer functionality

Integration tests validate:

  • End-to-end training workflows work correctly
  • Multi-step optimization maintains parity with HookedTransformer
  • Forward pass remains consistent after parameter updates
  • Loss convergence behaves as expected
  • Batch dimensions preserved through optimization steps

Benefits

✅ Optimizer Compatibility: Standard PyTorch optimizers work out of the box

✅ Analysis Tool Support: SVDInterpreter and other TL tools continue working unchanged

✅ Clear API Separation: Explicit methods for training (parameters()) vs analysis (tl_parameters())

✅ Zero Breaking Changes: Existing code using bridge.original_model.parameters() continues working

✅ Discoverable API: Clear naming convention (tl_*) signals TransformerLens-specific functionality

Backward Compatibility

For optimization workflows:

  • parameters() now returns optimizable leaf tensors (was broken, now fixed)
  • named_parameters() returns HF-style names (consistent with underlying model)

For analysis workflows:

  • tl_parameters() provides TL-style parameter dict (same as old get_params())
  • tl_named_parameters() provides iterator version (new convenience method)
  • SVDInterpreter updated to use tl_parameters() (transparent to users)

No breaking changes for users who:

  • Never tried to optimize (they weren't affected by the bug)
  • Used bridge.original_model.parameters() as a workaround (still works)
  • Only used TransformerBridge for inference (no change in behavior)

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Summary

This PR fixes a bug preventing TransformerBridge optimization while preserving full TransformerLens analysis functionality. By introducing a clear dual API—parameters() for standard PyTorch parameter semantics and tl_parameters() for analysis—we allow both standard PyTorch workflows and TransformerLens-specific tooling to coexist seamlessly.

Thank you so much to the TransformerLens maintainers/community for making this foundationally valuable contribution to the open-source interpretability ecosystem!

I'm using it extensively in a downstream analysis framework I'm building and couldn't appreciate your work more! I should also thank Claude 1 for its help (somewhat obviously) creating the scaffolding of this PR and for generating/refining the unit tests.

Footnotes

  1. Transitively thanking the authors of TransformerLens again for my ability to thank Claude, they really deserve the attribution :)

…xts.

We also update .gitignore to exclude .env (commonly used local file exclution), e.g. to allow collaborators to add their on HF_TOKEN for test suite

Core Fixes:
-----------

transformer_lens/components/abstract_attention.py:
  - Replace pattern.to(self.cfg.dtype) with pattern.to(v.dtype) to handle cases
    where tensors are upcast to float32 for numerical stability while cfg.dtype
    remains float16/bfloat16
  - Add explicit device/dtype synchronization for output projection:
    * Move weights (W_O) and bias (b_O) to match input device (z.device)
    * Ensure z matches weight dtype before final linear operation

transformer_lens/model_bridge/bridge.py:
  - Replace direct original_model.to() call with move_to_and_update_config()
    utility to ensure:
    * All bridge components (not just original_model) are moved to target device
    * cfg.device and cfg.dtype stay synchronized with actual model state
    * Multi-GPU cache tensors remain on correct devices

Test Fixes:
-----------

tests/acceptance/test_hooked_encoder.py:
  - Fix test_cuda() to use correct fixture name 'tokens' instead of 'mlm_tokens'

tests/acceptance/test_multi_gpu.py:
  - Update test_cache_device() to pass torch.device("cpu") instead of string
    "cpu" for proper device type validation

tests/unit/components/test_attention.py:
  - Add test_attention_forward_half_precisions() to validate attention works
    correctly with bfloat16/float16 dtypes on CUDA devices

tests/unit/factored_matrix/test_multiply_by_scalar.py:
  - Add test IDs to parametrize decorators to avoid pytest cache issues when
    random numbers appear in test names

Tests Fixed by This Commit:
---------------------------
- tests/acceptance/test_multi_gpu.py::test_cache_device
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_memory_efficiency[gpt2]
- tests/acceptance/model_bridge/compatibility/test_legacy_hooked_transformer_coverage.py::TestLegacyHookedTransformerCoverage::test_consistent_outputs[gpt2]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype0]
- tests/acceptance/test_hooked_transformer.py::test_half_precision[dtype1]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype0]
- tests/unit/components/test_attention.py::test_attention_forward_half_precisions[dtype1]
- tests/unit/model_bridge/compatibility/test_utils.py::TestUtilsWithTransformerBridge::test_device_compatibility[gpt2]
…rameter/named_parameter generator API

Add parameters()/named_parameters() returning leaf tensors for PyTorch optimizers
and tl_parameters()/tl_named_parameters() for TransformerLens-style format.
Update SVDInterpreter to use tl_parameters(). Add comprehensive tests.
Enhance to() method to properly handle both device and dtype arguments in
all supported PyTorch formats (positional, keyword, combined). Separately
invoke move_to_and_update_config for device/dtype to update cfg while
delegating the actual tensor movement to original_model.to() with original
args/kwargs. This ensures TransformerBridge respects standard PyTorch
behavior for model.to() calls.
Add comprehensive test validating optimizer parity between TransformerBridge
and HookedTransformer across optimization steps (1, 10). Test checks:
- Initial forward pass alignment (logits, loss)
- Parameter updates (unembed weights remain close)
- Post-update forward pass convergence

Uses NamedTuple for threshold definitions with magnitude-based validation
thresholds (1e-3, 1e-5, etc.) for cleaner readability.
@speediedan speediedan marked this pull request as ready for review November 29, 2025 22:30
Copilot AI review requested due to automatic review settings November 29, 2025 22:30
Copilot finished reviewing on behalf of speediedan November 29, 2025 22:32
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR resolves a critical bug where TransformerBridge.parameters() returned non-leaf tensors that couldn't be optimized by PyTorch optimizers. The solution introduces a dual API approach: standard PyTorch methods (parameters(), named_parameters()) delegate to the underlying HuggingFace model, while new TransformerLens-specific methods (tl_parameters(), tl_named_parameters()) provide processed parameters for analysis tools.

Key Changes:

  • Introduced dual parameter access API to support both optimization and analysis workflows
  • Updated SVDInterpreter to use tl_parameters() for TransformerBridge compatibility
  • Enhanced device/dtype handling in TransformerBridge.to() method and attention components

Reviewed changes

Copilot reviewed 10 out of 12 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
transformer_lens/model_bridge/bridge.py Added parameters(), named_parameters(), tl_parameters(), and tl_named_parameters() methods; enhanced to() method for improved device/dtype handling
transformer_lens/SVDInterpreter.py Updated to use tl_parameters() for TransformerBridge models while maintaining backward compatibility with HookedTransformer
transformer_lens/components/abstract_attention.py Improved device and dtype consistency for attention operations by consolidating device/dtype conversions
tests/unit/model_bridge/test_optimizer_compatibility.py Comprehensive unit tests validating leaf tensor guarantees, optimizer compatibility, and dual API correctness
tests/integration/model_bridge/test_optimizer_compatibility.py Integration tests validating end-to-end optimization workflows and multi-step training parity
tests/unit/factored_matrix/test_multiply_by_scalar.py Added test IDs for better test identification
tests/unit/components/test_attention.py Added half-precision (bfloat16/float16) attention tests for CUDA
tests/acceptance/test_multi_gpu.py Fixed device specification to use torch.device object
tests/acceptance/test_hooked_encoder.py Fixed test fixture reference from mlm_tokens to tokens
.vscode/settings.json Updated pytest args to test tests directory instead of transformer_lens
.gitignore Added .env to ignore list

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@bryce13950
Copy link
Collaborator

Thanks so much for this! I'll review your prs relatively quickly, and likely get it all folded into the coming beta release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants