Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,16 +510,17 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig):
"""
# NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE
# but it is experimental.
# torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_scalar_outputs = True
for layer_id, transformer_block in model.layers.named_children():
# TODO: remove when torch.compile supports fullgraph=True for MoE
fullgraph = True
if transformer_block.moe_enabled:
fullgraph = False
transformer_block.moe.experts._forward = torch.compile(transformer_block.moe.experts._forward, backend=compile_config.backend, fullgraph=True)
transformer_block.moe.experts.forward = torch.compiler.disable(transformer_block.moe.experts.forward)

transformer_block = torch.compile(
transformer_block,
backend=compile_config.backend,
fullgraph=fullgraph,
fullgraph=not transformer_block.moe_enabled,
)
model.layers.register_module(layer_id, transformer_block)

Expand Down
7 changes: 7 additions & 0 deletions torchtitan/models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def forward(
self,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
return self._forward(x, num_tokens_per_expert)

def _forward(
self,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
if isinstance(self.w1, DTensor):
# Convert parameters from DTensors to plain Tensors, to work with
Expand Down
Loading