From 9a7f5f17dd52b896442629a603ab7b12376a3a4f Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 13 Oct 2025 13:35:15 -0700 Subject: [PATCH] [moe][compile] turn on capture_scalar_outputs and assert fullgraph on grouped experts stack-info: PR: https://github.com/pytorch/torchtitan/pull/1861, branch: xmfan/stack/1 --- torchtitan/experiments/llama4/infra/parallelize.py | 9 +++++---- torchtitan/models/moe.py | 7 +++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 1f579ccd04..cc28484aec 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -510,16 +510,17 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): """ # NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE # but it is experimental. - # torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_scalar_outputs = True for layer_id, transformer_block in model.layers.named_children(): # TODO: remove when torch.compile supports fullgraph=True for MoE - fullgraph = True if transformer_block.moe_enabled: - fullgraph = False + transformer_block.moe.experts._forward = torch.compile(transformer_block.moe.experts._forward, backend=compile_config.backend, fullgraph=True) + transformer_block.moe.experts.forward = torch.compiler.disable(transformer_block.moe.experts.forward) + transformer_block = torch.compile( transformer_block, backend=compile_config.backend, - fullgraph=fullgraph, + fullgraph=not transformer_block.moe_enabled, ) model.layers.register_module(layer_id, transformer_block) diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 67fe626d34..9e43ea82c0 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -143,6 +143,13 @@ def forward( self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + return self._forward(x, num_tokens_per_expert) + + def _forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: if isinstance(self.w1, DTensor): # Convert parameters from DTensors to plain Tensors, to work with