diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 8d13a3f31f..23a7c47999 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -118,7 +118,7 @@ def parallelize_deepseekv3( ) if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, job_config.activation_checkpoint) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 1f579ccd04..916104cb17 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -6,6 +6,9 @@ import torch import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, +) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy from torch.distributed.tensor import Partial, Replicate, Shard @@ -30,6 +33,7 @@ ) from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.llama3.infra.parallelize import apply_ddp +from torchtitan.models.moe import moe as moe_module from torchtitan.tools.logging import logger @@ -510,17 +514,65 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): """ # NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE # but it is experimental. - # torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_scalar_outputs = True for layer_id, transformer_block in model.layers.named_children(): - # TODO: remove when torch.compile supports fullgraph=True for MoE - fullgraph = True if transformer_block.moe_enabled: - fullgraph = False - transformer_block = torch.compile( - transformer_block, - backend=compile_config.backend, - fullgraph=fullgraph, - ) + # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break + # So we must weave compile wrappers around those FSDP hooks to + # prevent AC from falling back the whole graph to eager. + + if isinstance(transformer_block, CheckpointWrapper): + # unwrap so that .named_children() works + block = transformer_block._checkpoint_wrapped_module + else: + block = transformer_block + + patch_experts = False + for attr_name, submod in block.named_children(): + assert getattr(block, attr_name) == getattr( + transformer_block, attr_name + ) + + if isinstance(submod, moe_module.MoE): + # avoid graph breaking on the GroupedExperts' FSDP hooks + # by wrapping each submod's forward instead of their __call__ + moe = submod + for attr_name, submod in moe.named_children(): + if attr_name == "experts": + # don't compile token dispatch and combine due to some issue on b200 + patch_experts = True + continue + setattr( + moe, + attr_name, + torch.compile( + submod, backend=compile_config.backend, fullgraph=True + ), + ) + else: + setattr( + block, + attr_name, + torch.compile( + submod, backend=compile_config.backend, fullgraph=True + ), + ) + + if patch_experts: + moe_module._run_experts_grouped_mm = torch.compile( + moe_module._run_experts_grouped_mm, + backend=compile_config.backend, + fullgraph=True, + ) + else: + # If it's not a MoE layer, there is no FSDP(GroupedExperts) + # So we can compile the whole block + transformer_block = torch.compile( + transformer_block, + backend=compile_config.backend, + fullgraph=True, + ) + model.layers.register_module(layer_id, transformer_block) logger.info("Compiling each TransformerBlock with torch.compile") diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 5fa8549e9f..a6c5c1424d 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -119,7 +119,7 @@ def parallelize_qwen3( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, job_config.activation_checkpoint) if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel