[MoE][compile][full ac] weave torch.compile around the FSDP(GroupedExperts) graph break #1895
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
This PR changes how we compile MoE layers to work around Compile + AC limitations. When you
AC(Compile(block))
orCompile(AC(block))
and there is a graph break inblock
, we fall back the entire block to eager. For llama3, we've worked around this problem by addressing all graph breaks. With MoE models particularly dp2ep, we need to wrapFSDP(block.moe.experts)
, meaning that we will have graph breaks when tracingblock.moe.experts.__call__
, meaning that whenever AC was enabled, the entire block for MoE would fallback to eager: https://gist.github.com/xmfan/50f4de1e89d789cd63a21aca9e600132 (Note in the tlparse, graph 0/1 is empty and it corresponds to the block containing the MoE).The workaround in this PR is to avoid tracing
block.moe.experts.__call__
. This is done by individually wrapping torch.compile on submodules of TransformerBlock. Note that we are leaving some perf on the table as this might exclude some ops in TransformerBlock.forward and MoE.forward. This is an API limitation, as we have no way to acquire those ops while decoupling the wrapper from model code. This workaround will no longer be necessity when either:This change introduces a small regression to the non-AC configuration. You can see a small perf dip from before this PR and after this PR. Given that AC is a necessity to run non-toy configurations of these models, I chose to stick to this implementation to make comparisons easier.
Validated on DSv3 debug model: