From 7ce8eeda4ff128297a61cbfe0e5774a095eb4fa2 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Tue, 7 Oct 2025 13:48:52 -0700 Subject: [PATCH 01/46] Second version of degub/deterinistic configs. Incorporating input from converastion in https://github.com/pytorch/torchtitan/pull/1761 --- torchtitan/config/job_config.py | 18 ++++++++++++ .../distributed/activation_checkpoint.py | 28 ++++++++++++++----- torchtitan/distributed/utils.py | 2 ++ .../llama4/train_configs/llama4_17bx16e.toml | 7 +++++ torchtitan/train.py | 3 +- 5 files changed, 50 insertions(+), 8 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 7137579f18..7a7d0b9b2a 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -880,6 +880,23 @@ def __post_init__(self): ), "validation steps must be positive or -1" +@dataclass +class Debug: + torch_deterministic: bool = False + """Use deterministic algorithms wherever possible, may be slower""" + + torch_deterministic_warn_only: bool = False + """Only warns about ops without deterministic implementations rather than erroring out """ + + torch_preserve_rng_state: bool = False + """If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower""" + + ac_determinism_check: str = "default" + """A string specifying the determinism function. """ + + ac_debug: bool = False + """ Capture ac debug information. Will be slower. """ + @dataclass class JobConfig: """ @@ -905,6 +922,7 @@ class JobConfig: fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) experimental: Experimental = field(default_factory=Experimental) validation: Validation = field(default_factory=Validation) + debug: Debug = field(default_factory=Debug) def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 57809c45f9..885ad94636 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -17,13 +17,14 @@ ) from torchtitan.config.job_config import ActivationCheckpoint as ACConfig +from torchtitan.config.job_config import Debug as DebugConfig from torchtitan.tools.logging import logger, warn_once _layer_sac_count = 0 -def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: +def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, dbg_config:DebugConfig) -> nn.Module: """Apply layer selective activation checkpointing to the module. Args: @@ -38,7 +39,11 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: ac_freq = int(ac_config.selective_ac_option) if not ac_freq or _layer_sac_count % ac_freq == 0: return ptd_checkpoint_wrapper( - module, preserve_rng_state=False, early_stop=ac_config.early_stop + module, + preserve_rng_state=dbg_config.torch_preserve_rng_state, + determinism_check=dbg_config.ac_determinism_check, + early_stop=ac_config.early_stop, + debug=dbg_config.ac_debug ) else: return module @@ -123,11 +128,13 @@ def selective_checkpointing_context_fn(): return create_selective_checkpoint_contexts(_get_custom_policy(meta)) return ptd_checkpoint_wrapper( - module, - context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, - early_stop=ac_config.early_stop, - ) + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=dbg_config.torch_preserve_rng_state, + determinism_check=dbg_config.ac_determinism_check, + early_stop=ac_config.early_stop, + debug=dbg_config.ac_debug + ) def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: @@ -143,6 +150,13 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: return ptd_checkpoint_wrapper( module, preserve_rng_state=False, early_stop=ac_config.early_stop ) + return ptd_checkpoint_wrapper( + module, + preserve_rng_state=dbg_config.torch_preserve_rng_state, + determinism_check=dbg_config.ac_determinism_check, + early_stop=ac_config.early_stop, + debug=dbg_config.ac_debug + ) def _apply_op_sac_to_transformer_block_with_flex( diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 67eb41280f..10518190b4 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -85,6 +85,7 @@ def set_determinism( device: torch.device, seed: int | None = None, deterministic: bool = False, + deterministic_warn_only: bool = False, distinct_seed_mesh_dim: str = "pp", ) -> None: """ @@ -100,6 +101,7 @@ def set_determinism( if deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") torch.use_deterministic_algorithms(True) + torch.use_deterministic_algorithms(True, warn_only=deterministic_warn_only) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # env var for deterministic CuBLAS diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml index 78e210c10a..13198bf7c0 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml @@ -68,3 +68,10 @@ filter_fqns = ["output", "router.gate"] [quantize.linear.mx] filter_fqns = ["output", "router.gate"] + +[debug] +torch_deterministic = false +torch_deterministic_warn_only = false +torch_preserve_rng_state = false +ac_determinism_check = "default" +ac_debug = false diff --git a/torchtitan/train.py b/torchtitan/train.py index 1d5e0e500a..fb505c20df 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -119,7 +119,8 @@ def __init__(self, job_config: JobConfig): world_mesh, self.device, job_config.training.seed, - job_config.training.deterministic, + job_config.debug.torch_deterministic, + job_config.debug.torch_deterministic_warn_only, ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name) From e06e1e90095bfda27589f68566ed715c10b00200 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Thu, 9 Oct 2025 19:12:22 -0700 Subject: [PATCH 02/46] Review relaeted updates. --- docs/debugging.md | 17 +++++++++-- torchtitan/config/__init__.py | 2 ++ torchtitan/config/job_config.py | 29 +++++++++---------- .../distributed/activation_checkpoint.py | 8 ++--- torchtitan/distributed/utils.py | 10 +++---- torchtitan/models/llama4/model/args.py | 2 +- .../llama4/train_configs/llama4_17bx16e.toml | 8 +++-- torchtitan/train.py | 4 +-- 8 files changed, 46 insertions(+), 34 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index f7758cbde5..5f45a49037 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -70,7 +70,7 @@ When debugging issues with multi-dimensional parallelism (combinations of FSDP, Set consistent random seeds across all parallelism dimensions: ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.seed 42 +CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.seed 42 ``` **Seed behavior with parallelism:** @@ -84,7 +84,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr Enable deterministic algorithms to ensure bit-for-bit reproducibility across runs: ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.deterministic +CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.deterministic ``` **What it does:** @@ -93,6 +93,19 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr - Sets deterministic workspace configuration for CuBLAS operations - **Note:** This will significantly reduce training performance but ensures exact reproducibility +Use --debug.deterministic.warn_only to only warn about (not stop running) kernel without deterministic implementation. + +### Activation Checkipointing Debugging ### + +The following debug configs are available for AC. + +ac_preserve_rng_state - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. + +ac_determinism_check - A string specifying the determinism function + +ac_debug - capture ac debug information. Will be slower. + +See https://docs.pytorch.org/docs/stable/checkpoint.html for details. ### Seed-Checkpoint-based Reproducibility diff --git a/torchtitan/config/__init__.py b/torchtitan/config/__init__.py index ba2795a601..e70d7fb622 100644 --- a/torchtitan/config/__init__.py +++ b/torchtitan/config/__init__.py @@ -28,6 +28,7 @@ Quantize, Training, Validation, + Debug ) from .manager import ConfigManager @@ -49,4 +50,5 @@ "Profiling", "Training", "Validation", + "Debug" ] diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 7a7d0b9b2a..4baccc0866 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -253,15 +253,6 @@ class Training: many temporary files. """ - seed: int | None = None - """Choose the base RNG seed used for training""" - - deterministic: bool = False - """Use deterministic algorithms wherever possible, may be slower""" - - debug_moe_force_load_balance: bool = False - """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" - @dataclass class Parallelism: @@ -882,20 +873,26 @@ def __post_init__(self): @dataclass class Debug: - torch_deterministic: bool = False - """Use deterministic algorithms wherever possible, may be slower""" + deterministic: bool = False + """Use deterministic algorithms wherever possible, may be slower""" - torch_deterministic_warn_only: bool = False + deterministic_warn_only: bool = False """Only warns about ops without deterministic implementations rather than erroring out """ - torch_preserve_rng_state: bool = False - """If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower""" + seed: int | None = None + """Choose the base RNG seed used for training""" + + ac_preserve_rng_state: bool = False + """If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details.""" ac_determinism_check: str = "default" - """A string specifying the determinism function. """ + """A string specifying the determinism function. See https://docs.pytorch.org/docs/stable/checkpoint.html for details.""" ac_debug: bool = False - """ Capture ac debug information. Will be slower. """ + """ Capture ac debug information. Will be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details.""" + + moe_force_load_balance: bool = False + """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" @dataclass class JobConfig: diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 885ad94636..5c23b521ad 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -24,7 +24,7 @@ _layer_sac_count = 0 -def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, dbg_config:DebugConfig) -> nn.Module: +def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, debug_config:DebugConfig) -> nn.Module: """Apply layer selective activation checkpointing to the module. Args: @@ -40,10 +40,10 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, dbg_config:DebugCon if not ac_freq or _layer_sac_count % ac_freq == 0: return ptd_checkpoint_wrapper( module, - preserve_rng_state=dbg_config.torch_preserve_rng_state, - determinism_check=dbg_config.ac_determinism_check, + preserve_rng_state=debug_config.torch_preserve_rng_state, + determinism_check=debug_config.ac_determinism_check, early_stop=ac_config.early_stop, - debug=dbg_config.ac_debug + debug=debug_config.ac_debug ) else: return module diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 10518190b4..8787519d86 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -18,6 +18,7 @@ from torch.distributed.tensor import DTensor from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP +from torchtitan.config import Debug as DebugConfig from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.tools.logging import logger from torchtitan.tools.utils import device_module, device_type @@ -83,9 +84,7 @@ def dist_mean( def set_determinism( world_mesh: DeviceMesh | None, device: torch.device, - seed: int | None = None, - deterministic: bool = False, - deterministic_warn_only: bool = False, + debug_config: DebugConfig, distinct_seed_mesh_dim: str = "pp", ) -> None: """ @@ -98,10 +97,10 @@ def set_determinism( Set Determinism flags for increased reproducibility with loss of performance. """ - if deterministic: + if debug_config.deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") torch.use_deterministic_algorithms(True) - torch.use_deterministic_algorithms(True, warn_only=deterministic_warn_only) + torch.use_deterministic_algorithms(True, warn_only=debug_config.deterministic_warn_only) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # env var for deterministic CuBLAS @@ -116,6 +115,7 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) + seed = debug_config.seed if not world_mesh: if seed is not None: torch.manual_seed(seed) diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index faeb60aadf..53043a1d02 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -82,7 +82,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args._debug_force_load_balance = ( - job_config.training.debug_moe_force_load_balance + job_config.debug.moe_force_load_balance ) def get_nparams_and_flops( diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml index 13198bf7c0..c374616f91 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml @@ -70,8 +70,10 @@ filter_fqns = ["output", "router.gate"] filter_fqns = ["output", "router.gate"] [debug] -torch_deterministic = false -torch_deterministic_warn_only = false -torch_preserve_rng_state = false +#seed = +deterministic = false +deterministic_warn_only = false +ac_preserve_rng_state = false ac_determinism_check = "default" ac_debug = false +moe_force_load_balance = false diff --git a/torchtitan/train.py b/torchtitan/train.py index fb505c20df..290f2368b2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -118,9 +118,7 @@ def __init__(self, job_config: JobConfig): dist_utils.set_determinism( world_mesh, self.device, - job_config.training.seed, - job_config.debug.torch_deterministic, - job_config.debug.torch_deterministic_warn_only, + job_config.debug, ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name) From b23e5bb099233b56a0a7cb8e6e18396faa80c538 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Fri, 10 Oct 2025 14:54:11 -0700 Subject: [PATCH 03/46] Review 2 related changes. --- docs/debugging.md | 8 ++++---- torchtitan/distributed/activation_checkpoint.py | 6 +++--- torchtitan/experiments/flux/train.py | 4 ++-- torchtitan/experiments/forge/engine.py | 4 ++-- torchtitan/experiments/forge/job_config.py | 2 ++ .../models/llama4/train_configs/llama4_17bx16e.toml | 9 --------- torchtitan/train.py | 7 +++++++ 7 files changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index 5f45a49037..fd436367e9 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -93,17 +93,17 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr - Sets deterministic workspace configuration for CuBLAS operations - **Note:** This will significantly reduce training performance but ensures exact reproducibility -Use --debug.deterministic.warn_only to only warn about (not stop running) kernel without deterministic implementation. +Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation. ### Activation Checkipointing Debugging ### The following debug configs are available for AC. -ac_preserve_rng_state - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. +`ac_preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. -ac_determinism_check - A string specifying the determinism function +`ac_determinism_check` - A string specifying the determinism function -ac_debug - capture ac debug information. Will be slower. +`ac_debug` - capture ac debug information. Will be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details. diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 5c23b521ad..131c7f7f58 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -40,7 +40,7 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, debug_config:DebugC if not ac_freq or _layer_sac_count % ac_freq == 0: return ptd_checkpoint_wrapper( module, - preserve_rng_state=debug_config.torch_preserve_rng_state, + preserve_rng_state=debug_config.ac_preserve_rng_state, determinism_check=debug_config.ac_determinism_check, early_stop=ac_config.early_stop, debug=debug_config.ac_debug @@ -130,7 +130,7 @@ def selective_checkpointing_context_fn(): return ptd_checkpoint_wrapper( module, context_fn=selective_checkpointing_context_fn, - preserve_rng_state=dbg_config.torch_preserve_rng_state, + preserve_rng_state=dbg_config.ac_preserve_rng_state, determinism_check=dbg_config.ac_determinism_check, early_stop=ac_config.early_stop, debug=dbg_config.ac_debug @@ -152,7 +152,7 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: ) return ptd_checkpoint_wrapper( module, - preserve_rng_state=dbg_config.torch_preserve_rng_state, + preserve_rng_state=dbg_config.ac_preserve_rng_state, determinism_check=dbg_config.ac_determinism_check, early_stop=ac_config.early_stop, debug=dbg_config.ac_debug diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 624792e83e..59a40c147b 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -35,8 +35,8 @@ def __init__(self, job_config: JobConfig): dist_utils.set_determinism( self.parallel_dims.world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug.seed, + job_config.debug.deterministic, distinct_seed_mesh_dim="dp_shard", ) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index f8b1412959..add19073f7 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -104,8 +104,8 @@ def __init__(self, job_config: ForgeJobConfig): dist_utils.set_determinism( world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug.seed, + job_config.debug.deterministic, ) self.train_spec = get_train_spec(job_config.model.name) diff --git a/torchtitan/experiments/forge/job_config.py b/torchtitan/experiments/forge/job_config.py index b1c014cc1d..d255bc0b72 100644 --- a/torchtitan/experiments/forge/job_config.py +++ b/torchtitan/experiments/forge/job_config.py @@ -20,6 +20,7 @@ Parallelism, Quantize, Training, + Debug, ) @@ -45,6 +46,7 @@ class ForgeJobConfig: # fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) # experimental: Experimental = field(default_factory=Experimental) # validation: Validation = field(default_factory=Validation) + debug: Debug = field(default_factory=Debug) def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml index c374616f91..78e210c10a 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml @@ -68,12 +68,3 @@ filter_fqns = ["output", "router.gate"] [quantize.linear.mx] filter_fqns = ["output", "router.gate"] - -[debug] -#seed = -deterministic = false -deterministic_warn_only = false -ac_preserve_rng_state = false -ac_determinism_check = "default" -ac_debug = false -moe_force_load_balance = false diff --git a/torchtitan/train.py b/torchtitan/train.py index 290f2368b2..a5cff0a7ee 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,6 +11,13 @@ from typing import Any, Generator, Iterable, Optional import torch + +try: + import intel_extension_for_pytorch as ipex + print ( f"IPEX found - hence using IPEX") +except: + print ( f"IPEX not found, hence not using") + from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module From bcb08945173346636a02eb0409ddf8eb52482930 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Wed, 8 Oct 2025 10:19:37 -0700 Subject: [PATCH 04/46] [DSV3] Offload dequantization process to DCP QuantizedHFReader (#1804) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Benchmarking
Step | time | log -- | -- | -- to_hf() | 0.1103s | [trainer0\|0]:[titan] 2025-10-03 17:07:45,697 - root - INFO - Completed to_hf conversion, generated 189 keys, duration: 0.1103s Split local GroupedExperts DTensor to individual experts’ weight | 0.008 s per layer per matrix (total 58 MoE Layers * 3 weight matrices per layer) | [trainer0\|0]:[titan] 2025-10-03 17:07:45,697 - root - INFO - Completed _get_local_experts_weights for layer 6, abstract_key: model.layers.{}.mlp.experts.{}.up_proj.weight, duration: 0.0082s dcp.load()Threads count=4 | 193.20s | [trainer0\|0]:[titan] 2025-10-03 17:10:58,899 - root - INFO - dcp.load with HuggingFaceStorageReader completed in 193.20 seconds from_hf() | 0.48s | [trainer0\|0]:[titan] 2025-10-03 17:10:59,378 - root - INFO - Completed from_hf conversion, processed 189 keys, duration: 0.4787s Concatenate individual experts weight into GroupedExperts weight | 0.01s per layer per matrix (total 58 MoE Layers * 3 weight matrices) | [trainer0\|0]:[titan] 2025-10-03 17:10:59,120 - root - INFO - Completed _concatenate_expert_weights_dtensor for layer 5, abstract_key: layers.{}.moe.experts.w2, duration: 0.0142s Total | 193.87s | [trainer0\|0]:[titan] 2025-10-03 17:10:59,458 - root - INFO - Finished loading the checkpoint in 193.87 seconds.
## End-to-End verification for 671B model Parallelsim: FSDP=32, PP=8, 1F1B, EP=32 Screenshot 2025-10-06 at 8 32 37 PM Screenshot 2025-10-06 at 8 32 54 PM --- torchtitan/models/deepseek_v3/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 51bd7ea922..1b9c5f9500 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -132,7 +132,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=61, + n_layers=4, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( From 93cd4c6a67ff9f94fb528c99754c4decd4b8c247 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 8 Oct 2025 10:45:40 -0700 Subject: [PATCH 05/46] Disable FlexAttention max-autotune when deterministic is used (#1808) With max-autotune, FlexAttention is not deterministic even if torch.use_deterministic_algorithms is True. When deterministic mode is set, we should also remove the usage of `max-autotune`. --- torchtitan/distributed/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 8787519d86..0796c36081 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -116,6 +116,14 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) seed = debug_config.seed + # Ensure flex_attention is compiled without max-autotune. This is needed to ensure + # reproducibility, since the autotune results may not be deterministic. + from torch.nn.attention.flex_attention import flex_attention + + from torchtitan.models.attention import FlexAttention + + FlexAttention.flex_attn = torch.compile(flex_attention) + if not world_mesh: if seed is not None: torch.manual_seed(seed) From 44acac84f1bc998a25691cfb56c11fb99ab24bb2 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Thu, 9 Oct 2025 10:37:43 -0700 Subject: [PATCH 06/46] Fix num of layers for deepseek-v3 (#1845) Fix the number of layer issue introduced by #1804 --- torchtitan/models/deepseek_v3/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 1b9c5f9500..51bd7ea922 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -132,7 +132,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=4, + n_layers=61, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( From 702966147b3b6349ee9a9a0d0625ce49001a03d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ph=C3=BAc=20H=2E=20L=C3=AA=20Kh=E1=BA=AFc?= Date: Fri, 10 Oct 2025 04:39:28 +0700 Subject: [PATCH 07/46] [VLM] Add token-imbalance loss (#1803) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In VLM interleaved training, with native resolution and aspect ratio, the number of tokens participating in loss computation differ per rank. Naive FSDP gradient averaging across data ranks can causes tokens on ranks with fewer valid tokens to contribute more to the loss than on other ranks. This PR address this via loss balancing, which incur an additional comm in the loss computation. In practice, I haven't notice any impacts from this comm. #### Quick sanity check Let have a sum loss of all tokens on each rank i, with $N_i$ number of tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i = \sum_{j=1}^{N_i}\nabla\ell_{ij}$ If we multiply the *loss* on each rank by a constant factor **c** (the same for all ranks), then after `backward()`: $$ \tilde g_i = c \cdot g_i . $$ FSDP will *average* these gradients across ranks: $$ g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i =\frac{c}{R}\sum_{i=1}^{R} g_i . $$ We want this to equal the **global‑sample average**: $$ g_{\text{true}} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla \ell_{ij} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i . $$ Thus for FSDP gradient to be correct, we need $$ \frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad c=\frac{R}{N_{\text{total}}}. $$ So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide the per-rank sum loss with $N_{\text{total}}/R$, which is **average number of tokens per rank**. Intuitively, this is the same as default cross-entropy loss, but instead of diving sum loss on a rank by the number of tokens **on that rank**, we now divide by the **average number of tokens across all rank** P/s: sorry this PR is based on #1802 but I couldn't choose that as the base branch. Maybe it will be easier to review once that PR is merged. --- torchtitan/experiments/vlm/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py index 49e3d89e72..910a4f1bb8 100644 --- a/torchtitan/experiments/vlm/__init__.py +++ b/torchtitan/experiments/vlm/__init__.py @@ -7,7 +7,6 @@ from dataclasses import fields from typing import Any -from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer From 594fe9c9a27d70fbabb531b6031b97b6b279af06 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Thu, 9 Oct 2025 18:22:41 -0700 Subject: [PATCH 08/46] refactor TrainSpec to remove the name field (#1850) --- torchtitan/experiments/vlm/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py index 910a4f1bb8..49e3d89e72 100644 --- a/torchtitan/experiments/vlm/__init__.py +++ b/torchtitan/experiments/vlm/__init__.py @@ -7,6 +7,7 @@ from dataclasses import fields from typing import Any +from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer From ee49f75075beff4d1c59029857acc38c9050cb69 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 9 Oct 2025 23:03:57 -0700 Subject: [PATCH 09/46] Refactor attention and make attention mask an argument to the model (#1776) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #1797 * __->__ #1776 **Status** 1. Change all models, including the experimental ones. 2. E2E loss verification. 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a separate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is https://github.com/pytorch/torchtitan/issues/1723. https://github.com/pytorch/pytorch/pull/164111/ proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix https://github.com/pytorch/torchtitan/issues/1723 with https://github.com/pytorch/pytorch/pull/164111/ and this PR. 3. Now SDPA and FlexAttention are wrapped in two different classes. ~~Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certainly can be confusion for Titan's users. I'm opnn to merge them to AttentionOp.~~ See the discussion in https://github.com/pytorch/torchtitan/issues/1723. **Verification** *llama3* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" ``` *llama3 flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` *llama4* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *llama4 irope* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *deepseek* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ``` *deepseek flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` --- torchtitan/distributed/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 0796c36081..977f189528 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -120,9 +120,9 @@ def set_determinism( # reproducibility, since the autotune results may not be deterministic. from torch.nn.attention.flex_attention import flex_attention - from torchtitan.models.attention import FlexAttention + from torchtitan.models.attention import FlexAttentionWrapper - FlexAttention.flex_attn = torch.compile(flex_attention) + FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) if not world_mesh: if seed is not None: From 5ba348881958465aa15fa7fbb304c3fec775ee77 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Sat, 11 Oct 2025 22:49:54 -0700 Subject: [PATCH 10/46] minor refactor over EP (#1854) This PR: - let `ExpertParallel` handles indices permute / unpermute when EP is used - move `to_local` to model code to be more explicit - rename the `expert_parallel` wrapper which does permute / unpermute to `indices_permutation_wrapper` to be more accurate --- torchtitan/distributed/expert_parallel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index e9986b9974..5ed36a6317 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -227,7 +227,6 @@ def __init__(self): def _prepare_inputput_fn(self, mod, inputs, device_mesh): # shape (batch_size*seq_len, top_k) top_scores, selected_experts_indices = inputs - num_tokens, _ = top_scores.shape # NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree # if top_scores.shape[0] % device_mesh.size() != 0: From e11ea4be7b6ea0d69ccc3310231b33976f9a1276 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Sun, 12 Oct 2025 22:30:46 -0700 Subject: [PATCH 11/46] Graduate qwen3 from experiment to core (#1860) As titled. Added CI for test, fix minor TP issue after adding attention_mask --- tests/integration_tests/models.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index 37f588765b..be076d4684 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -110,6 +110,34 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "llama4_pp+fsdp+tp+ep+compile", ngpu=8, ), + # Integration Test Cases for Qwen3 dense and MoE model + OverrideDefinitions( + [ + [ + "--model.name qwen3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + ], + ], + "Qwen3 FSDP+TP", + "qwen3_fsdp+tp", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name qwen3", + "--model.flavor debugmodel_moe", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 2", + "--parallelism.expert_tensor_parallel_degree 2", + ], + ], + "Qwen3 FSDP+TP+EP+ETP", + "qwen3_fsdp+tp+ep+etp", + ngpu=4, + ), ] return model_tests From a92059badef4ea5d64395e943f7eb096e923acea Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Thu, 9 Oct 2025 19:12:22 -0700 Subject: [PATCH 12/46] Review related updates. --- torchtitan/distributed/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 977f189528..e716de2d5a 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -124,6 +124,7 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) + seed = debug_config.seed if not world_mesh: if seed is not None: torch.manual_seed(seed) From e53255aa85873d581bab9e1d81c91ee0a992a09e Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Mon, 13 Oct 2025 12:50:17 -0700 Subject: [PATCH 13/46] Rebasing and adding MATH attention kernel. --- torchtitan/models/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index bf963a5b5f..2309c844a4 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -90,6 +90,7 @@ def __init__(self) -> None: SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH ] def forward( From f30caf64cbb8381f0b7e543b02189da819c4dc78 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Mon, 13 Oct 2025 13:40:35 -0700 Subject: [PATCH 14/46] Indent issue fix. --- torchtitan/distributed/utils.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index e716de2d5a..8787519d86 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -115,15 +115,6 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) - seed = debug_config.seed - # Ensure flex_attention is compiled without max-autotune. This is needed to ensure - # reproducibility, since the autotune results may not be deterministic. - from torch.nn.attention.flex_attention import flex_attention - - from torchtitan.models.attention import FlexAttentionWrapper - - FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) - seed = debug_config.seed if not world_mesh: if seed is not None: From f4cbf9dc00e59c54b40254df2e073c1adf24d62f Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Mon, 13 Oct 2025 15:13:30 -0700 Subject: [PATCH 15/46] Removing ipex. --- torchtitan/train.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index a5cff0a7ee..8a4bcf759e 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,12 +12,6 @@ import torch -try: - import intel_extension_for_pytorch as ipex - print ( f"IPEX found - hence using IPEX") -except: - print ( f"IPEX not found, hence not using") - from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module From 00c316548c273b1ee4c41377a7525164babfee00 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Mon, 13 Oct 2025 19:47:43 -0700 Subject: [PATCH 16/46] Review updates. --- torchtitan/config/job_config.py | 6 +++--- torchtitan/experiments/flux/train.py | 3 +-- torchtitan/experiments/forge/engine.py | 3 +-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 4baccc0866..7fb3f013c3 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -873,15 +873,15 @@ def __post_init__(self): @dataclass class Debug: + seed: int | None = None + """Choose the base RNG seed used for training""" + deterministic: bool = False """Use deterministic algorithms wherever possible, may be slower""" deterministic_warn_only: bool = False """Only warns about ops without deterministic implementations rather than erroring out """ - seed: int | None = None - """Choose the base RNG seed used for training""" - ac_preserve_rng_state: bool = False """If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details.""" diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 59a40c147b..33018f95fc 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -35,8 +35,7 @@ def __init__(self, job_config: JobConfig): dist_utils.set_determinism( self.parallel_dims.world_mesh, self.device, - job_config.debug.seed, - job_config.debug.deterministic, + job_config.debug, distinct_seed_mesh_dim="dp_shard", ) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index add19073f7..1d3a420b9d 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -104,8 +104,7 @@ def __init__(self, job_config: ForgeJobConfig): dist_utils.set_determinism( world_mesh, self.device, - job_config.debug.seed, - job_config.debug.deterministic, + job_config.debug, ) self.train_spec = get_train_spec(job_config.model.name) From db187ff9fc4165772517747319bdff4410f97d41 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Tue, 14 Oct 2025 13:01:13 -0700 Subject: [PATCH 17/46] Fixing linter error. --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 8a4bcf759e..c0951063ec 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -119,7 +119,7 @@ def __init__(self, job_config: JobConfig): dist_utils.set_determinism( world_mesh, self.device, - job_config.debug, + job_config.debug, ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name) From 315ea57c6a72f5ec3431318665dbad4d24c5d595 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 13 Oct 2025 20:55:21 -0700 Subject: [PATCH 18/46] graduate llama4 to core (#1865) This PR also - updates the README of `deepseek_v3` folder. - move the `generate_permute_indices` triton kernel to `torchtitan/models/moe` so that core doesn't depend on `experiments` - deprecate the gpu checkpoint conversion scripts as now we natively support loading checkpoint from HF using GPUs (although it is only using GPU when doing online conversion right before training starts) --- tests/integration_tests/models.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index be076d4684..0b69ef806a 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -138,6 +138,24 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "qwen3_fsdp+tp+ep+etp", ngpu=4, ), + # Integration Test Cases for Llama 4 + OverrideDefinitions( + [ + [ + "--model.name llama4", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule Interleaved1F1B", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 4", + "--parallelism.expert_tensor_parallel_degree 1", + "--compile.enable", + ], + ], + "Llama 4 PP+FSDP+TP+EP+compile", + "llama4_pp+fsdp+tp+ep+compile", + ngpu=8, + ), ] return model_tests From 64b77de1e0549d1a682bd765d43f38db5ba0b809 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 13 Oct 2025 22:48:15 -0700 Subject: [PATCH 19/46] consolidate experiments/deepseek_v3 (#1869) As titled, given the `experiments/deepseek_v3` has been out of maintenance for long time. The folder could still be valuable, so I'm keeping the content in the branch `experiments/deepseek_v3` as reference https://github.com/pytorch/torchtitan/tree/experiments/deepseek_v3/torchtitan/experiments/deepseek_v3 This PR keeps the symmetric memory kernels for EP communication, whose integration will come later. --- .../experiments/simple_fsdp/job_config.py | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 torchtitan/experiments/simple_fsdp/job_config.py diff --git a/torchtitan/experiments/simple_fsdp/job_config.py b/torchtitan/experiments/simple_fsdp/job_config.py deleted file mode 100644 index a7e7c4c22f..0000000000 --- a/torchtitan/experiments/simple_fsdp/job_config.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, field - - -@dataclass -class Compile: - model_backend_override: str | None = None - """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" - - -@dataclass -class JobConfig: - compile: Compile = field(default_factory=Compile) From fa5084083302b23e99c251e82e76bad06f6ea8fe Mon Sep 17 00:00:00 2001 From: Ruisi Zhang Date: Tue, 14 Oct 2025 14:16:34 -0700 Subject: [PATCH 20/46] add auto_eager_graph_pass (#1813) This pr adds the autobucketing pass at aten-level to simplefsdp. It runs autobucketing + aot_eager backend without inductor. The aten fx autobucketing pass can be find in this PR: https://github.com/pytorch/pytorch/pull/163960. Key updates are: 1. Support customized `aot_eger_autobucketing` backend to perform autobucketing optimization. 2. In simplefsdp, the model_backend can be replaced by user's customized passes using `compile.model_backend_override`. --- .../experiments/simple_fsdp/job_config.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 torchtitan/experiments/simple_fsdp/job_config.py diff --git a/torchtitan/experiments/simple_fsdp/job_config.py b/torchtitan/experiments/simple_fsdp/job_config.py new file mode 100644 index 0000000000..a7e7c4c22f --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/job_config.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class Compile: + model_backend_override: str | None = None + """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" + + +@dataclass +class JobConfig: + compile: Compile = field(default_factory=Compile) From 00dbd5a1537bdb6d0ac5fbad4b1d876e145008a9 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Tue, 7 Oct 2025 13:48:52 -0700 Subject: [PATCH 21/46] Second version of degub/deterinistic configs. Incorporating input from converastion in https://github.com/pytorch/torchtitan/pull/1761 --- torchtitan/models/llama4/train_configs/llama4_17bx16e.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml index 78e210c10a..13198bf7c0 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml @@ -68,3 +68,10 @@ filter_fqns = ["output", "router.gate"] [quantize.linear.mx] filter_fqns = ["output", "router.gate"] + +[debug] +torch_deterministic = false +torch_deterministic_warn_only = false +torch_preserve_rng_state = false +ac_determinism_check = "default" +ac_debug = false From 8bdb11dda81efd9992fcbd8209f8ad608d382a0d Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Thu, 9 Oct 2025 19:12:22 -0700 Subject: [PATCH 22/46] Review relaeted updates. --- .../models/llama4/train_configs/llama4_17bx16e.toml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml index 13198bf7c0..c374616f91 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml @@ -70,8 +70,10 @@ filter_fqns = ["output", "router.gate"] filter_fqns = ["output", "router.gate"] [debug] -torch_deterministic = false -torch_deterministic_warn_only = false -torch_preserve_rng_state = false +#seed = +deterministic = false +deterministic_warn_only = false +ac_preserve_rng_state = false ac_determinism_check = "default" ac_debug = false +moe_force_load_balance = false From a6a1bab04686b310c10495b9c4ee5807599aab5a Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Tue, 14 Oct 2025 17:55:15 -0700 Subject: [PATCH 23/46] Review 2 related changes. --- .../models/llama4/train_configs/llama4_17bx16e.toml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml index c374616f91..78e210c10a 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml @@ -68,12 +68,3 @@ filter_fqns = ["output", "router.gate"] [quantize.linear.mx] filter_fqns = ["output", "router.gate"] - -[debug] -#seed = -deterministic = false -deterministic_warn_only = false -ac_preserve_rng_state = false -ac_determinism_check = "default" -ac_debug = false -moe_force_load_balance = false From 2e8585c6dbd8c708d1e4836703f87c0cceeecb54 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Wed, 8 Oct 2025 10:19:37 -0700 Subject: [PATCH 24/46] [DSV3] Offload dequantization process to DCP QuantizedHFReader (#1804) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Benchmarking
Step | time | log -- | -- | -- to_hf() | 0.1103s | [trainer0\|0]:[titan] 2025-10-03 17:07:45,697 - root - INFO - Completed to_hf conversion, generated 189 keys, duration: 0.1103s Split local GroupedExperts DTensor to individual experts’ weight | 0.008 s per layer per matrix (total 58 MoE Layers * 3 weight matrices per layer) | [trainer0\|0]:[titan] 2025-10-03 17:07:45,697 - root - INFO - Completed _get_local_experts_weights for layer 6, abstract_key: model.layers.{}.mlp.experts.{}.up_proj.weight, duration: 0.0082s dcp.load()Threads count=4 | 193.20s | [trainer0\|0]:[titan] 2025-10-03 17:10:58,899 - root - INFO - dcp.load with HuggingFaceStorageReader completed in 193.20 seconds from_hf() | 0.48s | [trainer0\|0]:[titan] 2025-10-03 17:10:59,378 - root - INFO - Completed from_hf conversion, processed 189 keys, duration: 0.4787s Concatenate individual experts weight into GroupedExperts weight | 0.01s per layer per matrix (total 58 MoE Layers * 3 weight matrices) | [trainer0\|0]:[titan] 2025-10-03 17:10:59,120 - root - INFO - Completed _concatenate_expert_weights_dtensor for layer 5, abstract_key: layers.{}.moe.experts.w2, duration: 0.0142s Total | 193.87s | [trainer0\|0]:[titan] 2025-10-03 17:10:59,458 - root - INFO - Finished loading the checkpoint in 193.87 seconds.
## End-to-End verification for 671B model Parallelsim: FSDP=32, PP=8, 1F1B, EP=32 Screenshot 2025-10-06 at 8 32 37 PM Screenshot 2025-10-06 at 8 32 54 PM --- torchtitan/models/deepseek_v3/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 51bd7ea922..1b9c5f9500 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -132,7 +132,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=61, + n_layers=4, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( From 2795956692d6321f83ca56891dc264f5943e2d20 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 8 Oct 2025 10:45:40 -0700 Subject: [PATCH 25/46] Disable FlexAttention max-autotune when deterministic is used (#1808) With max-autotune, FlexAttention is not deterministic even if torch.use_deterministic_algorithms is True. When deterministic mode is set, we should also remove the usage of `max-autotune`. --- torchtitan/distributed/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 8787519d86..0796c36081 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -116,6 +116,14 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) seed = debug_config.seed + # Ensure flex_attention is compiled without max-autotune. This is needed to ensure + # reproducibility, since the autotune results may not be deterministic. + from torch.nn.attention.flex_attention import flex_attention + + from torchtitan.models.attention import FlexAttention + + FlexAttention.flex_attn = torch.compile(flex_attention) + if not world_mesh: if seed is not None: torch.manual_seed(seed) From 1432c099b8a70190c327353fc4de2d6d64110e67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ph=C3=BAc=20H=2E=20L=C3=AA=20Kh=E1=BA=AFc?= Date: Fri, 10 Oct 2025 04:39:28 +0700 Subject: [PATCH 26/46] [VLM] Add token-imbalance loss (#1803) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In VLM interleaved training, with native resolution and aspect ratio, the number of tokens participating in loss computation differ per rank. Naive FSDP gradient averaging across data ranks can causes tokens on ranks with fewer valid tokens to contribute more to the loss than on other ranks. This PR address this via loss balancing, which incur an additional comm in the loss computation. In practice, I haven't notice any impacts from this comm. #### Quick sanity check Let have a sum loss of all tokens on each rank i, with $N_i$ number of tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i = \sum_{j=1}^{N_i}\nabla\ell_{ij}$ If we multiply the *loss* on each rank by a constant factor **c** (the same for all ranks), then after `backward()`: $$ \tilde g_i = c \cdot g_i . $$ FSDP will *average* these gradients across ranks: $$ g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i =\frac{c}{R}\sum_{i=1}^{R} g_i . $$ We want this to equal the **global‑sample average**: $$ g_{\text{true}} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla \ell_{ij} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i . $$ Thus for FSDP gradient to be correct, we need $$ \frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad c=\frac{R}{N_{\text{total}}}. $$ So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide the per-rank sum loss with $N_{\text{total}}/R$, which is **average number of tokens per rank**. Intuitively, this is the same as default cross-entropy loss, but instead of diving sum loss on a rank by the number of tokens **on that rank**, we now divide by the **average number of tokens across all rank** P/s: sorry this PR is based on #1802 but I couldn't choose that as the base branch. Maybe it will be easier to review once that PR is merged. --- torchtitan/experiments/vlm/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py index 49e3d89e72..910a4f1bb8 100644 --- a/torchtitan/experiments/vlm/__init__.py +++ b/torchtitan/experiments/vlm/__init__.py @@ -7,7 +7,6 @@ from dataclasses import fields from typing import Any -from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer From 1b88f57de4f71a734a43e0e0e081adb4aad835a3 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Thu, 9 Oct 2025 18:22:41 -0700 Subject: [PATCH 27/46] refactor TrainSpec to remove the name field (#1850) --- torchtitan/experiments/vlm/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/experiments/vlm/__init__.py b/torchtitan/experiments/vlm/__init__.py index 910a4f1bb8..49e3d89e72 100644 --- a/torchtitan/experiments/vlm/__init__.py +++ b/torchtitan/experiments/vlm/__init__.py @@ -7,6 +7,7 @@ from dataclasses import fields from typing import Any +from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer From 139926b96af686a26b9af017afc1a919a7b8b9b8 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 9 Oct 2025 23:03:57 -0700 Subject: [PATCH 28/46] Refactor attention and make attention mask an argument to the model (#1776) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #1797 * __->__ #1776 **Status** 1. Change all models, including the experimental ones. 2. E2E loss verification. 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a separate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is https://github.com/pytorch/torchtitan/issues/1723. https://github.com/pytorch/pytorch/pull/164111/ proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix https://github.com/pytorch/torchtitan/issues/1723 with https://github.com/pytorch/pytorch/pull/164111/ and this PR. 3. Now SDPA and FlexAttention are wrapped in two different classes. ~~Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certainly can be confusion for Titan's users. I'm opnn to merge them to AttentionOp.~~ See the discussion in https://github.com/pytorch/torchtitan/issues/1723. **Verification** *llama3* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" ``` *llama3 flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` *llama4* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *llama4 irope* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *deepseek* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ``` *deepseek flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` --- torchtitan/distributed/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 0796c36081..977f189528 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -120,9 +120,9 @@ def set_determinism( # reproducibility, since the autotune results may not be deterministic. from torch.nn.attention.flex_attention import flex_attention - from torchtitan.models.attention import FlexAttention + from torchtitan.models.attention import FlexAttentionWrapper - FlexAttention.flex_attn = torch.compile(flex_attention) + FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) if not world_mesh: if seed is not None: From 087dc88f0113f2c7d39da0c671d5f7183c80c982 Mon Sep 17 00:00:00 2001 From: Tushar Jain <8455015+tushar00jain@users.noreply.github.com> Date: Fri, 10 Oct 2025 12:59:57 -0400 Subject: [PATCH 29/46] add script to train with ft (#1812) Summary: the script adds configuration options to run training locally with ft enabled --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1812). * #1840 * #1811 * #1810 * __->__ #1812 * #1809 --------- Co-authored-by: Tushar Jain --- torchtitan/models/llama3_ft/train_configs/debug_model.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchtitan/models/llama3_ft/train_configs/debug_model.toml b/torchtitan/models/llama3_ft/train_configs/debug_model.toml index 883d08bced..a9a6b87aa2 100644 --- a/torchtitan/models/llama3_ft/train_configs/debug_model.toml +++ b/torchtitan/models/llama3_ft/train_configs/debug_model.toml @@ -91,3 +91,9 @@ num_fragments = 2 semi_sync_method = "diloco" process_group = "nccl" process_group_timeout_ms = 10000 +<<<<<<< HEAD +======= + +[experimental] +custom_args_module = "torchtitan.components.ft.config" +>>>>>>> f63037ff (add script to train with ft (#1812)) From 5032db6a884e9c13fe2d61a0118cfb8fa164e4b9 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Mon, 13 Oct 2025 13:40:35 -0700 Subject: [PATCH 30/46] Indent issue fix. --- torchtitan/distributed/utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 977f189528..8787519d86 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -116,14 +116,6 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) seed = debug_config.seed - # Ensure flex_attention is compiled without max-autotune. This is needed to ensure - # reproducibility, since the autotune results may not be deterministic. - from torch.nn.attention.flex_attention import flex_attention - - from torchtitan.models.attention import FlexAttentionWrapper - - FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) - if not world_mesh: if seed is not None: torch.manual_seed(seed) From 409da11295489da407acd58c1bbafec98caea45e Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Tue, 14 Oct 2025 17:45:17 -0700 Subject: [PATCH 31/46] Post rebase changes. --- torchtitan/models/llama3_ft/train_configs/debug_model.toml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchtitan/models/llama3_ft/train_configs/debug_model.toml b/torchtitan/models/llama3_ft/train_configs/debug_model.toml index a9a6b87aa2..883d08bced 100644 --- a/torchtitan/models/llama3_ft/train_configs/debug_model.toml +++ b/torchtitan/models/llama3_ft/train_configs/debug_model.toml @@ -91,9 +91,3 @@ num_fragments = 2 semi_sync_method = "diloco" process_group = "nccl" process_group_timeout_ms = 10000 -<<<<<<< HEAD -======= - -[experimental] -custom_args_module = "torchtitan.components.ft.config" ->>>>>>> f63037ff (add script to train with ft (#1812)) From 2c35b957fcac0f26f042886aa95962f470f91808 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Tue, 7 Oct 2025 13:48:52 -0700 Subject: [PATCH 32/46] Second version of degub/deterinistic configs. Incorporating input from converastion in https://github.com/pytorch/torchtitan/pull/1761 --- torchtitan/models/llama4/train_configs/llama4_17bx16e.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml index 78e210c10a..13198bf7c0 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml @@ -68,3 +68,10 @@ filter_fqns = ["output", "router.gate"] [quantize.linear.mx] filter_fqns = ["output", "router.gate"] + +[debug] +torch_deterministic = false +torch_deterministic_warn_only = false +torch_preserve_rng_state = false +ac_determinism_check = "default" +ac_debug = false From 27a942fb54c758c4e8b64a3fb701cdf67e68d492 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Thu, 9 Oct 2025 19:12:22 -0700 Subject: [PATCH 33/46] Review relaeted updates. --- docs/debugging.md | 4 ++++ .../models/llama4/train_configs/llama4_17bx16e.toml | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index fd436367e9..03879703ea 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -93,7 +93,11 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr - Sets deterministic workspace configuration for CuBLAS operations - **Note:** This will significantly reduce training performance but ensures exact reproducibility +<<<<<<< HEAD Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation. +======= +Use --debug.deterministic.warn_only to only warn about (not stop running) kernel without deterministic implementation. +>>>>>>> 0b9a2b71 (Review relaeted updates.) ### Activation Checkipointing Debugging ### diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml index 13198bf7c0..c374616f91 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml @@ -70,8 +70,10 @@ filter_fqns = ["output", "router.gate"] filter_fqns = ["output", "router.gate"] [debug] -torch_deterministic = false -torch_deterministic_warn_only = false -torch_preserve_rng_state = false +#seed = +deterministic = false +deterministic_warn_only = false +ac_preserve_rng_state = false ac_determinism_check = "default" ac_debug = false +moe_force_load_balance = false From 93e6d5e69e2e3dee704194d9369014d18f685af9 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Fri, 10 Oct 2025 14:54:11 -0700 Subject: [PATCH 34/46] Review 2 related changes. --- docs/debugging.md | 4 ---- .../models/llama4/train_configs/llama4_17bx16e.toml | 9 --------- 2 files changed, 13 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index 03879703ea..fd436367e9 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -93,11 +93,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr - Sets deterministic workspace configuration for CuBLAS operations - **Note:** This will significantly reduce training performance but ensures exact reproducibility -<<<<<<< HEAD Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation. -======= -Use --debug.deterministic.warn_only to only warn about (not stop running) kernel without deterministic implementation. ->>>>>>> 0b9a2b71 (Review relaeted updates.) ### Activation Checkipointing Debugging ### diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml index c374616f91..78e210c10a 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml @@ -68,12 +68,3 @@ filter_fqns = ["output", "router.gate"] [quantize.linear.mx] filter_fqns = ["output", "router.gate"] - -[debug] -#seed = -deterministic = false -deterministic_warn_only = false -ac_preserve_rng_state = false -ac_determinism_check = "default" -ac_debug = false -moe_force_load_balance = false From ff832d2f0d039c287026d753d86d47dd059442b4 Mon Sep 17 00:00:00 2001 From: Tushar Jain <8455015+tushar00jain@users.noreply.github.com> Date: Fri, 10 Oct 2025 12:59:57 -0400 Subject: [PATCH 35/46] add script to train with ft (#1812) Summary: the script adds configuration options to run training locally with ft enabled --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1812). * #1840 * #1811 * #1810 * __->__ #1812 * #1809 --------- Co-authored-by: Tushar Jain --- torchtitan/models/llama3_ft/train_configs/debug_model.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchtitan/models/llama3_ft/train_configs/debug_model.toml b/torchtitan/models/llama3_ft/train_configs/debug_model.toml index 883d08bced..b8f2d7989d 100644 --- a/torchtitan/models/llama3_ft/train_configs/debug_model.toml +++ b/torchtitan/models/llama3_ft/train_configs/debug_model.toml @@ -1,8 +1,13 @@ [job] dump_folder = "./outputs" +<<<<<<< HEAD description = "Llama 3 fault-tolerant debug training" print_config = false custom_config_module = "torchtitan.components.ft.config" +======= +description = "Llama 3 debug training" +print_args = false +>>>>>>> f63037ff (add script to train with ft (#1812)) [profiling] enable_profiling = true From 6bb6254a58dda6cb1bda4164ad8b79f9cc2ec7fe Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Sat, 11 Oct 2025 22:49:54 -0700 Subject: [PATCH 36/46] minor refactor over EP (#1854) This PR: - let `ExpertParallel` handles indices permute / unpermute when EP is used - move `to_local` to model code to be more explicit - rename the `expert_parallel` wrapper which does permute / unpermute to `indices_permutation_wrapper` to be more accurate --- torchtitan/distributed/expert_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 5ed36a6317..64017f97d0 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -223,6 +223,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: class ReordererSequenceParallel(ParallelStyle): def __init__(self): super().__init__() + self.top_k = None def _prepare_inputput_fn(self, mod, inputs, device_mesh): # shape (batch_size*seq_len, top_k) From f5dbc0f6801f92f2cbaf9efa008849632a851f2d Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Sun, 12 Oct 2025 13:32:27 -0700 Subject: [PATCH 37/46] [vlm] Add light-weight CI for experimental models (#1848) Add one light-weight CI for VLM --- .github/workflows/integration_test_8gpu_vlm.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/integration_test_8gpu_vlm.yaml b/.github/workflows/integration_test_8gpu_vlm.yaml index 5e378d6597..089897043d 100644 --- a/.github/workflows/integration_test_8gpu_vlm.yaml +++ b/.github/workflows/integration_test_8gpu_vlm.yaml @@ -1,4 +1,8 @@ +<<<<<<< HEAD name: VLM 8 GPU Integration Tests +======= +name: 8 GPU Vision Language Model Tests +>>>>>>> f7f225eb ([vlm] Add light-weight CI for experimental models (#1848)) on: push: From 1eb5f8eaa2f98832dd5f193f319ee8bdcdeacb2b Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Sun, 12 Oct 2025 21:41:36 -0700 Subject: [PATCH 38/46] add owners and CI status for experiments (#1859) Next is step is to move `qwen3` and `llama4` to core, and remove outdated experiments. --- .github/workflows/integration_test_8gpu_vlm.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/integration_test_8gpu_vlm.yaml b/.github/workflows/integration_test_8gpu_vlm.yaml index 089897043d..5e378d6597 100644 --- a/.github/workflows/integration_test_8gpu_vlm.yaml +++ b/.github/workflows/integration_test_8gpu_vlm.yaml @@ -1,8 +1,4 @@ -<<<<<<< HEAD name: VLM 8 GPU Integration Tests -======= -name: 8 GPU Vision Language Model Tests ->>>>>>> f7f225eb ([vlm] Add light-weight CI for experimental models (#1848)) on: push: From aba26b4e837b321530bbaa3d3622d042bdb9c2c4 Mon Sep 17 00:00:00 2001 From: yifanmao Date: Mon, 13 Oct 2025 20:03:33 -0700 Subject: [PATCH 39/46] TorchTitan e2e test on torchcomms device mesh (#1847) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763 Test plan: TEST_BACKEND=nccl TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms Loss curve: running 1000 steps on llama3_8b.toml Screenshot 2025-10-13 at 4 14 46 PM --- .../experiments/torchcomms/parallel_dims.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/torchtitan/experiments/torchcomms/parallel_dims.py b/torchtitan/experiments/torchcomms/parallel_dims.py index d4cbdaa0ee..7ec13456b0 100644 --- a/torchtitan/experiments/torchcomms/parallel_dims.py +++ b/torchtitan/experiments/torchcomms/parallel_dims.py @@ -20,6 +20,7 @@ __all__ = ["TorchCommsParallelDims"] +<<<<<<< HEAD def _calculate_ranks_per_dimension( meshes: List[torch.Tensor], dim_names: List[str], @@ -53,6 +54,11 @@ def _calculate_ranks_per_dimension( class TorchCommsParallelDims(ParallelDims): def _build_mesh_without_ep(self) -> DeviceMesh: # TODO: support EP +======= +@dataclass +class TorchCommsParallelDims(ParallelDims): + def _build_mesh_without_ep(self) -> DeviceMesh: +>>>>>>> e8c73aed (TorchTitan e2e test on torchcomms device mesh (#1847)) dims = [] names = [] for d, name in zip( @@ -66,6 +72,7 @@ def _build_mesh_without_ep(self) -> DeviceMesh: logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") backend = os.environ["TEST_BACKEND"] device = torch.device("cuda") +<<<<<<< HEAD mesh = torch.arange(self.world_size, dtype=torch.int, device="cpu").view( self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp ) @@ -156,3 +163,27 @@ def _build_mesh_without_ep(self) -> DeviceMesh: self.comms = [*comm_per_dim.values(), comm] return device_mesh +======= + # TODO: + # - Extend support for additional parallelism strategies (e.g., pipeline, context) + # - Refactor and modularize initialization logic for communication objects and device mesh construction. + if ( + self.dp_shard > 1 + and self.pp == 1 + and self.dp_replicate == 1 + and self.cp == 1 + and self.tp == 1 + ): + self.comms = [] + comm = torchcomms.new_comm(backend, device, name="main") + # TODO: it's a hacky solution for now and we will update it in a week + mesh = init_device_mesh( + mesh_dim_comms=(comm, comm, comm, comm), + mesh_dim_names=("dp_shard", "dp", "dp_cp", "dp_shard_cp"), + _global_comm=comm, + ) + self.comms.append(comm) + return mesh + else: + raise NotImplementedError("Only support FSDP 1D parallelism for now.") +>>>>>>> e8c73aed (TorchTitan e2e test on torchcomms device mesh (#1847)) From 09db1fe2fb0eaa8e1bb6fc6f488fa9534bb851b7 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 13 Oct 2025 20:55:21 -0700 Subject: [PATCH 40/46] graduate llama4 to core (#1865) This PR also - updates the README of `deepseek_v3` folder. - move the `generate_permute_indices` triton kernel to `torchtitan/models/moe` so that core doesn't depend on `experiments` - deprecate the gpu checkpoint conversion scripts as now we natively support loading checkpoint from HF using GPUs (although it is only using GPU when doing online conversion right before training starts) --- .../experiments/torchcomms/parallel_dims.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/torchtitan/experiments/torchcomms/parallel_dims.py b/torchtitan/experiments/torchcomms/parallel_dims.py index 7ec13456b0..d4cbdaa0ee 100644 --- a/torchtitan/experiments/torchcomms/parallel_dims.py +++ b/torchtitan/experiments/torchcomms/parallel_dims.py @@ -20,7 +20,6 @@ __all__ = ["TorchCommsParallelDims"] -<<<<<<< HEAD def _calculate_ranks_per_dimension( meshes: List[torch.Tensor], dim_names: List[str], @@ -54,11 +53,6 @@ def _calculate_ranks_per_dimension( class TorchCommsParallelDims(ParallelDims): def _build_mesh_without_ep(self) -> DeviceMesh: # TODO: support EP -======= -@dataclass -class TorchCommsParallelDims(ParallelDims): - def _build_mesh_without_ep(self) -> DeviceMesh: ->>>>>>> e8c73aed (TorchTitan e2e test on torchcomms device mesh (#1847)) dims = [] names = [] for d, name in zip( @@ -72,7 +66,6 @@ def _build_mesh_without_ep(self) -> DeviceMesh: logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") backend = os.environ["TEST_BACKEND"] device = torch.device("cuda") -<<<<<<< HEAD mesh = torch.arange(self.world_size, dtype=torch.int, device="cpu").view( self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp ) @@ -163,27 +156,3 @@ def _build_mesh_without_ep(self) -> DeviceMesh: self.comms = [*comm_per_dim.values(), comm] return device_mesh -======= - # TODO: - # - Extend support for additional parallelism strategies (e.g., pipeline, context) - # - Refactor and modularize initialization logic for communication objects and device mesh construction. - if ( - self.dp_shard > 1 - and self.pp == 1 - and self.dp_replicate == 1 - and self.cp == 1 - and self.tp == 1 - ): - self.comms = [] - comm = torchcomms.new_comm(backend, device, name="main") - # TODO: it's a hacky solution for now and we will update it in a week - mesh = init_device_mesh( - mesh_dim_comms=(comm, comm, comm, comm), - mesh_dim_names=("dp_shard", "dp", "dp_cp", "dp_shard_cp"), - _global_comm=comm, - ) - self.comms.append(comm) - return mesh - else: - raise NotImplementedError("Only support FSDP 1D parallelism for now.") ->>>>>>> e8c73aed (TorchTitan e2e test on torchcomms device mesh (#1847)) From d0b19874b1f63d66dc3ea3e4ad1628a6f2394c5f Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 13 Oct 2025 22:18:17 -0700 Subject: [PATCH 41/46] move PP API to model agnostic file (#1868) We originally thought each model should have its own `pipeline.py` function. However, for most LLMs, it turns out a single function would suffice, and all models which needs PP are reusing `pipeline_llama.py` originally written for llama3. (For diffusion models, the model size doesn't justify the usage of PP.) This PR consolidates them and moves `pipeline_llm` into `torchtitan/distributed/pipeline_parallel.py`. We can refactor later if things change. --- torchtitan/experiments/simple_fsdp/llama3/__init__.py | 4 ++++ torchtitan/models/llama3_ft/__init__.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/torchtitan/experiments/simple_fsdp/llama3/__init__.py b/torchtitan/experiments/simple_fsdp/llama3/__init__.py index 30b71797f6..0306963eda 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/__init__.py +++ b/torchtitan/experiments/simple_fsdp/llama3/__init__.py @@ -10,7 +10,11 @@ from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.distributed.pipeline_parallel import pipeline_llm +<<<<<<< HEAD from torchtitan.models.llama3 import llama3_args +======= +from torchtitan.models.llama3 import llama3_configs +>>>>>>> cd16507b (move PP API to model agnostic file (#1868)) from torchtitan.protocols.train_spec import TrainSpec from .model import SimpleFSDPTransformer diff --git a/torchtitan/models/llama3_ft/__init__.py b/torchtitan/models/llama3_ft/__init__.py index c914c235e9..e0d76de554 100644 --- a/torchtitan/models/llama3_ft/__init__.py +++ b/torchtitan/models/llama3_ft/__init__.py @@ -16,6 +16,14 @@ from ..llama3 import llama3_args, Llama3StateDictAdapter, parallelize_llama, Transformer +__all__ = [ + "parallelize_llama", + "TransformerModelArgs", + "Transformer", + "llama3_configs", +] +>>>>>>> cd16507b (move PP API to model agnostic file (#1868)) + def get_train_spec() -> TrainSpec: return FaultTolerantTrainSpec( From 66487079945501c659a23aa3c0aabfa714d28740 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Tue, 14 Oct 2025 12:52:16 -0700 Subject: [PATCH 42/46] [refactor] graduate custom_config_module and unify args/config naming (#1871) In the past, the terms "args" and "config" have been used in a mix. To make it unambiguous, in torchtitan we use - "args" as in `ModelArgs` to refer to parameters used to define a model in model code - "config" as in `JobConfig` to refer to configurable training job commands used in training script This also PR also moves `custom_args_module` (which should be `custom_config_module` according to the naming rule above) from `Experimental` to `Job`, as it has been extensively used by various models in torchtitan, especially those in the `experiments` folder. --- torchtitan/experiments/simple_fsdp/llama3/__init__.py | 4 ---- torchtitan/models/llama3_ft/__init__.py | 1 - torchtitan/models/llama3_ft/train_configs/debug_model.toml | 5 ----- 3 files changed, 10 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/llama3/__init__.py b/torchtitan/experiments/simple_fsdp/llama3/__init__.py index 0306963eda..30b71797f6 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/__init__.py +++ b/torchtitan/experiments/simple_fsdp/llama3/__init__.py @@ -10,11 +10,7 @@ from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.distributed.pipeline_parallel import pipeline_llm -<<<<<<< HEAD from torchtitan.models.llama3 import llama3_args -======= -from torchtitan.models.llama3 import llama3_configs ->>>>>>> cd16507b (move PP API to model agnostic file (#1868)) from torchtitan.protocols.train_spec import TrainSpec from .model import SimpleFSDPTransformer diff --git a/torchtitan/models/llama3_ft/__init__.py b/torchtitan/models/llama3_ft/__init__.py index e0d76de554..ae1eef8922 100644 --- a/torchtitan/models/llama3_ft/__init__.py +++ b/torchtitan/models/llama3_ft/__init__.py @@ -22,7 +22,6 @@ "Transformer", "llama3_configs", ] ->>>>>>> cd16507b (move PP API to model agnostic file (#1868)) def get_train_spec() -> TrainSpec: diff --git a/torchtitan/models/llama3_ft/train_configs/debug_model.toml b/torchtitan/models/llama3_ft/train_configs/debug_model.toml index b8f2d7989d..883d08bced 100644 --- a/torchtitan/models/llama3_ft/train_configs/debug_model.toml +++ b/torchtitan/models/llama3_ft/train_configs/debug_model.toml @@ -1,13 +1,8 @@ [job] dump_folder = "./outputs" -<<<<<<< HEAD description = "Llama 3 fault-tolerant debug training" print_config = false custom_config_module = "torchtitan.components.ft.config" -======= -description = "Llama 3 debug training" -print_args = false ->>>>>>> f63037ff (add script to train with ft (#1812)) [profiling] enable_profiling = true From e57adb71b82e50f931328f7234c85441215eec6a Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Wed, 15 Oct 2025 15:41:08 -0700 Subject: [PATCH 43/46] Rebase misses. --- torchtitan/distributed/expert_parallel.py | 1 - torchtitan/models/deepseek_v3/__init__.py | 2 +- torchtitan/models/llama3_ft/__init__.py | 7 ------- 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 64017f97d0..5ed36a6317 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -223,7 +223,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: class ReordererSequenceParallel(ParallelStyle): def __init__(self): super().__init__() - self.top_k = None def _prepare_inputput_fn(self, mod, inputs, device_mesh): # shape (batch_size*seq_len, top_k) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 1b9c5f9500..51bd7ea922 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -132,7 +132,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=4, + n_layers=61, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( diff --git a/torchtitan/models/llama3_ft/__init__.py b/torchtitan/models/llama3_ft/__init__.py index ae1eef8922..c914c235e9 100644 --- a/torchtitan/models/llama3_ft/__init__.py +++ b/torchtitan/models/llama3_ft/__init__.py @@ -16,13 +16,6 @@ from ..llama3 import llama3_args, Llama3StateDictAdapter, parallelize_llama, Transformer -__all__ = [ - "parallelize_llama", - "TransformerModelArgs", - "Transformer", - "llama3_configs", -] - def get_train_spec() -> TrainSpec: return FaultTolerantTrainSpec( From d68127a8bdc2d045bbf4959296dfcada222f7146 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Wed, 15 Oct 2025 16:22:58 -0700 Subject: [PATCH 44/46] Rebase mistakes. --- tests/integration_tests/models.py | 46 ----------------------- torchtitan/distributed/expert_parallel.py | 1 + 2 files changed, 1 insertion(+), 46 deletions(-) diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index 0b69ef806a..37f588765b 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -110,52 +110,6 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "llama4_pp+fsdp+tp+ep+compile", ngpu=8, ), - # Integration Test Cases for Qwen3 dense and MoE model - OverrideDefinitions( - [ - [ - "--model.name qwen3", - "--parallelism.data_parallel_shard_degree 2", - "--parallelism.tensor_parallel_degree 2", - ], - ], - "Qwen3 FSDP+TP", - "qwen3_fsdp+tp", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--model.name qwen3", - "--model.flavor debugmodel_moe", - "--parallelism.data_parallel_shard_degree 2", - "--parallelism.tensor_parallel_degree 2", - "--parallelism.expert_parallel_degree 2", - "--parallelism.expert_tensor_parallel_degree 2", - ], - ], - "Qwen3 FSDP+TP+EP+ETP", - "qwen3_fsdp+tp+ep+etp", - ngpu=4, - ), - # Integration Test Cases for Llama 4 - OverrideDefinitions( - [ - [ - "--model.name llama4", - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule Interleaved1F1B", - "--parallelism.data_parallel_shard_degree 2", - "--parallelism.tensor_parallel_degree 2", - "--parallelism.expert_parallel_degree 4", - "--parallelism.expert_tensor_parallel_degree 1", - "--compile.enable", - ], - ], - "Llama 4 PP+FSDP+TP+EP+compile", - "llama4_pp+fsdp+tp+ep+compile", - ngpu=8, - ), ] return model_tests diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 5ed36a6317..e9986b9974 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -227,6 +227,7 @@ def __init__(self): def _prepare_inputput_fn(self, mod, inputs, device_mesh): # shape (batch_size*seq_len, top_k) top_scores, selected_experts_indices = inputs + num_tokens, _ = top_scores.shape # NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree # if top_scores.shape[0] % device_mesh.size() != 0: From a251fd48cf62565d848f106258bcd8a5760ac1f8 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Thu, 16 Oct 2025 16:58:21 -0700 Subject: [PATCH 45/46] Lint'er error fixes. --- .../unit_tests/test_activation_checkpoint.py | 62 ++++++++------- .../distributed/activation_checkpoint.py | 76 ++++++++++--------- .../simple_fsdp/deepseek_v3/parallelize.py | 2 +- .../simple_fsdp/llama3/parallelize.py | 2 +- .../experiments/vlm/infra/parallelize.py | 2 +- .../models/deepseek_v3/infra/parallelize.py | 2 +- torchtitan/models/deepseek_v3/model/args.py | 2 +- torchtitan/models/llama3/infra/parallelize.py | 2 +- torchtitan/models/llama4/infra/parallelize.py | 2 +- torchtitan/models/qwen3/infra/parallelize.py | 2 +- 10 files changed, 81 insertions(+), 73 deletions(-) diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index 202f7b1e48..7a91bfd475 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -11,6 +11,7 @@ from torch.utils.flop_counter import FlopCounterMode from torchtitan.config.job_config import ActivationCheckpoint as ACConfig +from torchtitan.config.job_config import JobConfig from torchtitan.distributed.activation_checkpoint import apply_ac @@ -74,7 +75,8 @@ def get_bw_flops(model_fn): # 2. SAC # Per-op SAC's policy is to save every other mm model_selective_ac = ToyModule() - ac_config_no_force = ACConfig( + job_config = JobConfig() + job_config.activation_checkpoint = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list @@ -82,7 +84,7 @@ def get_bw_flops(model_fn): ) apply_ac( model_selective_ac, - ac_config_no_force, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -92,7 +94,7 @@ def get_bw_flops(model_fn): # 3. Per-op SAC with force recompute "moe.router.gate" # This leads to two mms being recomputed since they share the same shape! model_with_force_first = ToyModule() - ac_config_with_force_first = ACConfig( + job_config.activation_checkpoint = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], @@ -100,7 +102,7 @@ def get_bw_flops(model_fn): ) apply_ac( model_with_force_first, - ac_config_with_force_first, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -109,7 +111,7 @@ def get_bw_flops(model_fn): # 4. Per-op SAC with force recompute "output" model_with_force_last = ToyModule() - ac_config_with_force_last = ACConfig( + job_config.activation_checkpoint = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], @@ -117,7 +119,7 @@ def get_bw_flops(model_fn): ) apply_ac( model_with_force_last, - ac_config_with_force_last, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -126,13 +128,13 @@ def get_bw_flops(model_fn): # 5. Full AC model_with_full_ac = ToyModule() - ac_config_full_ac = ACConfig( + job_config.activation_checkpoint = ACConfig( mode="full", early_stop=False, ) apply_ac( model_with_full_ac, - ac_config_full_ac, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -168,14 +170,14 @@ def get_act_mem(model_fn): # 2. SAC # Per-op SAC's policy is to save every other mm model_selective_ac = ToyModule().cuda() - ac_config_no_force = ACConfig( + job_config.activation_checkpoint = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list ) apply_ac( model_selective_ac, - ac_config_no_force, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -185,14 +187,14 @@ def get_act_mem(model_fn): # 3. Per-op SAC with force recompute "moe.router.gate" # This leads to two mms being recomputed since they share the same shape! model_with_force_first = ToyModule().cuda() - ac_config_with_force_first = ACConfig( + job_config.activation_checkpoint = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ) apply_ac( model_with_force_first, - ac_config_with_force_first, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -201,14 +203,14 @@ def get_act_mem(model_fn): # 4. Per-op SAC with force recompute "output" model_with_force_last = ToyModule().cuda() - ac_config_with_force_last = ACConfig( + job_config.activation_checkpoint = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ) apply_ac( model_with_force_last, - ac_config_with_force_last, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -217,12 +219,12 @@ def get_act_mem(model_fn): # 5. Full AC model_with_full_ac = ToyModule().cuda() - ac_config_full_ac = ACConfig( + job_config.activation_checkpoint = ACConfig( mode="full", ) apply_ac( model_with_full_ac, - ac_config_full_ac, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -243,26 +245,29 @@ def test_correctness(self): model_selective_ac = ToyModule() model_selective_ac.load_state_dict(model_no_ac.state_dict()) - apply_ac( - model_selective_ac, - ACConfig( + job_config = JobConfig() + job_config.activation_checkpoint = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], - ), + ) + apply_ac( + model_selective_ac, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) model_force_first = ToyModule() model_force_first.load_state_dict(model_no_ac.state_dict()) - apply_ac( - model_force_first, - ACConfig( + job_config.activation_checkpoint = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], - ), + ) + apply_ac( + model_force_first, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -270,13 +275,14 @@ def test_correctness(self): model_force_last = ToyModule() model_force_last.load_state_dict(model_no_ac.state_dict()) - apply_ac( - model_force_last, - ACConfig( + job_config.activation_checkpoint = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], - ), + ) + apply_ac( + model_force_last, + job_config, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 131c7f7f58..7ea25405ac 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -18,13 +18,14 @@ from torchtitan.config.job_config import ActivationCheckpoint as ACConfig from torchtitan.config.job_config import Debug as DebugConfig +from torchtitan.config.job_config import JobConfig from torchtitan.tools.logging import logger, warn_once _layer_sac_count = 0 -def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, debug_config:DebugConfig) -> nn.Module: +def _apply_layer_sac(module: nn.Module, job_config: JobConfig) -> nn.Module: """Apply layer selective activation checkpointing to the module. Args: @@ -36,14 +37,14 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, debug_config:DebugC """ global _layer_sac_count _layer_sac_count += 1 - ac_freq = int(ac_config.selective_ac_option) + ac_freq = int(job_config.activation_checkpoint.selective_ac_option) if not ac_freq or _layer_sac_count % ac_freq == 0: return ptd_checkpoint_wrapper( module, - preserve_rng_state=debug_config.ac_preserve_rng_state, - determinism_check=debug_config.ac_determinism_check, - early_stop=ac_config.early_stop, - debug=debug_config.ac_debug + preserve_rng_state=job_config.debug.ac_preserve_rng_state, + determinism_check=job_config.debug.ac_determinism_check, + early_stop=job_config.activation_checkpoint.early_stop, + debug=job_config.debug.ac_debug ) else: return module @@ -51,7 +52,7 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, debug_config:DebugC def _apply_op_sac( module: nn.Module, - ac_config: ACConfig, + job_config: JobConfig, *, base_fqn: str | None = None, op_sac_save_list: set[torch._ops.OpOverload], @@ -60,7 +61,7 @@ def _apply_op_sac( Args: module (nn.Module): The module to apply selective activation checkpointing to. - ac_config (ACConfig): The activation checkpointing config. + job_config (JobConfig): The job config. base_fqn (str, optional): The base fqn of the module. Defaults to None. op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead of recomputing. @@ -74,6 +75,7 @@ def _apply_op_sac( ) mm_recompute_shapes = set() + ac_config = job_config.activation_checkpoint if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: for module_fqn, submod in module.named_modules(): fqn = module_fqn @@ -128,16 +130,16 @@ def selective_checkpointing_context_fn(): return create_selective_checkpoint_contexts(_get_custom_policy(meta)) return ptd_checkpoint_wrapper( - module, - context_fn=selective_checkpointing_context_fn, - preserve_rng_state=dbg_config.ac_preserve_rng_state, - determinism_check=dbg_config.ac_determinism_check, - early_stop=ac_config.early_stop, - debug=dbg_config.ac_debug - ) + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=job_config.debug.ac_preserve_rng_state, + determinism_check=job_config.debug.ac_determinism_check, + early_stop=ac_config.early_stop, + debug=job_config.debug.ac_debug + ) -def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: +def _apply_full_ac(module: nn.Module, job_config: JobConfig ) -> nn.Module: """Apply full activation checkpointing to the module. Args: @@ -148,20 +150,17 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: nn.Module: The module with full activation checkpointing applied. """ return ptd_checkpoint_wrapper( - module, preserve_rng_state=False, early_stop=ac_config.early_stop + module, + preserve_rng_state=job_config.debug.ac_preserve_rng_state, + determinism_check=job_config.debug.ac_determinism_check, + early_stop=job_config.activation_checkpoint.early_stop, + debug=job_config.debug.ac_debug ) - return ptd_checkpoint_wrapper( - module, - preserve_rng_state=dbg_config.ac_preserve_rng_state, - determinism_check=dbg_config.ac_determinism_check, - early_stop=ac_config.early_stop, - debug=dbg_config.ac_debug - ) def _apply_op_sac_to_transformer_block_with_flex( module: nn.Module, - ac_config: ACConfig, + job_config: JobConfig, *, base_fqn: str | None = None, model_compile_enabled: bool = False, @@ -171,7 +170,7 @@ def _apply_op_sac_to_transformer_block_with_flex( Args: module (nn.Module): The transformer block to apply SAC to. - ac_config (ACConfig): The activation checkpointing config. + job_config (JobConfig): The job config. base_fqn (str, optional): The base fqn of the module. Defaults to None. model_compile_enabled (bool): Whether model compilation is enabled. Defaults to False. @@ -201,14 +200,15 @@ def _apply_op_sac_to_transformer_block_with_flex( ), ) + def wrap_submodule(name: str, full_ac: bool = False) -> None: submodule = getattr(module, name) if full_ac: - submodule = _apply_full_ac(submodule, ac_config) + submodule = _apply_full_ac(submodule, job_config) else: submodule = _apply_op_sac( submodule, - ac_config, + job_config, base_fqn=f"{base_fqn}.{name}" if base_fqn else name, op_sac_save_list=op_sac_save_list, ) @@ -224,7 +224,7 @@ def wrap_submodule(name: str, full_ac: bool = False) -> None: if model_compile_enabled: module = _apply_op_sac( module, - ac_config, + job_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list, ) @@ -236,7 +236,7 @@ def wrap_submodule(name: str, full_ac: bool = False) -> None: def _apply_ac_to_transformer_block( module: nn.Module, - ac_config: ACConfig, + job_config: JobConfig, *, base_fqn: str | None = None, model_compile_enabled: bool = False, @@ -244,13 +244,14 @@ def _apply_ac_to_transformer_block( op_sac_save_list: set[torch._ops.OpOverload] | None = None, ) -> nn.Module: valid_ac_modes = ("full", "selective") + ac_config = job_config.activation_checkpoint if ac_config.mode not in valid_ac_modes: raise ValueError( f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" ) if ac_config.mode == "full": - return _apply_full_ac(module, ac_config) + return _apply_full_ac(module, job_config) assert ac_config.mode == "selective", f"{ac_config.mode}" use_op_sac = ac_config.selective_ac_option == "op" @@ -274,22 +275,22 @@ def _apply_ac_to_transformer_block( """ return _apply_op_sac_to_transformer_block_with_flex( module, - ac_config, + job_config, base_fqn=base_fqn, model_compile_enabled=model_compile_enabled, op_sac_save_list=op_sac_save_list, ) else: return _apply_op_sac( - module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list + module, job_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list ) - return _apply_layer_sac(module, ac_config) + return _apply_layer_sac(module, job_config) def apply_ac( model: nn.Module, - ac_config: ACConfig, + job_config: JobConfig, *, model_compile_enabled: bool = False, use_flex_attn: bool = False, @@ -312,7 +313,8 @@ def apply_ac( Returns: None """ - + ac_config = job_config.activation_checkpoint + debug_config = job_config.debug if ac_config.mode == "memory_budget": assert model_compile_enabled, "Memory budget mode requires model to be compiled" if ac_config.visualize_memory_budget_pareto: @@ -328,7 +330,7 @@ def apply_ac( for layer_id, transformer_block in model.layers.named_children(): transformer_block = _apply_ac_to_transformer_block( transformer_block, - ac_config, + job_config, base_fqn=f"layers.{layer_id}", model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 9ffbc8f76d..09f90afe45 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -84,7 +84,7 @@ def parallelize_deepseekv3( ) if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) + apply_ac(model, job_config) mp_policy = MixedPrecisionPolicy( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index fdc40031f5..8c90e659e6 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -83,7 +83,7 @@ def parallelize_llama( ) apply_ac( model, - job_config.activation_checkpoint, + job_config, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index a8095c7621..aff9afb020 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -61,7 +61,7 @@ def parallelize_vlm( if job_config.activation_checkpoint.mode != "none": apply_ac( model, - job_config.activation_checkpoint, + job_config, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 8d13a3f31f..87c55efea4 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -110,7 +110,7 @@ def parallelize_deepseekv3( if job_config.activation_checkpoint.mode != "none": apply_ac( model, - job_config.activation_checkpoint, + job_config, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 3bac6e82f1..0328c52334 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -107,7 +107,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args._debug_force_load_balance = ( - job_config.training.debug_moe_force_load_balance + job_config.debug.moe_force_load_balance ) def get_nparams_and_flops( diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 4944af569e..01617e1e91 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -98,7 +98,7 @@ def parallelize_llama( if job_config.activation_checkpoint.mode != "none": apply_ac( model, - job_config.activation_checkpoint, + job_config, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 1f579ccd04..22edcc847a 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -116,7 +116,7 @@ def parallelize_llama( if job_config.activation_checkpoint.mode != "none": apply_ac( model, - job_config.activation_checkpoint, + job_config, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 5fa8549e9f..2dfa8bb970 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -110,7 +110,7 @@ def parallelize_qwen3( if job_config.activation_checkpoint.mode != "none": apply_ac( model, - job_config.activation_checkpoint, + job_config, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, From ede15a7646eb521c82633e004a949f72e42a6e28 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Tue, 21 Oct 2025 19:12:23 -0700 Subject: [PATCH 46/46] Review updates. --- docs/debugging.md | 6 +- .../unit_tests/test_activation_checkpoint.py | 62 +++++++++---------- torchtitan/config/job_config.py | 29 ++++++--- .../distributed/activation_checkpoint.py | 59 ++++++++---------- .../simple_fsdp/deepseek_v3/parallelize.py | 2 +- .../simple_fsdp/llama3/parallelize.py | 2 +- .../experiments/vlm/infra/parallelize.py | 2 +- .../models/deepseek_v3/infra/parallelize.py | 2 +- torchtitan/models/llama3/infra/parallelize.py | 2 +- torchtitan/models/llama4/infra/parallelize.py | 2 +- torchtitan/models/qwen3/infra/parallelize.py | 2 +- 11 files changed, 85 insertions(+), 85 deletions(-) diff --git a/docs/debugging.md b/docs/debugging.md index fd436367e9..4deb20bbac 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -99,11 +99,11 @@ Use `--debug.deterministic_warn_only` to only warn about (not stop running) kern The following debug configs are available for AC. -`ac_preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. +`preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. -`ac_determinism_check` - A string specifying the determinism function +`determinism_check` - A string specifying the determinism function -`ac_debug` - capture ac debug information. Will be slower. +`debug` - capture ac debug information. Will be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details. diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index 7a91bfd475..202f7b1e48 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -11,7 +11,6 @@ from torch.utils.flop_counter import FlopCounterMode from torchtitan.config.job_config import ActivationCheckpoint as ACConfig -from torchtitan.config.job_config import JobConfig from torchtitan.distributed.activation_checkpoint import apply_ac @@ -75,8 +74,7 @@ def get_bw_flops(model_fn): # 2. SAC # Per-op SAC's policy is to save every other mm model_selective_ac = ToyModule() - job_config = JobConfig() - job_config.activation_checkpoint = ACConfig( + ac_config_no_force = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list @@ -84,7 +82,7 @@ def get_bw_flops(model_fn): ) apply_ac( model_selective_ac, - job_config, + ac_config_no_force, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -94,7 +92,7 @@ def get_bw_flops(model_fn): # 3. Per-op SAC with force recompute "moe.router.gate" # This leads to two mms being recomputed since they share the same shape! model_with_force_first = ToyModule() - job_config.activation_checkpoint = ACConfig( + ac_config_with_force_first = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], @@ -102,7 +100,7 @@ def get_bw_flops(model_fn): ) apply_ac( model_with_force_first, - job_config, + ac_config_with_force_first, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -111,7 +109,7 @@ def get_bw_flops(model_fn): # 4. Per-op SAC with force recompute "output" model_with_force_last = ToyModule() - job_config.activation_checkpoint = ACConfig( + ac_config_with_force_last = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], @@ -119,7 +117,7 @@ def get_bw_flops(model_fn): ) apply_ac( model_with_force_last, - job_config, + ac_config_with_force_last, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -128,13 +126,13 @@ def get_bw_flops(model_fn): # 5. Full AC model_with_full_ac = ToyModule() - job_config.activation_checkpoint = ACConfig( + ac_config_full_ac = ACConfig( mode="full", early_stop=False, ) apply_ac( model_with_full_ac, - job_config, + ac_config_full_ac, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -170,14 +168,14 @@ def get_act_mem(model_fn): # 2. SAC # Per-op SAC's policy is to save every other mm model_selective_ac = ToyModule().cuda() - job_config.activation_checkpoint = ACConfig( + ac_config_no_force = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list ) apply_ac( model_selective_ac, - job_config, + ac_config_no_force, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -187,14 +185,14 @@ def get_act_mem(model_fn): # 3. Per-op SAC with force recompute "moe.router.gate" # This leads to two mms being recomputed since they share the same shape! model_with_force_first = ToyModule().cuda() - job_config.activation_checkpoint = ACConfig( + ac_config_with_force_first = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ) apply_ac( model_with_force_first, - job_config, + ac_config_with_force_first, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -203,14 +201,14 @@ def get_act_mem(model_fn): # 4. Per-op SAC with force recompute "output" model_with_force_last = ToyModule().cuda() - job_config.activation_checkpoint = ACConfig( + ac_config_with_force_last = ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ) apply_ac( model_with_force_last, - job_config, + ac_config_with_force_last, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -219,12 +217,12 @@ def get_act_mem(model_fn): # 5. Full AC model_with_full_ac = ToyModule().cuda() - job_config.activation_checkpoint = ACConfig( + ac_config_full_ac = ACConfig( mode="full", ) apply_ac( model_with_full_ac, - job_config, + ac_config_full_ac, model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -245,29 +243,26 @@ def test_correctness(self): model_selective_ac = ToyModule() model_selective_ac.load_state_dict(model_no_ac.state_dict()) - job_config = JobConfig() - job_config.activation_checkpoint = ACConfig( + apply_ac( + model_selective_ac, + ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], - ) - apply_ac( - model_selective_ac, - job_config, + ), model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) model_force_first = ToyModule() model_force_first.load_state_dict(model_no_ac.state_dict()) - job_config.activation_checkpoint = ACConfig( + apply_ac( + model_force_first, + ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], - ) - apply_ac( - model_force_first, - job_config, + ), model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, @@ -275,14 +270,13 @@ def test_correctness(self): model_force_last = ToyModule() model_force_last.load_state_dict(model_no_ac.state_dict()) - job_config.activation_checkpoint = ACConfig( + apply_ac( + model_force_last, + ACConfig( mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], - ) - apply_ac( - model_force_last, - job_config, + ), model_compile_enabled=False, use_flex_attn=False, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 7fb3f013c3..5837c5c2cb 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -623,6 +623,26 @@ class ActivationCheckpoint: https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 """ + preserve_rng_state: bool = False + """ + If deterministic output compared to non-checkpointed passes is required, set + to true. Results in stashing and restoring the RNG state during each checkpoint, + may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html + for details. + """ + + determinism_check: str = "default" + """ + A string specifying the determinism function. See + https://docs.pytorch.org/docs/stable/checkpoint.html for details. + """ + + debug: bool = False + """ + Capture ac debug information. Will be slower. See + https://docs.pytorch.org/docs/stable/checkpoint.html for details. + """ + @dataclass class Compile: @@ -882,15 +902,6 @@ class Debug: deterministic_warn_only: bool = False """Only warns about ops without deterministic implementations rather than erroring out """ - ac_preserve_rng_state: bool = False - """If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details.""" - - ac_determinism_check: str = "default" - """A string specifying the determinism function. See https://docs.pytorch.org/docs/stable/checkpoint.html for details.""" - - ac_debug: bool = False - """ Capture ac debug information. Will be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details.""" - moe_force_load_balance: bool = False """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 7ea25405ac..a07d06c19b 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -18,14 +18,13 @@ from torchtitan.config.job_config import ActivationCheckpoint as ACConfig from torchtitan.config.job_config import Debug as DebugConfig -from torchtitan.config.job_config import JobConfig from torchtitan.tools.logging import logger, warn_once _layer_sac_count = 0 -def _apply_layer_sac(module: nn.Module, job_config: JobConfig) -> nn.Module: +def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, debug_config: DebugConfig) -> nn.Module: """Apply layer selective activation checkpointing to the module. Args: @@ -37,14 +36,14 @@ def _apply_layer_sac(module: nn.Module, job_config: JobConfig) -> nn.Module: """ global _layer_sac_count _layer_sac_count += 1 - ac_freq = int(job_config.activation_checkpoint.selective_ac_option) + ac_freq = int(ac_config.selective_ac_option) if not ac_freq or _layer_sac_count % ac_freq == 0: return ptd_checkpoint_wrapper( module, - preserve_rng_state=job_config.debug.ac_preserve_rng_state, - determinism_check=job_config.debug.ac_determinism_check, - early_stop=job_config.activation_checkpoint.early_stop, - debug=job_config.debug.ac_debug + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, + early_stop=ac_config.early_stop, + debug=ac_config.debug ) else: return module @@ -52,7 +51,7 @@ def _apply_layer_sac(module: nn.Module, job_config: JobConfig) -> nn.Module: def _apply_op_sac( module: nn.Module, - job_config: JobConfig, + ac_config: ACConfig, *, base_fqn: str | None = None, op_sac_save_list: set[torch._ops.OpOverload], @@ -61,7 +60,7 @@ def _apply_op_sac( Args: module (nn.Module): The module to apply selective activation checkpointing to. - job_config (JobConfig): The job config. + ac_config (ACConfig): The activation checkpointing config. base_fqn (str, optional): The base fqn of the module. Defaults to None. op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead of recomputing. @@ -75,7 +74,6 @@ def _apply_op_sac( ) mm_recompute_shapes = set() - ac_config = job_config.activation_checkpoint if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: for module_fqn, submod in module.named_modules(): fqn = module_fqn @@ -132,14 +130,14 @@ def selective_checkpointing_context_fn(): return ptd_checkpoint_wrapper( module, context_fn=selective_checkpointing_context_fn, - preserve_rng_state=job_config.debug.ac_preserve_rng_state, - determinism_check=job_config.debug.ac_determinism_check, + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, early_stop=ac_config.early_stop, - debug=job_config.debug.ac_debug + debug=ac_config.debug ) -def _apply_full_ac(module: nn.Module, job_config: JobConfig ) -> nn.Module: +def _apply_full_ac(module: nn.Module, ac_config: ACConfig ) -> nn.Module: """Apply full activation checkpointing to the module. Args: @@ -151,16 +149,16 @@ def _apply_full_ac(module: nn.Module, job_config: JobConfig ) -> nn.Module: """ return ptd_checkpoint_wrapper( module, - preserve_rng_state=job_config.debug.ac_preserve_rng_state, - determinism_check=job_config.debug.ac_determinism_check, - early_stop=job_config.activation_checkpoint.early_stop, - debug=job_config.debug.ac_debug + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, + early_stop=ac_config.early_stop, + debug=ac_config.debug ) def _apply_op_sac_to_transformer_block_with_flex( module: nn.Module, - job_config: JobConfig, + ac_config: ACConfig, *, base_fqn: str | None = None, model_compile_enabled: bool = False, @@ -170,7 +168,7 @@ def _apply_op_sac_to_transformer_block_with_flex( Args: module (nn.Module): The transformer block to apply SAC to. - job_config (JobConfig): The job config. + ac_config (ACConfig): The Activation Checkpoint config. base_fqn (str, optional): The base fqn of the module. Defaults to None. model_compile_enabled (bool): Whether model compilation is enabled. Defaults to False. @@ -204,11 +202,11 @@ def _apply_op_sac_to_transformer_block_with_flex( def wrap_submodule(name: str, full_ac: bool = False) -> None: submodule = getattr(module, name) if full_ac: - submodule = _apply_full_ac(submodule, job_config) + submodule = _apply_full_ac(submodule, ac_config) else: submodule = _apply_op_sac( submodule, - job_config, + ac_config, base_fqn=f"{base_fqn}.{name}" if base_fqn else name, op_sac_save_list=op_sac_save_list, ) @@ -224,7 +222,7 @@ def wrap_submodule(name: str, full_ac: bool = False) -> None: if model_compile_enabled: module = _apply_op_sac( module, - job_config, + ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list, ) @@ -236,7 +234,7 @@ def wrap_submodule(name: str, full_ac: bool = False) -> None: def _apply_ac_to_transformer_block( module: nn.Module, - job_config: JobConfig, + ac_config: ACConfig, *, base_fqn: str | None = None, model_compile_enabled: bool = False, @@ -244,14 +242,13 @@ def _apply_ac_to_transformer_block( op_sac_save_list: set[torch._ops.OpOverload] | None = None, ) -> nn.Module: valid_ac_modes = ("full", "selective") - ac_config = job_config.activation_checkpoint if ac_config.mode not in valid_ac_modes: raise ValueError( f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" ) if ac_config.mode == "full": - return _apply_full_ac(module, job_config) + return _apply_full_ac(module, ac_config) assert ac_config.mode == "selective", f"{ac_config.mode}" use_op_sac = ac_config.selective_ac_option == "op" @@ -275,14 +272,14 @@ def _apply_ac_to_transformer_block( """ return _apply_op_sac_to_transformer_block_with_flex( module, - job_config, + ac_config, base_fqn=base_fqn, model_compile_enabled=model_compile_enabled, op_sac_save_list=op_sac_save_list, ) else: return _apply_op_sac( - module, job_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list + module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list ) return _apply_layer_sac(module, job_config) @@ -290,7 +287,7 @@ def _apply_ac_to_transformer_block( def apply_ac( model: nn.Module, - job_config: JobConfig, + ac_config: ACConfig, *, model_compile_enabled: bool = False, use_flex_attn: bool = False, @@ -313,8 +310,6 @@ def apply_ac( Returns: None """ - ac_config = job_config.activation_checkpoint - debug_config = job_config.debug if ac_config.mode == "memory_budget": assert model_compile_enabled, "Memory budget mode requires model to be compiled" if ac_config.visualize_memory_budget_pareto: @@ -330,7 +325,7 @@ def apply_ac( for layer_id, transformer_block in model.layers.named_children(): transformer_block = _apply_ac_to_transformer_block( transformer_block, - job_config, + ac_config, base_fqn=f"layers.{layer_id}", model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 09f90afe45..9ffbc8f76d 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -84,7 +84,7 @@ def parallelize_deepseekv3( ) if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config) + apply_ac(model, job_config.activation_checkpoint) mp_policy = MixedPrecisionPolicy( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 8c90e659e6..fdc40031f5 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -83,7 +83,7 @@ def parallelize_llama( ) apply_ac( model, - job_config, + job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index aff9afb020..a8095c7621 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -61,7 +61,7 @@ def parallelize_vlm( if job_config.activation_checkpoint.mode != "none": apply_ac( model, - job_config, + job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 87c55efea4..8d13a3f31f 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -110,7 +110,7 @@ def parallelize_deepseekv3( if job_config.activation_checkpoint.mode != "none": apply_ac( model, - job_config, + job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 01617e1e91..4944af569e 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -98,7 +98,7 @@ def parallelize_llama( if job_config.activation_checkpoint.mode != "none": apply_ac( model, - job_config, + job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 22edcc847a..1f579ccd04 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -116,7 +116,7 @@ def parallelize_llama( if job_config.activation_checkpoint.mode != "none": apply_ac( model, - job_config, + job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 2dfa8bb970..5fa8549e9f 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -110,7 +110,7 @@ def parallelize_qwen3( if job_config.activation_checkpoint.mode != "none": apply_ac( model, - job_config, + job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list,