diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e2e643db6b..d4c5416aa2 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -190,8 +190,19 @@ def __init__( self.load_only = checkpoint_config.load_only self.ft_manager = ( - ft_manager.manager if ft_manager and ft_manager.enabled else None + ft_manager.manager + if ft_manager + and ft_manager.enabled + and checkpoint_config.enable_ft_dataloader_checkpoints + else None ) + + if ft_manager and ft_manager.enabled and not self.ft_manager: + logger.warn( + "Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. " + "This means replicas can retrain over the same data multiple times, which can result in overfitting." + ) + if self.ft_manager: optimizers.init_cache_state_dict() diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index a003859a3b..d7e2752aea 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -430,6 +430,28 @@ class Checkpoint: enable: bool = False """Whether to enable checkpoint""" + enable_ft_dataloader_checkpoints: bool = True + """ + Warning: Disabling this can have fault tolerant replicas training + over the same data multiple times. Use it with caution if training + over the same data is acceptable. + + Used to enable checkpointing the dataloader index for fault tolerant training with torchft. + + Fault tolerant training stores data loader index in the checkpoints, so that training can resume + without going over the same batch twice. + + If enabled, data loader state is checkpointed. Otherwise, replicas + will train over the same data multiple times, which can result in + overfitting. + + The failed replcia will still recover other state e.g. model + parameters from other replcias. + + Note, if regular checkpointing is enabled, we also checkpoint the + data loader state. But when not using fault tolerance, the entire training starts from scratch. + """ + folder: str = "checkpoint" """ The folder to store the checkpoints.