Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def parallelize_deepseekv3(
)

if model_compile_enabled:
apply_compile(model, job_config.compile)
apply_compile(model, job_config.compile, job_config.activation_checkpoint)

dp_mesh: DeviceMesh | None = None
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
Expand Down
79 changes: 67 additions & 12 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
from torch.distributed.tensor import Partial, Replicate, Shard
Expand All @@ -18,7 +21,10 @@
SequenceParallel,
)
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.config.job_config import Compile as CompileConfig
from torchtitan.config.job_config import (
ActivationCheckpoint as ACConfig,
Compile as CompileConfig,
)
from torchtitan.distributed import NoParallel, ParallelDims
from torchtitan.distributed.activation_checkpoint import apply_ac

Expand All @@ -30,6 +36,7 @@
)
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.models.llama3.infra.parallelize import apply_ddp
from torchtitan.models.moe import moe as moe_module
from torchtitan.tools.logging import logger


Expand Down Expand Up @@ -125,7 +132,7 @@ def parallelize_llama(

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if model_compile_enabled:
apply_compile(model, job_config.compile)
apply_compile(model, job_config.compile, job_config.activation_checkpoint)

dp_mesh: DeviceMesh | None = None
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
Expand Down Expand Up @@ -503,24 +510,72 @@ def apply_moe_ep_tp(
)


def apply_compile(model: nn.Module, compile_config: CompileConfig):
def apply_compile(model: nn.Module, compile_config: CompileConfig, ac_config: ACConfig):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
# NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE
# but it is experimental.
# torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_scalar_outputs = True
for layer_id, transformer_block in model.layers.named_children():
# TODO: remove when torch.compile supports fullgraph=True for MoE
fullgraph = True
if transformer_block.moe_enabled:
fullgraph = False
transformer_block = torch.compile(
transformer_block,
backend=compile_config.backend,
fullgraph=fullgraph,
)
# If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break
# So we must weave compile wrappers around those FSDP hooks to
# prevent AC from falling back the whole graph to eager.

if isinstance(transformer_block, CheckpointWrapper):
# unwrap so that .named_children() works
block = transformer_block._checkpoint_wrapped_module
else:
block = transformer_block

patch_experts = False
for attr_name, submod in block.named_children():
assert getattr(block, attr_name) == getattr(
transformer_block, attr_name
)

if isinstance(submod, moe_module.MoE):
# avoid graph breaking on the GroupedExperts' FSDP hooks
# by wrapping each submod's forward instead of their __call__
moe = submod
for attr_name, submod in moe.named_children():
if attr_name == "experts":
# don't compile token dispatch and combine due to some issue on b200
patch_experts = True
continue
setattr(
moe,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
),
)
else:
setattr(
block,
attr_name,
torch.compile(
submod, backend=compile_config.backend, fullgraph=True
),
)

if patch_experts:
moe_module._run_experts_grouped_mm = torch.compile(
moe_module._run_experts_grouped_mm,
backend=compile_config.backend,
fullgraph=True,
)
else:
# If it's not a MoE layer, there is no FSDP(GroupedExperts)
# So we can compile the whole block
transformer_block = torch.compile(
transformer_block,
backend=compile_config.backend,
fullgraph=True,
)

model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")
2 changes: 1 addition & 1 deletion torchtitan/models/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def parallelize_qwen3(

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if model_compile_enabled:
apply_compile(model, job_config.compile)
apply_compile(model, job_config.compile, job_config.activation_checkpoint)

if parallel_dims.fsdp_enabled:
# apply FSDP or HSDP, potentially with Context Parallel
Expand Down