Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint import HuggingFaceStorageWriter
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
consolidate_safetensors_files_on_every_rank,
from torch.distributed.checkpoint import (
HuggingFaceStorageReader,
HuggingFaceStorageWriter,
)
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
try:
from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions
except:
from torch.distributed.checkpoint.staging import BlockingAsyncStager as DefaultStager
def StagingOptions(a, b, c, d):
return True
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
set_model_state_dict,
Expand Down Expand Up @@ -187,7 +192,6 @@ def __init__(
ft_manager: FTManager | None = None,
) -> None:
self.enable = checkpoint_config.enable
self.load_only = checkpoint_config.load_only

self.ft_manager = (
ft_manager.manager if ft_manager and ft_manager.enabled else None
Expand Down Expand Up @@ -345,8 +349,10 @@ def dcp_save(
Future: The future object if the checkpoint is async, otherwise None.
"""


ret: Future | None = None


storage_writer: HuggingFaceStorageWriter | None = None
checkpoint_save_id: str | None = None
if to_hf:
Expand All @@ -355,6 +361,7 @@ def dcp_save(
), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
state_dict = self.sd_adapter.to_hf(state_dict)


fqn_to_index_mapping = self.sd_adapter.fqn_to_index_mapping
if fqn_to_index_mapping:
storage_writer = HuggingFaceStorageWriter(
Expand All @@ -374,9 +381,11 @@ def dcp_save(
enable_consolidation=True,
)


else:
checkpoint_save_id = checkpoint_id


if async_mode == AsyncMode.ASYNC:
ret = dcp.async_save(
state_dict,
Expand All @@ -400,17 +409,27 @@ def dcp_save(
checkpoint_id=checkpoint_save_id,
)


if to_hf and self.sd_adapter.fqn_to_index_mapping:
try:
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
consolidate_safetensors_files_on_every_rank,
)
except Exception as e:
raise e

consolidate_safetensors_files_on_every_rank(
input_dir=os.path.join(checkpoint_id, "sharded"),
output_dir=checkpoint_id,
fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping,
num_threads=5,
)


if enable_garbage_collection:
GarbageCollection.collect("GC collection invoked by checkpointer.")


return ret

def dcp_load(
Expand Down Expand Up @@ -499,7 +518,6 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
)
self.save_future = result.upload_completion
self.staging_future = result.staging_completion
self.staging = True
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.save_future = self.dcp_save(
Expand Down Expand Up @@ -630,7 +648,6 @@ def maybe_wait_for_staging(self) -> None:
"""
if self.enable_staging and self.staging:
self.staging_future.result()
self.staging = False

def _find_load_step(self, folder: str = "") -> int:
"""Find the step to load the checkpoint for.
Expand Down Expand Up @@ -776,7 +793,7 @@ def _save_last_step(self, curr_step: int) -> None:
)

def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
if not self.enable or self.load_only:
if not self.enable:
return False

if curr_step == 1 and self.enable_first_step_checkpoint:
Expand Down
11 changes: 7 additions & 4 deletions torchtitan/distributed/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
get_schedule_class,
PipelineScheduleMulti,
PipelineScheduleSingle,
ScheduleDualPipeV,
ScheduleZBVZeroBubble,
)

from torchtitan.components.loss import rescale_accumulated_loss
from torchtitan.config import JobConfig
from torchtitan.tools.logging import logger


try:
from torch.distributed.pipelining.schedules import ScheduleDualPipeV
except ImportError:
ScheduleDualPipeV = ScheduleZBVZeroBubble


__all__ = [
"build_pipeline_schedule",
"stage_ids_this_rank",
Expand Down Expand Up @@ -83,8 +87,7 @@ def build_pipeline_schedule(
schedule = schedule_class(
stages if looped_schedule else stages[0],
n_microbatches=n_microbatches,
loss_fn=rescale_accumulated_loss(loss_fn, n_microbatches),
scale_grads=False,
loss_fn=loss_fn,
)
logger.info(
f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} "
Expand Down