diff --git a/docs/debugging.md b/docs/debugging.md index f7758cbde5..4deb20bbac 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -70,7 +70,7 @@ When debugging issues with multi-dimensional parallelism (combinations of FSDP, Set consistent random seeds across all parallelism dimensions: ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.seed 42 +CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.seed 42 ``` **Seed behavior with parallelism:** @@ -84,7 +84,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr Enable deterministic algorithms to ensure bit-for-bit reproducibility across runs: ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.deterministic +CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.deterministic ``` **What it does:** @@ -93,6 +93,19 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr - Sets deterministic workspace configuration for CuBLAS operations - **Note:** This will significantly reduce training performance but ensures exact reproducibility +Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation. + +### Activation Checkipointing Debugging ### + +The following debug configs are available for AC. + +`preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. + +`determinism_check` - A string specifying the determinism function + +`debug` - capture ac debug information. Will be slower. + +See https://docs.pytorch.org/docs/stable/checkpoint.html for details. ### Seed-Checkpoint-based Reproducibility diff --git a/torchtitan/config/__init__.py b/torchtitan/config/__init__.py index ba2795a601..e70d7fb622 100644 --- a/torchtitan/config/__init__.py +++ b/torchtitan/config/__init__.py @@ -28,6 +28,7 @@ Quantize, Training, Validation, + Debug ) from .manager import ConfigManager @@ -49,4 +50,5 @@ "Profiling", "Training", "Validation", + "Debug" ] diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 7137579f18..5837c5c2cb 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -253,15 +253,6 @@ class Training: many temporary files. """ - seed: int | None = None - """Choose the base RNG seed used for training""" - - deterministic: bool = False - """Use deterministic algorithms wherever possible, may be slower""" - - debug_moe_force_load_balance: bool = False - """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" - @dataclass class Parallelism: @@ -632,6 +623,26 @@ class ActivationCheckpoint: https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 """ + preserve_rng_state: bool = False + """ + If deterministic output compared to non-checkpointed passes is required, set + to true. Results in stashing and restoring the RNG state during each checkpoint, + may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html + for details. + """ + + determinism_check: str = "default" + """ + A string specifying the determinism function. See + https://docs.pytorch.org/docs/stable/checkpoint.html for details. + """ + + debug: bool = False + """ + Capture ac debug information. Will be slower. See + https://docs.pytorch.org/docs/stable/checkpoint.html for details. + """ + @dataclass class Compile: @@ -880,6 +891,20 @@ def __post_init__(self): ), "validation steps must be positive or -1" +@dataclass +class Debug: + seed: int | None = None + """Choose the base RNG seed used for training""" + + deterministic: bool = False + """Use deterministic algorithms wherever possible, may be slower""" + + deterministic_warn_only: bool = False + """Only warns about ops without deterministic implementations rather than erroring out """ + + moe_force_load_balance: bool = False + """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" + @dataclass class JobConfig: """ @@ -905,6 +930,7 @@ class JobConfig: fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) experimental: Experimental = field(default_factory=Experimental) validation: Validation = field(default_factory=Validation) + debug: Debug = field(default_factory=Debug) def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 57809c45f9..a07d06c19b 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -17,13 +17,14 @@ ) from torchtitan.config.job_config import ActivationCheckpoint as ACConfig +from torchtitan.config.job_config import Debug as DebugConfig from torchtitan.tools.logging import logger, warn_once _layer_sac_count = 0 -def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: +def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, debug_config: DebugConfig) -> nn.Module: """Apply layer selective activation checkpointing to the module. Args: @@ -38,7 +39,11 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: ac_freq = int(ac_config.selective_ac_option) if not ac_freq or _layer_sac_count % ac_freq == 0: return ptd_checkpoint_wrapper( - module, preserve_rng_state=False, early_stop=ac_config.early_stop + module, + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, + early_stop=ac_config.early_stop, + debug=ac_config.debug ) else: return module @@ -125,12 +130,14 @@ def selective_checkpointing_context_fn(): return ptd_checkpoint_wrapper( module, context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, early_stop=ac_config.early_stop, + debug=ac_config.debug ) -def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: +def _apply_full_ac(module: nn.Module, ac_config: ACConfig ) -> nn.Module: """Apply full activation checkpointing to the module. Args: @@ -141,7 +148,11 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: nn.Module: The module with full activation checkpointing applied. """ return ptd_checkpoint_wrapper( - module, preserve_rng_state=False, early_stop=ac_config.early_stop + module, + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, + early_stop=ac_config.early_stop, + debug=ac_config.debug ) @@ -157,7 +168,7 @@ def _apply_op_sac_to_transformer_block_with_flex( Args: module (nn.Module): The transformer block to apply SAC to. - ac_config (ACConfig): The activation checkpointing config. + ac_config (ACConfig): The Activation Checkpoint config. base_fqn (str, optional): The base fqn of the module. Defaults to None. model_compile_enabled (bool): Whether model compilation is enabled. Defaults to False. @@ -187,6 +198,7 @@ def _apply_op_sac_to_transformer_block_with_flex( ), ) + def wrap_submodule(name: str, full_ac: bool = False) -> None: submodule = getattr(module, name) if full_ac: @@ -270,7 +282,7 @@ def _apply_ac_to_transformer_block( module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list ) - return _apply_layer_sac(module, ac_config) + return _apply_layer_sac(module, job_config) def apply_ac( @@ -298,7 +310,6 @@ def apply_ac( Returns: None """ - if ac_config.mode == "memory_budget": assert model_compile_enabled, "Memory budget mode requires model to be compiled" if ac_config.visualize_memory_budget_pareto: diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 67eb41280f..8787519d86 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -18,6 +18,7 @@ from torch.distributed.tensor import DTensor from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP +from torchtitan.config import Debug as DebugConfig from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.tools.logging import logger from torchtitan.tools.utils import device_module, device_type @@ -83,8 +84,7 @@ def dist_mean( def set_determinism( world_mesh: DeviceMesh | None, device: torch.device, - seed: int | None = None, - deterministic: bool = False, + debug_config: DebugConfig, distinct_seed_mesh_dim: str = "pp", ) -> None: """ @@ -97,9 +97,10 @@ def set_determinism( Set Determinism flags for increased reproducibility with loss of performance. """ - if deterministic: + if debug_config.deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") torch.use_deterministic_algorithms(True) + torch.use_deterministic_algorithms(True, warn_only=debug_config.deterministic_warn_only) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # env var for deterministic CuBLAS @@ -114,6 +115,7 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) + seed = debug_config.seed if not world_mesh: if seed is not None: torch.manual_seed(seed) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 624792e83e..33018f95fc 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -35,8 +35,7 @@ def __init__(self, job_config: JobConfig): dist_utils.set_determinism( self.parallel_dims.world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug, distinct_seed_mesh_dim="dp_shard", ) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index f8b1412959..1d3a420b9d 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -104,8 +104,7 @@ def __init__(self, job_config: ForgeJobConfig): dist_utils.set_determinism( world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug, ) self.train_spec = get_train_spec(job_config.model.name) diff --git a/torchtitan/experiments/forge/job_config.py b/torchtitan/experiments/forge/job_config.py index b1c014cc1d..d255bc0b72 100644 --- a/torchtitan/experiments/forge/job_config.py +++ b/torchtitan/experiments/forge/job_config.py @@ -20,6 +20,7 @@ Parallelism, Quantize, Training, + Debug, ) @@ -45,6 +46,7 @@ class ForgeJobConfig: # fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) # experimental: Experimental = field(default_factory=Experimental) # validation: Validation = field(default_factory=Validation) + debug: Debug = field(default_factory=Debug) def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index bf963a5b5f..2309c844a4 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -90,6 +90,7 @@ def __init__(self) -> None: SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH ] def forward( diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 3bac6e82f1..0328c52334 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -107,7 +107,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args._debug_force_load_balance = ( - job_config.training.debug_moe_force_load_balance + job_config.debug.moe_force_load_balance ) def get_nparams_and_flops( diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index faeb60aadf..53043a1d02 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -82,7 +82,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args._debug_force_load_balance = ( - job_config.training.debug_moe_force_load_balance + job_config.debug.moe_force_load_balance ) def get_nparams_and_flops( diff --git a/torchtitan/train.py b/torchtitan/train.py index 1d5e0e500a..c0951063ec 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,6 +11,7 @@ from typing import Any, Generator, Iterable, Optional import torch + from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -118,8 +119,7 @@ def __init__(self, job_config: JobConfig): dist_utils.set_determinism( world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug, ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name)