diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index cfe7f0ce6..52084022b 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -21,7 +21,7 @@ InternVL3P5Dense8BConfig, InternVL3P5MoE30BA3Config, ) -from xtuner.v1.train.trainer import HooksConfig, Trainer, ResumeConfig, HookStage +from xtuner.v1.train.trainer import HooksConfig, Trainer, ResumeConfig, HookStage, LoadCheckpointConfig from xtuner.v1.datasets import FTDPTokenizeFnConfig from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig from xtuner.v1.train.trainer import TrainerConfig @@ -29,7 +29,7 @@ from xtuner.v1.loss import CELossConfig from xtuner._testing import DeterministicDDPTestCase from unittest import TestCase - +from xtuner.v1.train.trainer import XTunerMeta, ExpInfo, ExpHistory, GitInfo from xtuner.v1.utils.device import get_device @@ -319,9 +319,7 @@ def test_resume(self): debug=False, checkpoint_interval=2, checkpoint_maxkeep=2, - resume_cfg=ResumeConfig( - auto_resume=True, - ), + auto_resume=True, ) assert resume_trainer1.cur_step == 6 assert resume_trainer1.exp_dir == trainer.exp_dir @@ -347,9 +345,7 @@ def test_resume(self): debug=False, checkpoint_interval=2, checkpoint_maxkeep=2, - resume_cfg=ResumeConfig( - auto_resume=True, - ), + auto_resume=True, ) assert resume_trainer1_2.cur_step == 10 assert resume_trainer1_2.exp_dir == trainer.exp_dir @@ -376,8 +372,8 @@ def test_resume(self): debug=False, checkpoint_interval=5, checkpoint_maxkeep=2, - resume_cfg=ResumeConfig( - resume_from=resume_trainer1_2.meta.latest_exp.checkpoint_list[-2], + load_checkpoint_cfg=LoadCheckpointConfig( + checkpoint_path=resume_trainer1_2.meta.latest_exp.checkpoint_list[-2], ), ) assert resume_trainer2.cur_step == 14 @@ -612,3 +608,112 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch): loaded = pickle.loads(dumped) assert len(loaded.get_hooks(HookStage.AFTER_TRAIN_STEP)) == 0 # object cannot be serialized assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1 + + +def test_resume_and_load_checkpoint_cfg(tmp_path: Path): + # 0. prepare environment + os.environ["LOCAL_RANK"] = "0" + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + + alpaca_path = os.environ["ALPACA_PATH"] + tokenizer_path = os.environ["QWEN3_MOE_PATH"] + + work_dir = tmp_path / "work_dir" + fake_hf_model_dir = tmp_path / "fake_hf_model" + fake_hf_model_dir.mkdir() + (fake_hf_model_dir / "config.json").write_text('{"model_type": "fake_model"}') + (fake_hf_model_dir / "model.safetensors").write_text("fake weights") + + model_cfg = Qwen3MoE30BA3Config() + optim_cfg = AdamWConfig(lr=1e-4, weight_decay=0.01) + + dataset_cfg = [ + { + "dataset": DatasetConfig(name="alpaca", anno_path=alpaca_path, sample_ratio=1.0), + "tokenize_fn": FTDPTokenizeFnConfig(), + }, + ] + dataloader_cfg = DataloaderConfig(dataset_config_list=dataset_cfg) + lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0.1, lr_min=1e-6) + + Trainer.build_engine = Mock(return_value=FakeEngine()) + Trainer._resume = Mock() + + # 1. create: first train with auto_resume and load_checkpoint_cfg + auto_resume = True + load_checkpoint_cfg = LoadCheckpointConfig( + checkpoint_path=work_dir / "other_checkpoint", + load_optimizer_states=True, + load_optimizer_args=True, + load_dataset=False, + load_scheduler=False, + ) + + # 2. operate + trainer = Trainer( + load_from=fake_hf_model_dir, + model_cfg=model_cfg, + optim_cfg=optim_cfg, + dataloader_cfg=dataloader_cfg, + lr_cfg=lr_cfg, + tokenizer_path=tokenizer_path, + global_batch_size=2, + total_step=10, + checkpoint_interval=5, + work_dir=work_dir, + auto_resume=auto_resume, + load_checkpoint_cfg=load_checkpoint_cfg, + ) + + # 3. check: auto_resume does not overwrite load_checkpoint_cfg at first time + assert trainer._load_checkpoint_cfg.load_dataset is False + assert trainer._load_checkpoint_cfg.load_scheduler is False + + del trainer + + # 4. 2nd create: resume train with auto_resume and load_checkpoint_cfg + exp_dir = work_dir / "fake_exp_dir" + latest_checkpoint = exp_dir / "step-5" + resume_meta = XTunerMeta( + exps=[ + ExpInfo( + checkpoint_list=[str(latest_checkpoint)], + exp_dir=str(exp_dir), + history=[ + dict( + begin=0, + timestamp="20251126202933", + git_info=dict(commit="dae707", staged="", unstaged=""), + end=5, + ), + ], + ), + ], + ) + Trainer._init_xtuner_meta = Mock(return_value=resume_meta) + trainer2 = Trainer( + load_from=str(fake_hf_model_dir), + model_cfg=model_cfg, + optim_cfg=optim_cfg, + dataloader_cfg=dataloader_cfg, + lr_cfg=lr_cfg, + tokenizer_path=tokenizer_path, + global_batch_size=2, + total_step=10, + checkpoint_interval=5, + work_dir=work_dir, + auto_resume=auto_resume, + load_checkpoint_cfg=load_checkpoint_cfg, + ) + + # 5. check: auto_resume overrides load_checkpoint_cfg when resume train + assert trainer2._load_checkpoint_cfg.checkpoint_path == latest_checkpoint + assert trainer2._load_checkpoint_cfg.load_dataset is True + assert trainer2._load_checkpoint_cfg.load_scheduler is True + assert trainer2._load_checkpoint_cfg.load_optimizer_states is True + assert trainer2._load_checkpoint_cfg.load_optimizer_args is True + + dist.destroy_process_group() diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index c8da31fa1..eb53659d5 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -32,7 +32,7 @@ from xtuner.v1.utils.device import get_device, get_torch_device_module from xtuner.v1.utils.env_check import get_rollout_engine_version -from .trainer import ExpHistory, ExpInfo, GitInfo, XTunerMeta +from .trainer import ExpHistory, ExpInfo, GitInfo, LoadCheckpointConfig, XTunerMeta # TODO: Move DEVICE to `xtuner.utils.device` @@ -65,6 +65,8 @@ class RLTrainerConfig(BaseModel): log_dir: Path | str | None = None total_epochs: int resume_config: ResumeConfig | None = None + auto_resume: bool = False + load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig() strict_load: bool = True hf_interval: int | None = None hf_max_keep: int | None = None @@ -157,6 +159,8 @@ class RLTrainer: enable_evaluate (bool): Whether to perform periodic evaluation during training. resume_config (ResumeConfig | None): Configuration for resuming training from a previous checkpoint. Defaults to None. + auto_resume (bool): Whether to automatically resume training. Defaults to False. + load_checkpoint_cfg (LoadCheckpointConfig): Configuration for loading checkpoints. strict_load (bool): Whether to strictly enforce checkpoint loading compatibility. Defaults to True. hf_interval (int | None): Interval (in epochs) for saving HuggingFace format @@ -203,7 +207,9 @@ def __init__( work_dir: Path | str | None = None, log_dir: Path | str | None = None, total_epochs: int, - resume_config: ResumeConfig | None = None, + resume_config: ResumeConfig | None = None, # TODO: Removed in version 1.1.0 + auto_resume: bool = False, + load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig(), strict_load: bool = True, hf_interval: int | None = None, hf_max_keep: int | None = None, @@ -257,7 +263,8 @@ def __init__( work_dir.mkdir(parents=True, exist_ok=True) self._work_dir = work_dir - self._meta = self._init_xtuner_meta(work_dir, resume_config is not None) + auto_resume = auto_resume or (resume_config is not None and resume_config.auto_resume) + self._meta = self._init_xtuner_meta(work_dir, auto_resume) if log_dir is None: log_dir = self.exp_dir @@ -265,6 +272,11 @@ def __init__( log_dir = Path(log_dir) self.logger = self._init_logger(log_dir) + + self.logger.warning( + "`resume_config` is deprecated, please use `auto_resume` and `load_checkpoint_cfg` instead" + ) + train_worker_cfg.log_dir = log_dir dataflow_config.worker_log_dir = log_dir rollout_config.worker_log_dir = log_dir @@ -338,6 +350,8 @@ def from_config(cls, config: RLTrainerConfig) -> Self: log_dir=config.log_dir, total_epochs=config.total_epochs, resume_config=config.resume_config, + auto_resume=config.auto_resume, + load_checkpoint_cfg=config.load_checkpoint_cfg, strict_load=config.strict_load, hf_interval=config.hf_interval, hf_max_keep=config.hf_max_keep, diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 4aa124808..f80b0d4aa 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -277,6 +277,15 @@ def __setstate__(self, state): self.__dict__.update(valid_state) +class LoadCheckpointConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + checkpoint_path: str | Path | None = None + load_optimizer_states: bool = True + load_optimizer_args: bool = True + load_dataset: bool = True + load_scheduler: bool = True + + class TrainerConfig(BaseModel): model_config = ConfigDict( title="Trainer config", @@ -301,7 +310,9 @@ class TrainerConfig(BaseModel): sp_size: int = 1 total_step: int | None = None total_epoch: int | None = None - resume_cfg: ResumeConfig | None = None + resume_cfg: ResumeConfig | None = None # TODO: Removed in version 1.1.0 + auto_resume: bool = False + load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig() strict_load: bool = True checkpoint_interval: int | None = -1 checkpoint_maxkeep: int | None = -1 @@ -373,6 +384,8 @@ class Trainer: total_step (int | None): Total training steps. total_epoch (int | None): Number of training epochs. resume_cfg (ResumeConfig | None): Configuration for resuming training. + auto_resume (bool): Whether to automatically resume training. Defaults to False. + load_checkpoint_cfg (LoadCheckpointConfig): Configuration for loading checkpoints. strict_load (bool): Whether to strictly load model weights. checkpoint_interval (int | None): Interval for saving checkpoints. checkpoint_maxkeep (int | None): Maximum number of checkpoints to keep. @@ -419,7 +432,9 @@ def __init__( sp_size: int = 1, total_step: int | None = None, total_epoch: int | None = None, - resume_cfg: ResumeConfig | None = ResumeConfig(), + resume_cfg: ResumeConfig | None = ResumeConfig(), # TODO: Removed in version 1.1.0 + auto_resume: bool = False, + load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig(), strict_load: bool = True, checkpoint_interval: int | None = -1, checkpoint_maxkeep: int | None = -1, @@ -507,9 +522,14 @@ def __init__( resume_cfg = ResumeConfig() self._work_dir = self._resolve_work_dir(work_dir) - self._meta = self._init_xtuner_meta(self.work_dir, auto_resume=resume_cfg.auto_resume) + logger.warning("`resume_cfg` is deprecated, please use `auto_resume` and `load_checkpoint_cfg` instead") + self._auto_resume = auto_resume + self._auto_resume = self._resolve_deprecated_resume_cfg( + resume_cfg, self._auto_resume + ) # TODO: Removed in version 1.1.0 + self._meta = self._init_xtuner_meta(self.work_dir, auto_resume=self._auto_resume) self._log_dir = self._resolve_log_dir(log_dir) - self._resume_cfg = self._resolve_resume_cfg(resume_cfg) + self._load_checkpoint_cfg = self._resolve_load_checkpoint_cfg(self._auto_resume, load_checkpoint_cfg) self.logger, log_dir = self._init_logger(self._log_dir) self._exp_tracker = self._init_tracker( @@ -529,6 +549,7 @@ def __init__( self._resolve_config_conflicts(self.tokenizer, model_cfg, dataloader_cfg) if dataset_cfg is not None: # TODO: Removed in version 1.1.0 + logger.warning("`dataset_cfg` is deprecated, please use `dataloader_cfg.dataset_config_list` instead") # For backward compatibility, reserve the dataset_cfg interface, remove it later if dataloader_cfg.dataset_config_list is not None: logger.warning("Outside dataset_cfg will override inner dataset_config_list") @@ -557,7 +578,7 @@ def __init__( model_config=model_cfg, optim_config=optim_cfg, fsdp_config=fsdp_cfg, - resume_cfg=resume_cfg, + load_checkpoint_path=self._load_checkpoint_cfg.checkpoint_path, strict=strict_load, intra_layer_micro_batch=intra_layer_micro_batch, ) @@ -581,8 +602,8 @@ def __init__( self._checkpoint_interval = None self._snapshot_interval = None - if self._resume_cfg.resume_from is not None: - self._resume() + if self._load_checkpoint_cfg.checkpoint_path is not None: + self._load_checkpoint() self.hooks_config = self._setup_hooks(hooks_config=hooks_config) @@ -615,6 +636,8 @@ def from_config(cls, config: TrainerConfig) -> Self: total_step=config.total_step, total_epoch=config.total_epoch, resume_cfg=config.resume_cfg, + auto_resume=config.auto_resume, + load_checkpoint_cfg=config.load_checkpoint_cfg, strict_load=config.strict_load, checkpoint_interval=config.checkpoint_interval, checkpoint_maxkeep=config.checkpoint_maxkeep, @@ -887,7 +910,7 @@ def build_engine( model_config: TransformerConfig | VisionComposeConfigProtocol, optim_config: OptimConfig, fsdp_config: FSDPConfig, - resume_cfg: ResumeConfig, + load_checkpoint_path: str | Path | None, intra_layer_micro_batch: int = 1, strict: bool = True, ): @@ -919,9 +942,9 @@ def build_engine( model_cfg=model_config, intra_layer_micro_batch=intra_layer_micro_batch, ) - if model_path is not None and (model_config.dcp_ignore_frozen_params or resume_cfg.resume_from is None): + if model_path is not None and (model_config.dcp_ignore_frozen_params or load_checkpoint_path is None): engine.from_hf(hf_path=model_path, strict=strict) - elif resume_cfg.resume_from is None: + elif load_checkpoint_path is None: engine.init_model_weights() if model_path is not None: @@ -989,11 +1012,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: meta_path = self.work_dir / self._META_PATH - optimizer_path = ( - checkpoint_path / self._SAVE_OPTIMIZER_DIR - if self._resume_cfg.load_optimizer_states or self._resume_cfg.load_optimizer_args - else None - ) + optimizer_path = checkpoint_path / self._SAVE_OPTIMIZER_DIR model_path = checkpoint_path / self._SAVE_MODEL_DIR dataloader_path = checkpoint_path / self._SAVE_DATALOADER_DIR scheduler_path = checkpoint_path / self._SAVE_SCHEDULER_DIR @@ -1505,16 +1524,29 @@ def _resolve_config_conflicts( ) dataloader_cfg.pad_token_id = pad_token_id - def _resolve_resume_cfg(self, resume_cfg: ResumeConfig): - latest_checkpoint = self.meta.latest_exp.latest_checkpoint - if latest_checkpoint is not None and resume_cfg.auto_resume: - resume_cfg.resume_from = Path(latest_checkpoint) - return resume_cfg - - def _resume(self): - resume_cfg = self._resume_cfg + def _resolve_deprecated_resume_cfg(self, resume_cfg: ResumeConfig, auto_resume: bool) -> bool: + if resume_cfg.auto_resume: + return True + return auto_resume - if (resume_from := resume_cfg.resume_from) is None: + def _resolve_load_checkpoint_cfg( + self, auto_resume: bool, load_checkpoint_cfg: LoadCheckpointConfig + ) -> LoadCheckpointConfig: + # auto_resume优先级高,如果有latest ckp,则说明走auto_resume逻辑 + # 此时,覆盖load checkpoint path,并且加载optimizer states, optimizer args, dataset, scheduler + latest_checkpoint = self.meta.latest_exp.latest_checkpoint + if latest_checkpoint is not None and auto_resume: + load_checkpoint_cfg.checkpoint_path = Path(latest_checkpoint) + load_checkpoint_cfg.load_optimizer_states = True + load_checkpoint_cfg.load_optimizer_args = True + load_checkpoint_cfg.load_dataset = True + load_checkpoint_cfg.load_scheduler = True + return load_checkpoint_cfg + + def _load_checkpoint(self): + load_checkpoint_cfg: LoadCheckpointConfig = self._load_checkpoint_cfg + + if (resume_from := load_checkpoint_cfg.checkpoint_path) is None: logger.info("No checkpoint to resume from.") return @@ -1528,18 +1560,18 @@ def _resume(self): model_path = resume_from / self._SAVE_MODEL_DIR optimizer_path = ( resume_from / self._SAVE_OPTIMIZER_DIR - if self._resume_cfg.load_optimizer_states or self._resume_cfg.load_optimizer_args + if load_checkpoint_cfg.load_optimizer_states or load_checkpoint_cfg.load_optimizer_args else None ) self._engine.load_dcp( model_dir=model_path, optimizer_dir=optimizer_path, - load_states=self._resume_cfg.load_optimizer_states, - load_args=self._resume_cfg.load_optimizer_args, + load_states=load_checkpoint_cfg.load_optimizer_states, + load_args=load_checkpoint_cfg.load_optimizer_args, ) - if resume_cfg.load_dataset: + if load_checkpoint_cfg.load_dataset: dataloader_path = resume_from / self._SAVE_DATALOADER_DIR self._resume_dataloader(dataloader_path) @@ -1554,7 +1586,7 @@ def _resume(self): self._consumed_tokens = train_state["consumed_tokens"] self._train_time_offset = train_state["train_time_offset"] - if resume_cfg.load_scheduler: + if load_checkpoint_cfg.load_scheduler: scheduler_path = resume_from / self._SAVE_SCHEDULER_DIR if not scheduler_path.exists(): raise FileNotFoundError(f"Scheduler path {scheduler_path} does not exist.")