diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e2e643db6b..3be7aaf6e3 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -19,11 +19,16 @@ import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn -from torch.distributed.checkpoint import HuggingFaceStorageWriter -from torch.distributed.checkpoint._consolidate_hf_safetensors import ( - consolidate_safetensors_files_on_every_rank, +from torch.distributed.checkpoint import ( + HuggingFaceStorageReader, + HuggingFaceStorageWriter, ) -from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions +try: + from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions +except: + from torch.distributed.checkpoint.staging import BlockingAsyncStager as DefaultStager + def StagingOptions(a, b, c, d): + return True from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, set_model_state_dict, @@ -187,7 +192,6 @@ def __init__( ft_manager: FTManager | None = None, ) -> None: self.enable = checkpoint_config.enable - self.load_only = checkpoint_config.load_only self.ft_manager = ( ft_manager.manager if ft_manager and ft_manager.enabled else None @@ -345,8 +349,10 @@ def dcp_save( Future: The future object if the checkpoint is async, otherwise None. """ + ret: Future | None = None + storage_writer: HuggingFaceStorageWriter | None = None checkpoint_save_id: str | None = None if to_hf: @@ -355,6 +361,7 @@ def dcp_save( ), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." state_dict = self.sd_adapter.to_hf(state_dict) + fqn_to_index_mapping = self.sd_adapter.fqn_to_index_mapping if fqn_to_index_mapping: storage_writer = HuggingFaceStorageWriter( @@ -374,9 +381,11 @@ def dcp_save( enable_consolidation=True, ) + else: checkpoint_save_id = checkpoint_id + if async_mode == AsyncMode.ASYNC: ret = dcp.async_save( state_dict, @@ -400,7 +409,15 @@ def dcp_save( checkpoint_id=checkpoint_save_id, ) + if to_hf and self.sd_adapter.fqn_to_index_mapping: + try: + from torch.distributed.checkpoint._consolidate_hf_safetensors import ( + consolidate_safetensors_files_on_every_rank, + ) + except Exception as e: + raise e + consolidate_safetensors_files_on_every_rank( input_dir=os.path.join(checkpoint_id, "sharded"), output_dir=checkpoint_id, @@ -408,9 +425,11 @@ def dcp_save( num_threads=5, ) + if enable_garbage_collection: GarbageCollection.collect("GC collection invoked by checkpointer.") + return ret def dcp_load( @@ -499,7 +518,6 @@ def save(self, curr_step: int, last_step: bool = False) -> None: ) self.save_future = result.upload_completion self.staging_future = result.staging_completion - self.staging = True elif self.async_mode == AsyncMode.ASYNC: GarbageCollection.collect("GC collection invoked by checkpointer.") self.save_future = self.dcp_save( @@ -630,7 +648,6 @@ def maybe_wait_for_staging(self) -> None: """ if self.enable_staging and self.staging: self.staging_future.result() - self.staging = False def _find_load_step(self, folder: str = "") -> int: """Find the step to load the checkpoint for. @@ -776,7 +793,7 @@ def _save_last_step(self, curr_step: int) -> None: ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: - if not self.enable or self.load_only: + if not self.enable: return False if curr_step == 1 and self.enable_first_step_checkpoint: diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 3edb21ebfa..c3b9c2ceca 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -18,15 +18,19 @@ get_schedule_class, PipelineScheduleMulti, PipelineScheduleSingle, - ScheduleDualPipeV, ScheduleZBVZeroBubble, ) -from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.config import JobConfig from torchtitan.tools.logging import logger +try: + from torch.distributed.pipelining.schedules import ScheduleDualPipeV +except ImportError: + ScheduleDualPipeV = ScheduleZBVZeroBubble + + __all__ = [ "build_pipeline_schedule", "stage_ids_this_rank", @@ -83,8 +87,7 @@ def build_pipeline_schedule( schedule = schedule_class( stages if looped_schedule else stages[0], n_microbatches=n_microbatches, - loss_fn=rescale_accumulated_loss(loss_fn, n_microbatches), - scale_grads=False, + loss_fn=loss_fn, ) logger.info( f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} "