Fix gradient mismatch in checkpointing by preserving node metadata #2553
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Fixes gradient computation errors in
test_checkpointing_thunderfx
by preserving PyTorch FX node metadata during checkpoint function conversion.Problem
The test
test_networks.py::test_checkpointing_thunderfx
was failing with significant gradient mismatches between eager PyTorch and thunderfx, particularly on B200/GB200 GPUs:Root Cause
In
thunder/dynamo/utils.py
, the_checkpoint_function_converter
function converts PyTorch operators to Thunder operators inside checkpointed functions. However, it was not copying the FX node metadata (node.meta
) from the original nodes to the new Thunder nodes.This metadata contains critical information for gradient computation:
example_value
: Tensor with shape, dtype, device informationrequires_grad
: Whether gradients should be computedtensor_meta
: Additional tensor properties (stride, layout, etc.)Without this metadata, the Thunder operators lacked the necessary information to properly participate in the autograd graph during the backward pass of activation checkpointing, resulting in incorrect gradient computation.
Solution
Added a single line to copy metadata when creating Thunder nodes:
This ensures all tensor properties are preserved, allowing the Thunder operators to correctly compute gradients during the recomputation phase of activation checkpointing.
Impact
Fixes #[issue_number_if_available]
Original prompt
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.