2121 InternVL3P5Dense8BConfig ,
2222 InternVL3P5MoE30BA3Config ,
2323)
24- from xtuner .v1 .train .trainer import HooksConfig , Trainer , ResumeConfig , HookStage
24+ from xtuner .v1 .train .trainer import HooksConfig , Trainer , ResumeConfig , HookStage , LoadCheckpointConfig
2525from xtuner .v1 .datasets import FTDPTokenizeFnConfig
2626from xtuner .v1 .datasets .sft_tokenize_fn import OpenaiTokenizeFunctionConfig
2727from xtuner .v1 .train .trainer import TrainerConfig
2828from xtuner .v1 .engine .train_engine import LossLog , OtherLog
2929from xtuner .v1 .loss import CELossConfig
3030from xtuner ._testing import DeterministicDDPTestCase
3131from unittest import TestCase
32-
32+ from xtuner . v1 . train . trainer import XTunerMeta , ExpInfo , ExpHistory , GitInfo
3333from xtuner .v1 .utils .device import get_device
3434
3535
@@ -319,9 +319,7 @@ def test_resume(self):
319319 debug = False ,
320320 checkpoint_interval = 2 ,
321321 checkpoint_maxkeep = 2 ,
322- resume_cfg = ResumeConfig (
323- auto_resume = True ,
324- ),
322+ auto_resume = True ,
325323 )
326324 assert resume_trainer1 .cur_step == 6
327325 assert resume_trainer1 .exp_dir == trainer .exp_dir
@@ -347,9 +345,7 @@ def test_resume(self):
347345 debug = False ,
348346 checkpoint_interval = 2 ,
349347 checkpoint_maxkeep = 2 ,
350- resume_cfg = ResumeConfig (
351- auto_resume = True ,
352- ),
348+ auto_resume = True ,
353349 )
354350 assert resume_trainer1_2 .cur_step == 10
355351 assert resume_trainer1_2 .exp_dir == trainer .exp_dir
@@ -376,8 +372,8 @@ def test_resume(self):
376372 debug = False ,
377373 checkpoint_interval = 5 ,
378374 checkpoint_maxkeep = 2 ,
379- resume_cfg = ResumeConfig (
380- resume_from = resume_trainer1_2 .meta .latest_exp .checkpoint_list [- 2 ],
375+ load_checkpoint_cfg = LoadCheckpointConfig (
376+ checkpoint_path = resume_trainer1_2 .meta .latest_exp .checkpoint_list [- 2 ],
381377 ),
382378 )
383379 assert resume_trainer2 .cur_step == 14
@@ -612,3 +608,112 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
612608 loaded = pickle .loads (dumped )
613609 assert len (loaded .get_hooks (HookStage .AFTER_TRAIN_STEP )) == 0 # <local> object cannot be serialized
614610 assert len (loaded .get_hooks (HookStage .AFTER_SAVE_DCP )) == 1
611+
612+
613+ def test_resume_and_load_checkpoint_cfg (tmp_path : Path ):
614+ # 0. prepare environment
615+ os .environ ["LOCAL_RANK" ] = "0"
616+ os .environ ["RANK" ] = "0"
617+ os .environ ["WORLD_SIZE" ] = "1"
618+ os .environ ["MASTER_ADDR" ] = "localhost"
619+ os .environ ["MASTER_PORT" ] = "29500"
620+
621+ alpaca_path = os .environ ["ALPACA_PATH" ]
622+ tokenizer_path = os .environ ["QWEN3_MOE_PATH" ]
623+
624+ work_dir = tmp_path / "work_dir"
625+ fake_hf_model_dir = tmp_path / "fake_hf_model"
626+ fake_hf_model_dir .mkdir ()
627+ (fake_hf_model_dir / "config.json" ).write_text ('{"model_type": "fake_model"}' )
628+ (fake_hf_model_dir / "model.safetensors" ).write_text ("fake weights" )
629+
630+ model_cfg = Qwen3MoE30BA3Config ()
631+ optim_cfg = AdamWConfig (lr = 1e-4 , weight_decay = 0.01 )
632+
633+ dataset_cfg = [
634+ {
635+ "dataset" : DatasetConfig (name = "alpaca" , anno_path = alpaca_path , sample_ratio = 1.0 ),
636+ "tokenize_fn" : FTDPTokenizeFnConfig (),
637+ },
638+ ]
639+ dataloader_cfg = DataloaderConfig (dataset_config_list = dataset_cfg )
640+ lr_cfg = LRConfig (lr_type = "constant" , warmup_ratio = 0.1 , lr_min = 1e-6 )
641+
642+ Trainer .build_engine = Mock (return_value = FakeEngine ())
643+ Trainer ._resume = Mock ()
644+
645+ # 1. create: first train with auto_resume and load_checkpoint_cfg
646+ auto_resume = True
647+ load_checkpoint_cfg = LoadCheckpointConfig (
648+ checkpoint_path = work_dir / "other_checkpoint" ,
649+ load_optimizer_states = True ,
650+ load_optimizer_args = True ,
651+ load_dataset = False ,
652+ load_scheduler = False ,
653+ )
654+
655+ # 2. operate
656+ trainer = Trainer (
657+ load_from = fake_hf_model_dir ,
658+ model_cfg = model_cfg ,
659+ optim_cfg = optim_cfg ,
660+ dataloader_cfg = dataloader_cfg ,
661+ lr_cfg = lr_cfg ,
662+ tokenizer_path = tokenizer_path ,
663+ global_batch_size = 2 ,
664+ total_step = 10 ,
665+ checkpoint_interval = 5 ,
666+ work_dir = work_dir ,
667+ auto_resume = auto_resume ,
668+ load_checkpoint_cfg = load_checkpoint_cfg ,
669+ )
670+
671+ # 3. check: auto_resume does not overwrite load_checkpoint_cfg at first time
672+ assert trainer ._load_checkpoint_cfg .load_dataset is False
673+ assert trainer ._load_checkpoint_cfg .load_scheduler is False
674+
675+ del trainer
676+
677+ # 4. 2nd create: resume train with auto_resume and load_checkpoint_cfg
678+ exp_dir = work_dir / "fake_exp_dir"
679+ latest_checkpoint = exp_dir / "step-5"
680+ resume_meta = XTunerMeta (
681+ exps = [
682+ ExpInfo (
683+ checkpoint_list = [str (latest_checkpoint )],
684+ exp_dir = str (exp_dir ),
685+ history = [
686+ dict (
687+ begin = 0 ,
688+ timestamp = "20251126202933" ,
689+ git_info = dict (commit = "dae707" , staged = "" , unstaged = "" ),
690+ end = 5 ,
691+ ),
692+ ],
693+ ),
694+ ],
695+ )
696+ Trainer ._init_xtuner_meta = Mock (return_value = resume_meta )
697+ trainer2 = Trainer (
698+ load_from = str (fake_hf_model_dir ),
699+ model_cfg = model_cfg ,
700+ optim_cfg = optim_cfg ,
701+ dataloader_cfg = dataloader_cfg ,
702+ lr_cfg = lr_cfg ,
703+ tokenizer_path = tokenizer_path ,
704+ global_batch_size = 2 ,
705+ total_step = 10 ,
706+ checkpoint_interval = 5 ,
707+ work_dir = work_dir ,
708+ auto_resume = auto_resume ,
709+ load_checkpoint_cfg = load_checkpoint_cfg ,
710+ )
711+
712+ # 5. check: auto_resume overrides load_checkpoint_cfg when resume train
713+ assert trainer2 ._load_checkpoint_cfg .checkpoint_path == latest_checkpoint
714+ assert trainer2 ._load_checkpoint_cfg .load_dataset is True
715+ assert trainer2 ._load_checkpoint_cfg .load_scheduler is True
716+ assert trainer2 ._load_checkpoint_cfg .load_optimizer_states is True
717+ assert trainer2 ._load_checkpoint_cfg .load_optimizer_args is True
718+
719+ dist .destroy_process_group ()
0 commit comments