From cfd29fcdbf264eeb3599afad3fb83b045cad60bd Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Thu, 9 Oct 2025 10:43:42 -0500 Subject: [PATCH 1/3] fix imports in `components/checkpoint.py` --- torchtitan/components/checkpoint.py | 48 +++++++++-------------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 1b25aa3f54..9762b99882 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -23,10 +23,12 @@ HuggingFaceStorageReader, HuggingFaceStorageWriter, ) -from torch.distributed.checkpoint._consolidate_hf_safetensors import ( - consolidate_safetensors_files_on_every_rank, -) -from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions +try: + from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions +except: + from torch.distributed.checkpoint.staging import BlockingAsyncStager as DefaultStager + def StagingOptions(a, b, c, d): + return True from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, set_model_state_dict, @@ -190,7 +192,6 @@ def __init__( ft_manager: FTManager | None = None, ) -> None: self.enable = checkpoint_config.enable - self.load_only = checkpoint_config.load_only self.ft_manager = ( ft_manager.manager if ft_manager and ft_manager.enabled else None @@ -356,23 +357,14 @@ def dcp_save( state_dict = self.sd_adapter.to_hf(state_dict) fqn_to_index_mapping = self.sd_adapter.fqn_to_index_mapping - if fqn_to_index_mapping: - storage_writer = HuggingFaceStorageWriter( - path=os.path.join(checkpoint_id, "sharded"), - save_distributed=True, - fqn_to_index_mapping=fqn_to_index_mapping, - enable_consolidation=False, - ) - else: - # the reason for only enabling consolidation if there is - # no mapping is because no mapping implies that we save all fqns - # to one file. This means we only need one rank to consolidate. - # Otherwise we should use consolidate_safetensors_files_on_every_rank - storage_writer = HuggingFaceStorageWriter( - path=checkpoint_id, - save_distributed=True, - enable_consolidation=True, - ) + + storage_writer = HuggingFaceStorageWriter( + path=checkpoint_id, + save_distributed=True, + fqn_to_index_mapping=fqn_to_index_mapping, + enable_consolidation=True, + thread_count_consolidation=5, + ) else: checkpoint_save_id = checkpoint_id @@ -400,14 +392,6 @@ def dcp_save( checkpoint_id=checkpoint_save_id, ) - if to_hf and self.sd_adapter.fqn_to_index_mapping: - consolidate_safetensors_files_on_every_rank( - input_dir=os.path.join(checkpoint_id, "sharded"), - output_dir=checkpoint_id, - fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, - num_threads=5, - ) - if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") @@ -495,7 +479,6 @@ def save(self, curr_step: int, last_step: bool = False) -> None: ) self.save_future = result.upload_completion self.staging_future = result.staging_completion - self.staging = True elif self.async_mode == AsyncMode.ASYNC: GarbageCollection.collect("GC collection invoked by checkpointer.") self.save_future = self.dcp_save( @@ -617,7 +600,6 @@ def maybe_wait_for_staging(self) -> None: """ if self.enable_staging and self.staging: self.staging_future.result() - self.staging = False def _find_load_step(self, folder: str = "") -> int: """Find the step to load the checkpoint for. @@ -762,7 +744,7 @@ def _save_last_step(self, curr_step: int) -> None: ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: - if not self.enable or self.load_only: + if not self.enable: return False if curr_step == 1 and self.enable_first_step_checkpoint: From 443d5a7b2d6d1cac3a0cab5b85d40d8346202379 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Thu, 9 Oct 2025 10:44:01 -0500 Subject: [PATCH 2/3] fix `distributed/pipeline_parallel.py` --- torchtitan/distributed/pipeline_parallel.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 3edb21ebfa..c3b9c2ceca 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -18,15 +18,19 @@ get_schedule_class, PipelineScheduleMulti, PipelineScheduleSingle, - ScheduleDualPipeV, ScheduleZBVZeroBubble, ) -from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.config import JobConfig from torchtitan.tools.logging import logger +try: + from torch.distributed.pipelining.schedules import ScheduleDualPipeV +except ImportError: + ScheduleDualPipeV = ScheduleZBVZeroBubble + + __all__ = [ "build_pipeline_schedule", "stage_ids_this_rank", @@ -83,8 +87,7 @@ def build_pipeline_schedule( schedule = schedule_class( stages if looped_schedule else stages[0], n_microbatches=n_microbatches, - loss_fn=rescale_accumulated_loss(loss_fn, n_microbatches), - scale_grads=False, + loss_fn=loss_fn, ) logger.info( f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} " From f9b2a839b3b5fd19f98e2d19faa3458fcd1117d7 Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Thu, 9 Oct 2025 10:59:10 -0500 Subject: [PATCH 3/3] fix: Resolve conflicts in `components/checkpoint.py` --- torchtitan/components/checkpoint.py | 46 ++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 9762b99882..2b2c9238b4 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -346,8 +346,10 @@ def dcp_save( Future: The future object if the checkpoint is async, otherwise None. """ + ret: Future | None = None + storage_writer: HuggingFaceStorageWriter | None = None checkpoint_save_id: str | None = None if to_hf: @@ -356,19 +358,31 @@ def dcp_save( ), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." state_dict = self.sd_adapter.to_hf(state_dict) + fqn_to_index_mapping = self.sd_adapter.fqn_to_index_mapping + if fqn_to_index_mapping: + storage_writer = HuggingFaceStorageWriter( + path=os.path.join(checkpoint_id, "sharded"), + save_distributed=True, + fqn_to_index_mapping=fqn_to_index_mapping, + enable_consolidation=False, + ) + else: + # the reason for only enabling consolidation if there is + # no mapping is because no mapping implies that we save all fqns + # to one file. This means we only need one rank to consolidate. + # Otherwise we should use consolidate_safetensors_files_on_every_rank + storage_writer = HuggingFaceStorageWriter( + path=checkpoint_id, + save_distributed=True, + enable_consolidation=True, + ) - storage_writer = HuggingFaceStorageWriter( - path=checkpoint_id, - save_distributed=True, - fqn_to_index_mapping=fqn_to_index_mapping, - enable_consolidation=True, - thread_count_consolidation=5, - ) else: checkpoint_save_id = checkpoint_id + if async_mode == AsyncMode.ASYNC: ret = dcp.async_save( state_dict, @@ -392,9 +406,27 @@ def dcp_save( checkpoint_id=checkpoint_save_id, ) + + if to_hf and self.sd_adapter.fqn_to_index_mapping: + try: + from torch.distributed.checkpoint._consolidate_hf_safetensors import ( + consolidate_safetensors_files_on_every_rank, + ) + except Exception as e: + raise e + + consolidate_safetensors_files_on_every_rank( + input_dir=os.path.join(checkpoint_id, "sharded"), + output_dir=checkpoint_id, + fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, + num_threads=5, + ) + + if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") + return ret def dcp_load(