Skip to content

Commit fd224b3

Browse files
tohskaigithubsgi
authored andcommitted
Add support for AC budget API (pytorch#1731)
Inspired by the blogpost: https://pytorch.org/blog/activation-checkpointing-techniques/
1 parent bc0383d commit fd224b3

File tree

7 files changed

+48
-11
lines changed

7 files changed

+48
-11
lines changed

torchtitan/config/job_config.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ class Checkpoint:
560560

561561
@dataclass
562562
class ActivationCheckpoint:
563-
mode: Literal["selective", "full", "none"] = "selective"
563+
mode: Literal["selective", "full", "memory_budget", "none"] = "selective"
564564
"""Type of activation checkpointing to use"""
565565

566566
selective_ac_option: str = "2"
@@ -589,6 +589,24 @@ class ActivationCheckpoint:
589589
rematerialized.
590590
"""
591591

592+
memory_budget: float = 0.5
593+
"""
594+
When mode is set to "memory_budget", this value determines how much
595+
partitioner in the compiler should trade off compute for memory.
596+
0.0 corresponds to the activation memory from applying
597+
activation checkpointing to the full compiled region, and 1.0 corresponds to
598+
the activation memory from the default runtime-optimized strategy. Read here:
599+
https://pytorch.org/blog/activation-checkpointing-techniques/
600+
"""
601+
602+
visualize_memory_budget_pareto: bool = False
603+
"""
604+
This dumps out a SVG visualization of the expected runtime vs. activation
605+
memory tradeoffs for all memory budget values from 0 to 1 in increments of
606+
0.05 in {--job.dump_folder}/memory_budget_pareto folder. See an example here:
607+
https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
608+
"""
609+
592610

593611
@dataclass
594612
class Compile:

torchtitan/distributed/activation_checkpoint.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# This file provides the util functions to apply activation checkpointing to the model.
88
# Technically, this is not a part of distributed, but distributed module is the best place to put it.
99

10+
import os
1011
from collections import defaultdict
1112

1213
import torch
@@ -293,6 +294,7 @@ def apply_ac(
293294
model_compile_enabled: bool = False,
294295
use_flex_attn: bool = False,
295296
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
297+
base_folder: str = "",
296298
) -> None:
297299
"""Apply activation checkpointing to the model.
298300
@@ -311,15 +313,27 @@ def apply_ac(
311313
None
312314
"""
313315

314-
for layer_id, transformer_block in model.layers.named_children():
315-
transformer_block = _apply_ac_to_transformer_block(
316-
transformer_block,
317-
ac_config,
318-
base_fqn=f"layers.{layer_id}",
319-
model_compile_enabled=model_compile_enabled,
320-
use_flex_attn=use_flex_attn,
321-
op_sac_save_list=op_sac_save_list,
322-
)
323-
model.layers.register_module(layer_id, transformer_block)
316+
if ac_config.mode == "memory_budget":
317+
assert model_compile_enabled, "Memory budget mode requires model to be compiled"
318+
if ac_config.visualize_memory_budget_pareto:
319+
pareto_dir = os.path.join(base_folder, "memory_budget_pareto")
320+
if not os.path.exists(pareto_dir):
321+
os.makedirs(pareto_dir, exist_ok=True)
322+
torch._functorch.config.memory_budget_pareto_dir = pareto_dir
323+
torch._functorch.config.visualize_memory_budget_pareto = True
324+
325+
torch._functorch.config.activation_memory_budget = ac_config.memory_budget
326+
logger.info(f"Selected {ac_config.memory_budget} budget option")
327+
else:
328+
for layer_id, transformer_block in model.layers.named_children():
329+
transformer_block = _apply_ac_to_transformer_block(
330+
transformer_block,
331+
ac_config,
332+
base_fqn=f"layers.{layer_id}",
333+
model_compile_enabled=model_compile_enabled,
334+
use_flex_attn=use_flex_attn,
335+
op_sac_save_list=op_sac_save_list,
336+
)
337+
model.layers.register_module(layer_id, transformer_block)
324338

325339
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def parallelize_llama(
120120
model_compile_enabled=model_compile_enabled,
121121
use_flex_attn=use_flex_attn,
122122
op_sac_save_list=_op_sac_save_list,
123+
base_folder=job_config.job.dump_folder,
123124
)
124125

125126
# turn on per-TransformerBlock compile after AC wrapping and before FSDP

torchtitan/experiments/qwen3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def parallelize_qwen3(
114114
model_compile_enabled=model_compile_enabled,
115115
use_flex_attn=use_flex_attn,
116116
op_sac_save_list=_op_sac_save_list,
117+
base_folder=job_config.job.dump_folder,
117118
)
118119

119120
# turn on per-TransformerBlock compile after AC wrapping and before FSDP

torchtitan/experiments/simple_fsdp/llama3/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def parallelize_llama(
8585
model_compile_enabled=model_compile_enabled,
8686
use_flex_attn=use_flex_attn,
8787
op_sac_save_list=_op_sac_save_list,
88+
base_folder=job_config.job.dump_folder,
8889
)
8990

9091
# apply data parallel

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def parallelize_deepseekv3(
113113
model_compile_enabled=model_compile_enabled,
114114
use_flex_attn=use_flex_attn,
115115
op_sac_save_list=_op_sac_save_list,
116+
base_folder=job_config.job.dump_folder,
116117
)
117118

118119
if model_compile_enabled:

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def parallelize_llama(
102102
model_compile_enabled=model_compile_enabled,
103103
use_flex_attn=use_flex_attn,
104104
op_sac_save_list=_op_sac_save_list,
105+
base_folder=job_config.job.dump_folder,
105106
)
106107

107108
# turn on per-TransformerBlock compile after AC wrapping and before FSDP

0 commit comments

Comments
 (0)