Skip to content

Commit 6ffcdf6

Browse files
committed
allow disabling ft checkpoints
Summary: Allows disabling the storage of checkpoints related to torchft. Users don't really have to rely on any external storage. So it reduces set up time to get things up and running. Since we also don't really need model checkpoints when we have torchft. And if checkpoint storage has issues, this can work as a killswitch to completely disable the storage so it doesn't impact training.
1 parent d203b1a commit 6ffcdf6

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

torchtitan/components/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def __init__(
193193
self.load_only = checkpoint_config.load_only
194194

195195
self.ft_manager = (
196-
ft_manager.manager if ft_manager and ft_manager.enabled else None
196+
ft_manager.manager if ft_manager and ft_manager.enabled and checkpoint_config.enable_ft_checkpoints else None
197197
)
198198
if self.ft_manager:
199199
optimizers.init_cache_state_dict()

torchtitan/config/job_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,17 @@ class Checkpoint:
414414
enable: bool = False
415415
"""Whether to enable checkpoint"""
416416

417+
enable_ft_checkpoints: bool = True
418+
"""
419+
Used to enable checkpointing state that's used for fault tolerant training with torchft.
420+
421+
Fault tolerant training stores data loader index in the checkpoints, so that training
422+
can resume without going over the same batch twice.
423+
424+
If enabled, data loader state is checkpointed. Otherwise the data loader index
425+
will be infered from the step count.
426+
"""
427+
417428
folder: str = "checkpoint"
418429
"""
419430
The folder to store the checkpoints.

torchtitan/train.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,12 @@ def state_dict(self) -> dict[str, Any]:
638638
def load_state_dict(self, state_dict: dict[str, Any]):
639639
self.step = state_dict["step"]
640640
self.ntokens_seen = state_dict["ntokens_seen"]
641+
if (
642+
self.job_config.fault_tolerance.enable
643+
and not self.job_config.checkpoint.enable_ft_checkpoints
644+
):
645+
# TODO: set index in the dataloader
646+
pass
641647

642648
def close(self) -> None:
643649
if self.checkpointer:

0 commit comments

Comments
 (0)