Skip to content

Commit a82b77a

Browse files
authored
allow disabling ft checkpoints (#1810)
Summary: Allows disabling the storage of checkpoints related to torchft. Users don't really have to rely on any external storage. So it reduces set up time to get things up and running. Since we also don't really need model checkpoints when we have torchft. And if checkpoint storage has issues, this can work as a killswitch to completely disable the storage so it doesn't impact training. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1810). * #1856 * #1811 * __->__ #1810 Co-authored-by: Tushar Jain <[email protected]>
1 parent 9603872 commit a82b77a

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

torchtitan/components/checkpoint.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,19 @@ def __init__(
190190
self.load_only = checkpoint_config.load_only
191191

192192
self.ft_manager = (
193-
ft_manager.manager if ft_manager and ft_manager.enabled else None
193+
ft_manager.manager
194+
if ft_manager
195+
and ft_manager.enabled
196+
and checkpoint_config.enable_ft_dataloader_checkpoints
197+
else None
194198
)
199+
200+
if ft_manager and ft_manager.enabled and not self.ft_manager:
201+
logger.warn(
202+
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
203+
"This means replicas can retrain over the same data multiple times, which can result in overfitting."
204+
)
205+
195206
if self.ft_manager:
196207
optimizers.init_cache_state_dict()
197208

torchtitan/config/job_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,28 @@ class Checkpoint:
430430
enable: bool = False
431431
"""Whether to enable checkpoint"""
432432

433+
enable_ft_dataloader_checkpoints: bool = True
434+
"""
435+
Warning: Disabling this can have fault tolerant replicas training
436+
over the same data multiple times. Use it with caution if training
437+
over the same data is acceptable.
438+
439+
Used to enable checkpointing the dataloader index for fault tolerant training with torchft.
440+
441+
Fault tolerant training stores data loader index in the checkpoints, so that training can resume
442+
without going over the same batch twice.
443+
444+
If enabled, data loader state is checkpointed. Otherwise, replicas
445+
will train over the same data multiple times, which can result in
446+
overfitting.
447+
448+
The failed replcia will still recover other state e.g. model
449+
parameters from other replcias.
450+
451+
Note, if regular checkpointing is enabled, we also checkpoint the
452+
data loader state. But when not using fault tolerance, the entire training starts from scratch.
453+
"""
454+
433455
folder: str = "checkpoint"
434456
"""
435457
The folder to store the checkpoints.

0 commit comments

Comments
 (0)