Skip to content

Commit b9efec7

Browse files
committed
[MoE][compile][full ac] weave torch.compile around the FSDP(GroupedExperts) graph break
1 parent 92ed8b3 commit b9efec7

File tree

3 files changed

+57
-14
lines changed

3 files changed

+57
-14
lines changed

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def parallelize_deepseekv3(
118118
)
119119

120120
if model_compile_enabled:
121-
apply_compile(model, job_config.compile)
121+
apply_compile(model, job_config.compile, job_config.activation_checkpoint)
122122

123123
dp_mesh: DeviceMesh | None = None
124124
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
import torch
88
import torch.nn as nn
9+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
10+
CheckpointWrapper,
11+
)
912
from torch.distributed.device_mesh import DeviceMesh
1013
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
1114
from torch.distributed.tensor import Partial, Replicate, Shard
@@ -18,7 +21,10 @@
1821
SequenceParallel,
1922
)
2023
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
21-
from torchtitan.config.job_config import Compile as CompileConfig
24+
from torchtitan.config.job_config import (
25+
ActivationCheckpoint as ACConfig,
26+
Compile as CompileConfig,
27+
)
2228
from torchtitan.distributed import NoParallel, ParallelDims
2329
from torchtitan.distributed.activation_checkpoint import apply_ac
2430

@@ -30,6 +36,7 @@
3036
)
3137
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
3238
from torchtitan.models.llama3.infra.parallelize import apply_ddp
39+
from torchtitan.models.moe.moe import MoE
3340
from torchtitan.tools.logging import logger
3441

3542

@@ -125,7 +132,7 @@ def parallelize_llama(
125132

126133
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
127134
if model_compile_enabled:
128-
apply_compile(model, job_config.compile)
135+
apply_compile(model, job_config.compile, job_config.activation_checkpoint)
129136

130137
dp_mesh: DeviceMesh | None = None
131138
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
@@ -503,24 +510,60 @@ def apply_moe_ep_tp(
503510
)
504511

505512

506-
def apply_compile(model: nn.Module, compile_config: CompileConfig):
513+
def apply_compile(model: nn.Module, compile_config: CompileConfig, ac_config: ACConfig):
507514
"""
508515
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
509516
repeated structure. Alternatively one can compile the whole model (after applying DP).
510517
"""
511518
# NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE
512519
# but it is experimental.
513-
# torch._dynamo.config.capture_scalar_outputs = True
520+
torch._dynamo.config.capture_scalar_outputs = True
514521
for layer_id, transformer_block in model.layers.named_children():
515-
# TODO: remove when torch.compile supports fullgraph=True for MoE
516-
fullgraph = True
517522
if transformer_block.moe_enabled:
518-
fullgraph = False
519-
transformer_block = torch.compile(
520-
transformer_block,
521-
backend=compile_config.backend,
522-
fullgraph=fullgraph,
523-
)
523+
# If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break
524+
# So we must weave compile wrappers around those FSDP hooks to
525+
# prevent AC from falling back the whole graph to eager.
526+
assert (
527+
ac_config.mode != "selective"
528+
), "Selective Activation Checkpointing + Compile is not yet supported for MoE models."
529+
530+
if isinstance(transformer_block, CheckpointWrapper):
531+
# unwrap so that .named_children() works
532+
block = transformer_block._checkpoint_wrapped_module
533+
else:
534+
block = transformer_block
535+
536+
for attr_name, submod in block.named_children():
537+
if isinstance(submod, MoE):
538+
# avoid graph breaking on the GroupedExperts' FSDP hooks
539+
# by wrapping each submod's forward instead of their __call__
540+
moe_key = attr_name
541+
moe = submod
542+
for attr_name, submod in moe.named_children():
543+
setattr(
544+
moe,
545+
attr_name,
546+
torch.compile(
547+
submod, backend=compile_config.backend, fullgraph=True
548+
),
549+
)
550+
else:
551+
setattr(
552+
transformer_block,
553+
attr_name,
554+
torch.compile(
555+
submod, backend=compile_config.backend, fullgraph=True
556+
),
557+
)
558+
else:
559+
# If it's not a MoE layer, there is no FSDP(GroupedExperts)
560+
# So we can compile the whole block
561+
transformer_block = torch.compile(
562+
transformer_block,
563+
backend=compile_config.backend,
564+
fullgraph=True,
565+
)
566+
524567
model.layers.register_module(layer_id, transformer_block)
525568

526569
logger.info("Compiling each TransformerBlock with torch.compile")

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def parallelize_qwen3(
119119

120120
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
121121
if model_compile_enabled:
122-
apply_compile(model, job_config.compile)
122+
apply_compile(model, job_config.compile, job_config.activation_checkpoint)
123123

124124
if parallel_dims.fsdp_enabled:
125125
# apply FSDP or HSDP, potentially with Context Parallel

0 commit comments

Comments
 (0)