-
Notifications
You must be signed in to change notification settings - Fork 108
Clean up traces for symbolic values caching #2662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The initial draft in e7c8bc9 applied dce to subsymbols within |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR applies Dead Code Elimination (DCE) earlier in the process to clean up duplicated calls to prims.eq and prims.shape in subsymbols when symbolic values caching is enabled. This improves trace readability by removing redundant operations before they accumulate.
- Extended the
dcefunction to accept either aTraceor a list ofBoundSymbolInterfaceobjects - Applied DCE to a bsym's subsymbols immediately after execution to remove duplicates
- Refactored output handling to support both trace-level and subsymbol-level DCE
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| thunder/core/transform_common.py | Extended dce function to handle both traces and bound symbol lists, with conditional logic for each type |
| thunder/core/symbol.py | Added DCE call to clean up subsymbols immediately after execution, removed trailing whitespace |
Comments suppressed due to low confidence (1)
thunder/core/transform_common.py:157
- This assignment assigns a variable to itself.
output = output
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
IvanYashchuk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving DCE to symbol creation results in much tidier outputs right where most contributors and users will inspect them. This is a targeted and elegant change, thank you for addressing technical debt so precisely!
|
I'm not convince that I don't have a minimal failing example yet, but I do see that in more complex tests that fusions are missing required bound symbols. Things are still failing when I move the |
|
I don't have it fully figured out yet, but I'm narrowing in on the problem. Consider No error is surfaced when running this code. However, it does erroneously result in a split graph. When the splitter is testing the node that has the add operation, a KeyError about 't1' is thrown and caught, resulting in the split Looking more carefully at this node, one sees that the bound symbol that is attempting to be executed is and the graph is no longer split. But obviously this isn't good. I can't figure out how dce'ing the subsymbols results in this name collision. The only time the dce'ing changes a list of subsymbols in the above example is for ltorch.add where it starts with @IvanYashchuk What do you make of this? Is it worth pulling more on this thread? or can we go back to having the dce happen in |
|
I think it's worth looking into further. What part of the codebase is generating |
| # We need to be under trace context to generate proxies. | ||
| with thunder.core.trace.tracectx(TraceCtx()): | ||
| with thunder.core.trace.tracectx(tracectx): | ||
| try: | ||
| function_to_run(*proxy_args, **proxy_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the source of all of my woes. This pattern implicitly binds the provided proxy_args to the symbols being generated, as opposed to mapping the input args to the input args an established trace expects. But the proxy_args were being created in a distinct TraceCtx, and this result in name collisions. This was revealed in an application of DCE that exists way, way, deep down in the call stack when executing function_to_run. DCE creates a producer map, and in accessing this map, a KeyError was raised, triggering a graph split.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a desired one! Thank you @beverlylytle
shino16
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
|
@KaelanDt This is ready for your review. Thanks! |
When the symbolic values caching option is enabled, there are many duplicated calls to prims.eq and prims.shape that appear in any given bsym's subsymbols. DCE is currently applied before the decent to the subsymbols happens. When the descent to the subsymbols happens, it results in a very ugly and hard to read trace. This PR applies dce to the bsym's subsymbols earlier on to tidy things up.
Fixes #2728