Skip to content

Conversation

@beverlylytle
Copy link
Collaborator

@beverlylytle beverlylytle commented Oct 16, 2025

When the symbolic values caching option is enabled, there are many duplicated calls to prims.eq and prims.shape that appear in any given bsym's subsymbols. DCE is currently applied before the decent to the subsymbols happens. When the descent to the subsymbols happens, it results in a very ugly and hard to read trace. This PR applies dce to the bsym's subsymbols earlier on to tidy things up.

Fixes #2728

@beverlylytle
Copy link
Collaborator Author

The initial draft in e7c8bc9 applied dce to subsymbols within remove_duplicate_number_proxies, a function whose job it is to clean after symbolic values. In cb31f7f, dce is now applied in within Symbol.__call__ when the subsymbols are created. It unfortunately requires a local import of dce because of dependency hell. I rather prefer the idea of all the symbolic values tidying up happening in one method.

@beverlylytle beverlylytle changed the title WIP Clean up traces for symbolic values caching Nov 12, 2025
@IvanYashchuk IvanYashchuk requested a review from Copilot November 12, 2025 15:51
Copilot finished reviewing on behalf of IvanYashchuk November 12, 2025 15:52
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR applies Dead Code Elimination (DCE) earlier in the process to clean up duplicated calls to prims.eq and prims.shape in subsymbols when symbolic values caching is enabled. This improves trace readability by removing redundant operations before they accumulate.

  • Extended the dce function to accept either a Trace or a list of BoundSymbolInterface objects
  • Applied DCE to a bsym's subsymbols immediately after execution to remove duplicates
  • Refactored output handling to support both trace-level and subsymbol-level DCE

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
thunder/core/transform_common.py Extended dce function to handle both traces and bound symbol lists, with conditional logic for each type
thunder/core/symbol.py Added DCE call to clean up subsymbols immediately after execution, removed trailing whitespace
Comments suppressed due to low confidence (1)

thunder/core/transform_common.py:157

  • This assignment assigns a variable to itself.
        output = output

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving DCE to symbol creation results in much tidier outputs right where most contributors and users will inspect them. This is a targeted and elegant change, thank you for addressing technical debt so precisely!

@beverlylytle
Copy link
Collaborator Author

beverlylytle commented Nov 12, 2025

I'm not convince that Symbol.__call__ is the right place for this dce. The dynamo tests are failing non-trivially, and they weren't with the other version.

I don't have a minimal failing example yet, but I do see that in more complex tests that fusions are missing required bound symbols.
I think this has something to do with the fact that results is the output of a symbol's meta, and not the function itself. I am seeing things like check_len with a meta that should return None having its subsymbol prims.check_len eliminated because of that. Of course that doesn't explain why fusions are being gutted.

Things are still failing when I move the dce(subsymbols) line below to after the results are finalized.

@beverlylytle
Copy link
Collaborator Author

beverlylytle commented Nov 13, 2025

I don't have it fully figured out yet, but I'm narrowing in on the problem. Consider

import torch
from thunder.dynamo import thunderfx

def foo(x):
    y = torch.cos(x) + torch.sin(x)
    return y

x = torch.randn(10, 10, requires_grad=True)

jfoo = thunderfx(foo)
jfoo(x)

No error is surfaced when running this code. However, it does erroneously result in a split graph. When the splitter is testing the node that has the add operation, a KeyError about 't1' is thrown and caught, resulting in the split Looking more carefully at this node, one sees that the bound symbol that is attempting to be executed is t1 = ltorch.add(t0, t1). Note the name collision between the output and the args. I can add some hack at the top of Symbol.__call__ that adds unseen arg names to the trace:

        flat_args, _ = tree_flatten((args, kwargs))
        for arg in flat_args:
            if not hasattr(arg, 'name') or trace.has_name(arg.name):
                continue
            trace.add_name(arg.name)

and the graph is no longer split. But obviously this isn't good. I can't figure out how dce'ing the subsymbols results in this name collision. The only time the dce'ing changes a list of subsymbols in the above example is for ltorch.add where it starts with [i0 = prims.ne(alpha, 1), t1 = prims.add(a, b)], but dce removes the prims.ne.

@IvanYashchuk What do you make of this? Is it worth pulling more on this thread? or can we go back to having the dce happen in remove_duplicate_number_proxies?

@IvanYashchuk
Copy link
Collaborator

I think it's worth looking into further. What part of the codebase is generating t1 = ltorch.add(t0, t1)? It shouldn't happen. For normal non-view non-inplace operations, the output name should be new. This might be surfacing a hidden bug that is worth fixing.

Comment on lines -378 to 381
# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):
with thunder.core.trace.tracectx(tracectx):
try:
function_to_run(*proxy_args, **proxy_kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the source of all of my woes. This pattern implicitly binds the provided proxy_args to the symbols being generated, as opposed to mapping the input args to the input args an established trace expects. But the proxy_args were being created in a distinct TraceCtx, and this result in name collisions. This was revealed in an application of DCE that exists way, way, deep down in the call stack when executing function_to_run. DCE creates a producer map, and in accessing this map, a KeyError was raised, triggering a graph split.

@beverlylytle beverlylytle marked this pull request as ready for review November 14, 2025 14:43
Copy link
Collaborator

@shino16 shino16 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a desired one! Thank you @beverlylytle

Copy link
Collaborator

@shino16 shino16 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@beverlylytle
Copy link
Collaborator Author

@KaelanDt This is ready for your review. Thanks!

@IvanYashchuk IvanYashchuk enabled auto-merge (squash) December 2, 2025 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Simplify trace representation with shape symbolic values enabled

3 participants