Skip to content

Commit 4c16ff2

Browse files
jayhenryHAOCHENYE
authored andcommitted
[Feature] add LoadCheckpointConfig & refactor auto_resume (#1301)
* [Feature] add LoadCheckpointConfig & refactor auto_resume
1 parent dc9f027 commit 4c16ff2

File tree

4 files changed

+272
-42
lines changed

4 files changed

+272
-42
lines changed

tests/scripts/save_hf_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import argparse
2+
import time
3+
import torch
4+
import torch.distributed as dist
5+
from xtuner.v1.model import get_model_config_from_hf
6+
from xtuner.v1.config import FSDPConfig
7+
8+
# fallback: 允许未安装 memory_profiler 时正常运行
9+
try:
10+
from memory_profiler import profile
11+
except Exception:
12+
def profile(fn): # no-op
13+
return fn
14+
15+
MB = 1024 ** 2
16+
17+
def get_args():
18+
p = argparse.ArgumentParser("Profile build/shard/save with @profile (RSS) and simple GPU stats")
19+
p.add_argument("hf_path", type=str, help="HF model path")
20+
p.add_argument("out", type=str, help="Output HF path")
21+
p.add_argument("--ep", type=int, default=1, help="expert parallel size")
22+
return p.parse_args()
23+
24+
def set_device_for_rank():
25+
if torch.cuda.is_available():
26+
rank = dist.get_rank() if dist.is_initialized() else 0
27+
torch.cuda.set_device(rank % torch.cuda.device_count())
28+
29+
def gpu_mem(label):
30+
if not torch.cuda.is_available():
31+
print(f"[GPU] {label}: no CUDA")
32+
return
33+
torch.cuda.synchronize()
34+
alloc = torch.cuda.memory_allocated() / MB
35+
reserved = torch.cuda.memory_reserved() / MB
36+
peak = torch.cuda.max_memory_allocated() / MB
37+
print(f"[GPU] {label}: alloc={alloc:.2f}MB reserved={reserved:.2f}MB peak={peak:.2f}MB")
38+
39+
def build_model(hf_path: str):
40+
cfg = get_model_config_from_hf(hf_path)
41+
model = cfg.build()
42+
return model
43+
44+
def shard_model(model, ep: int):
45+
fsdp_cfg = FSDPConfig(ep_size=ep)
46+
model.fully_shard(fsdp_config=fsdp_cfg)
47+
return model
48+
49+
@profile
50+
def save_model(model, out: str):
51+
model.save_hf(out)
52+
53+
def main():
54+
args = get_args()
55+
56+
dist.init_process_group(backend="nccl")
57+
set_device_for_rank()
58+
59+
t0 = time.perf_counter()
60+
gpu_mem("init")
61+
62+
torch.cuda.reset_peak_memory_stats()
63+
model = build_model(args.hf_path)
64+
gpu_mem("after_build")
65+
66+
torch.cuda.reset_peak_memory_stats()
67+
shard_model(model, args.ep)
68+
gpu_mem("after_shard")
69+
70+
torch.cuda.reset_peak_memory_stats()
71+
save_model(model, args.out)
72+
gpu_mem("after_save")
73+
74+
print(f"[TIME] total={time.perf_counter()-t0:.3f}s")
75+
76+
dist.destroy_process_group()
77+
78+
if __name__ == "__main__":
79+
main()

tests/train/test_trainer.py

Lines changed: 115 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@
2121
InternVL3P5Dense8BConfig,
2222
InternVL3P5MoE30BA3Config,
2323
)
24-
from xtuner.v1.train.trainer import HooksConfig, Trainer, ResumeConfig, HookStage
24+
from xtuner.v1.train.trainer import HooksConfig, Trainer, ResumeConfig, HookStage, LoadCheckpointConfig
2525
from xtuner.v1.datasets import FTDPTokenizeFnConfig
2626
from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig
2727
from xtuner.v1.train.trainer import TrainerConfig
2828
from xtuner.v1.engine.train_engine import LossLog, OtherLog
2929
from xtuner.v1.loss import CELossConfig
3030
from xtuner._testing import DeterministicDDPTestCase
3131
from unittest import TestCase
32-
32+
from xtuner.v1.train.trainer import XTunerMeta, ExpInfo, ExpHistory, GitInfo
3333
from xtuner.v1.utils.device import get_device
3434

3535

@@ -319,9 +319,7 @@ def test_resume(self):
319319
debug=False,
320320
checkpoint_interval=2,
321321
checkpoint_maxkeep=2,
322-
resume_cfg=ResumeConfig(
323-
auto_resume=True,
324-
),
322+
auto_resume=True,
325323
)
326324
assert resume_trainer1.cur_step == 6
327325
assert resume_trainer1.exp_dir == trainer.exp_dir
@@ -347,9 +345,7 @@ def test_resume(self):
347345
debug=False,
348346
checkpoint_interval=2,
349347
checkpoint_maxkeep=2,
350-
resume_cfg=ResumeConfig(
351-
auto_resume=True,
352-
),
348+
auto_resume=True,
353349
)
354350
assert resume_trainer1_2.cur_step == 10
355351
assert resume_trainer1_2.exp_dir == trainer.exp_dir
@@ -376,8 +372,8 @@ def test_resume(self):
376372
debug=False,
377373
checkpoint_interval=5,
378374
checkpoint_maxkeep=2,
379-
resume_cfg=ResumeConfig(
380-
resume_from=resume_trainer1_2.meta.latest_exp.checkpoint_list[-2],
375+
load_checkpoint_cfg=LoadCheckpointConfig(
376+
checkpoint_path=resume_trainer1_2.meta.latest_exp.checkpoint_list[-2],
381377
),
382378
)
383379
assert resume_trainer2.cur_step == 14
@@ -612,3 +608,112 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
612608
loaded = pickle.loads(dumped)
613609
assert len(loaded.get_hooks(HookStage.AFTER_TRAIN_STEP)) == 0 # <local> object cannot be serialized
614610
assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1
611+
612+
613+
def test_resume_and_load_checkpoint_cfg(tmp_path: Path):
614+
# 0. prepare environment
615+
os.environ["LOCAL_RANK"] = "0"
616+
os.environ["RANK"] = "0"
617+
os.environ["WORLD_SIZE"] = "1"
618+
os.environ["MASTER_ADDR"] = "localhost"
619+
os.environ["MASTER_PORT"] = "29500"
620+
621+
alpaca_path = os.environ["ALPACA_PATH"]
622+
tokenizer_path = os.environ["QWEN3_MOE_PATH"]
623+
624+
work_dir = tmp_path / "work_dir"
625+
fake_hf_model_dir = tmp_path / "fake_hf_model"
626+
fake_hf_model_dir.mkdir()
627+
(fake_hf_model_dir / "config.json").write_text('{"model_type": "fake_model"}')
628+
(fake_hf_model_dir / "model.safetensors").write_text("fake weights")
629+
630+
model_cfg = Qwen3MoE30BA3Config()
631+
optim_cfg = AdamWConfig(lr=1e-4, weight_decay=0.01)
632+
633+
dataset_cfg = [
634+
{
635+
"dataset": DatasetConfig(name="alpaca", anno_path=alpaca_path, sample_ratio=1.0),
636+
"tokenize_fn": FTDPTokenizeFnConfig(),
637+
},
638+
]
639+
dataloader_cfg = DataloaderConfig(dataset_config_list=dataset_cfg)
640+
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0.1, lr_min=1e-6)
641+
642+
Trainer.build_engine = Mock(return_value=FakeEngine())
643+
Trainer._resume = Mock()
644+
645+
# 1. create: first train with auto_resume and load_checkpoint_cfg
646+
auto_resume = True
647+
load_checkpoint_cfg = LoadCheckpointConfig(
648+
checkpoint_path=work_dir / "other_checkpoint",
649+
load_optimizer_states=True,
650+
load_optimizer_args=True,
651+
load_dataset=False,
652+
load_scheduler=False,
653+
)
654+
655+
# 2. operate
656+
trainer = Trainer(
657+
load_from=fake_hf_model_dir,
658+
model_cfg=model_cfg,
659+
optim_cfg=optim_cfg,
660+
dataloader_cfg=dataloader_cfg,
661+
lr_cfg=lr_cfg,
662+
tokenizer_path=tokenizer_path,
663+
global_batch_size=2,
664+
total_step=10,
665+
checkpoint_interval=5,
666+
work_dir=work_dir,
667+
auto_resume=auto_resume,
668+
load_checkpoint_cfg=load_checkpoint_cfg,
669+
)
670+
671+
# 3. check: auto_resume does not overwrite load_checkpoint_cfg at first time
672+
assert trainer._load_checkpoint_cfg.load_dataset is False
673+
assert trainer._load_checkpoint_cfg.load_scheduler is False
674+
675+
del trainer
676+
677+
# 4. 2nd create: resume train with auto_resume and load_checkpoint_cfg
678+
exp_dir = work_dir / "fake_exp_dir"
679+
latest_checkpoint = exp_dir / "step-5"
680+
resume_meta = XTunerMeta(
681+
exps=[
682+
ExpInfo(
683+
checkpoint_list=[str(latest_checkpoint)],
684+
exp_dir=str(exp_dir),
685+
history=[
686+
dict(
687+
begin=0,
688+
timestamp="20251126202933",
689+
git_info=dict(commit="dae707", staged="", unstaged=""),
690+
end=5,
691+
),
692+
],
693+
),
694+
],
695+
)
696+
Trainer._init_xtuner_meta = Mock(return_value=resume_meta)
697+
trainer2 = Trainer(
698+
load_from=str(fake_hf_model_dir),
699+
model_cfg=model_cfg,
700+
optim_cfg=optim_cfg,
701+
dataloader_cfg=dataloader_cfg,
702+
lr_cfg=lr_cfg,
703+
tokenizer_path=tokenizer_path,
704+
global_batch_size=2,
705+
total_step=10,
706+
checkpoint_interval=5,
707+
work_dir=work_dir,
708+
auto_resume=auto_resume,
709+
load_checkpoint_cfg=load_checkpoint_cfg,
710+
)
711+
712+
# 5. check: auto_resume overrides load_checkpoint_cfg when resume train
713+
assert trainer2._load_checkpoint_cfg.checkpoint_path == latest_checkpoint
714+
assert trainer2._load_checkpoint_cfg.load_dataset is True
715+
assert trainer2._load_checkpoint_cfg.load_scheduler is True
716+
assert trainer2._load_checkpoint_cfg.load_optimizer_states is True
717+
assert trainer2._load_checkpoint_cfg.load_optimizer_args is True
718+
719+
dist.destroy_process_group()

xtuner/v1/train/rl_trainer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from xtuner.v1.utils.device import get_device, get_torch_device_module
3333
from xtuner.v1.utils.env_check import get_rollout_engine_version
3434

35-
from .trainer import ExpHistory, ExpInfo, GitInfo, XTunerMeta
35+
from .trainer import ExpHistory, ExpInfo, GitInfo, LoadCheckpointConfig, XTunerMeta
3636

3737

3838
# TODO: Move DEVICE to `xtuner.utils.device`
@@ -65,6 +65,8 @@ class RLTrainerConfig(BaseModel):
6565
log_dir: Path | str | None = None
6666
total_epochs: int
6767
resume_config: ResumeConfig | None = None
68+
auto_resume: bool = False
69+
load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig()
6870
strict_load: bool = True
6971
hf_interval: int | None = None
7072
hf_max_keep: int | None = None
@@ -157,6 +159,8 @@ class RLTrainer:
157159
enable_evaluate (bool): Whether to perform periodic evaluation during training.
158160
resume_config (ResumeConfig | None): Configuration for resuming training from
159161
a previous checkpoint. Defaults to None.
162+
auto_resume (bool): Whether to automatically resume training. Defaults to False.
163+
load_checkpoint_cfg (LoadCheckpointConfig): Configuration for loading checkpoints.
160164
strict_load (bool): Whether to strictly enforce checkpoint loading compatibility.
161165
Defaults to True.
162166
hf_interval (int | None): Interval (in epochs) for saving HuggingFace format
@@ -203,7 +207,9 @@ def __init__(
203207
work_dir: Path | str | None = None,
204208
log_dir: Path | str | None = None,
205209
total_epochs: int,
206-
resume_config: ResumeConfig | None = None,
210+
resume_config: ResumeConfig | None = None, # TODO: Removed in version 1.1.0
211+
auto_resume: bool = False,
212+
load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig(),
207213
strict_load: bool = True,
208214
hf_interval: int | None = None,
209215
hf_max_keep: int | None = None,
@@ -257,14 +263,20 @@ def __init__(
257263
work_dir.mkdir(parents=True, exist_ok=True)
258264

259265
self._work_dir = work_dir
260-
self._meta = self._init_xtuner_meta(work_dir, resume_config is not None)
266+
auto_resume = auto_resume or (resume_config is not None and resume_config.auto_resume)
267+
self._meta = self._init_xtuner_meta(work_dir, auto_resume)
261268

262269
if log_dir is None:
263270
log_dir = self.exp_dir
264271
if isinstance(log_dir, str):
265272
log_dir = Path(log_dir)
266273

267274
self.logger = self._init_logger(log_dir)
275+
276+
self.logger.warning(
277+
"`resume_config` is deprecated, please use `auto_resume` and `load_checkpoint_cfg` instead"
278+
)
279+
268280
train_worker_cfg.log_dir = log_dir
269281
dataflow_config.worker_log_dir = log_dir
270282
rollout_config.worker_log_dir = log_dir
@@ -338,6 +350,8 @@ def from_config(cls, config: RLTrainerConfig) -> Self:
338350
log_dir=config.log_dir,
339351
total_epochs=config.total_epochs,
340352
resume_config=config.resume_config,
353+
auto_resume=config.auto_resume,
354+
load_checkpoint_cfg=config.load_checkpoint_cfg,
341355
strict_load=config.strict_load,
342356
hf_interval=config.hf_interval,
343357
hf_max_keep=config.hf_max_keep,

0 commit comments

Comments
 (0)