|
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | import torch.nn as nn |
| 9 | +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| 10 | + CheckpointWrapper, |
| 11 | +) |
9 | 12 | from torch.distributed.device_mesh import DeviceMesh |
10 | 13 | from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy |
11 | 14 | from torch.distributed.tensor import Partial, Replicate, Shard |
|
18 | 21 | SequenceParallel, |
19 | 22 | ) |
20 | 23 | from torchtitan.config import JobConfig, TORCH_DTYPE_MAP |
21 | | -from torchtitan.config.job_config import Compile as CompileConfig |
| 24 | +from torchtitan.config.job_config import ( |
| 25 | + ActivationCheckpoint as ACConfig, |
| 26 | + Compile as CompileConfig, |
| 27 | +) |
22 | 28 | from torchtitan.distributed import NoParallel, ParallelDims |
23 | 29 | from torchtitan.distributed.activation_checkpoint import apply_ac |
24 | 30 |
|
|
30 | 36 | ) |
31 | 37 | from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp |
32 | 38 | from torchtitan.models.llama3.infra.parallelize import apply_ddp |
| 39 | +from torchtitan.models.moe.moe import MoE |
33 | 40 | from torchtitan.tools.logging import logger |
34 | 41 |
|
35 | 42 |
|
@@ -125,7 +132,7 @@ def parallelize_llama( |
125 | 132 |
|
126 | 133 | # turn on per-TransformerBlock compile after AC wrapping and before FSDP |
127 | 134 | if model_compile_enabled: |
128 | | - apply_compile(model, job_config.compile) |
| 135 | + apply_compile(model, job_config.compile, job_config.activation_checkpoint) |
129 | 136 |
|
130 | 137 | dp_mesh: DeviceMesh | None = None |
131 | 138 | if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: |
@@ -503,24 +510,60 @@ def apply_moe_ep_tp( |
503 | 510 | ) |
504 | 511 |
|
505 | 512 |
|
506 | | -def apply_compile(model: nn.Module, compile_config: CompileConfig): |
| 513 | +def apply_compile(model: nn.Module, compile_config: CompileConfig, ac_config: ACConfig): |
507 | 514 | """ |
508 | 515 | Apply torch.compile to each TransformerBlock, which makes compilation efficient due to |
509 | 516 | repeated structure. Alternatively one can compile the whole model (after applying DP). |
510 | 517 | """ |
511 | 518 | # NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE |
512 | 519 | # but it is experimental. |
513 | | - # torch._dynamo.config.capture_scalar_outputs = True |
| 520 | + torch._dynamo.config.capture_scalar_outputs = True |
514 | 521 | for layer_id, transformer_block in model.layers.named_children(): |
515 | | - # TODO: remove when torch.compile supports fullgraph=True for MoE |
516 | | - fullgraph = True |
517 | 522 | if transformer_block.moe_enabled: |
518 | | - fullgraph = False |
519 | | - transformer_block = torch.compile( |
520 | | - transformer_block, |
521 | | - backend=compile_config.backend, |
522 | | - fullgraph=fullgraph, |
523 | | - ) |
| 523 | + # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break |
| 524 | + # So we must weave compile wrappers around those FSDP hooks to |
| 525 | + # prevent AC from falling back the whole graph to eager. |
| 526 | + assert ( |
| 527 | + ac_config.mode != "selective" |
| 528 | + ), "Selective Activation Checkpointing + Compile is not yet supported for MoE models." |
| 529 | + |
| 530 | + if isinstance(transformer_block, CheckpointWrapper): |
| 531 | + # unwrap so that .named_children() works |
| 532 | + block = transformer_block._checkpoint_wrapped_module |
| 533 | + else: |
| 534 | + block = transformer_block |
| 535 | + |
| 536 | + for attr_name, submod in block.named_children(): |
| 537 | + if isinstance(submod, MoE): |
| 538 | + # avoid graph breaking on the GroupedExperts' FSDP hooks |
| 539 | + # by wrapping each submod's forward instead of their __call__ |
| 540 | + moe_key = attr_name |
| 541 | + moe = submod |
| 542 | + for attr_name, submod in moe.named_children(): |
| 543 | + setattr( |
| 544 | + moe, |
| 545 | + attr_name, |
| 546 | + torch.compile( |
| 547 | + submod, backend=compile_config.backend, fullgraph=True |
| 548 | + ), |
| 549 | + ) |
| 550 | + else: |
| 551 | + setattr( |
| 552 | + transformer_block, |
| 553 | + attr_name, |
| 554 | + torch.compile( |
| 555 | + submod, backend=compile_config.backend, fullgraph=True |
| 556 | + ), |
| 557 | + ) |
| 558 | + else: |
| 559 | + # If it's not a MoE layer, there is no FSDP(GroupedExperts) |
| 560 | + # So we can compile the whole block |
| 561 | + transformer_block = torch.compile( |
| 562 | + transformer_block, |
| 563 | + backend=compile_config.backend, |
| 564 | + fullgraph=True, |
| 565 | + ) |
| 566 | + |
524 | 567 | model.layers.register_module(layer_id, transformer_block) |
525 | 568 |
|
526 | 569 | logger.info("Compiling each TransformerBlock with torch.compile") |
0 commit comments