Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 115 additions & 10 deletions tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
InternVL3P5Dense8BConfig,
InternVL3P5MoE30BA3Config,
)
from xtuner.v1.train.trainer import HooksConfig, Trainer, ResumeConfig, HookStage
from xtuner.v1.train.trainer import HooksConfig, Trainer, ResumeConfig, HookStage, LoadCheckpointConfig
from xtuner.v1.datasets import FTDPTokenizeFnConfig
from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig
from xtuner.v1.train.trainer import TrainerConfig
from xtuner.v1.engine.train_engine import LossLog, OtherLog
from xtuner.v1.loss import CELossConfig
from xtuner._testing import DeterministicDDPTestCase
from unittest import TestCase

from xtuner.v1.train.trainer import XTunerMeta, ExpInfo, ExpHistory, GitInfo
from xtuner.v1.utils.device import get_device


Expand Down Expand Up @@ -319,9 +319,7 @@ def test_resume(self):
debug=False,
checkpoint_interval=2,
checkpoint_maxkeep=2,
resume_cfg=ResumeConfig(
auto_resume=True,
),
auto_resume=True,
)
assert resume_trainer1.cur_step == 6
assert resume_trainer1.exp_dir == trainer.exp_dir
Expand All @@ -347,9 +345,7 @@ def test_resume(self):
debug=False,
checkpoint_interval=2,
checkpoint_maxkeep=2,
resume_cfg=ResumeConfig(
auto_resume=True,
),
auto_resume=True,
)
assert resume_trainer1_2.cur_step == 10
assert resume_trainer1_2.exp_dir == trainer.exp_dir
Expand All @@ -376,8 +372,8 @@ def test_resume(self):
debug=False,
checkpoint_interval=5,
checkpoint_maxkeep=2,
resume_cfg=ResumeConfig(
resume_from=resume_trainer1_2.meta.latest_exp.checkpoint_list[-2],
load_checkpoint_cfg=LoadCheckpointConfig(
checkpoint_path=resume_trainer1_2.meta.latest_exp.checkpoint_list[-2],
),
)
assert resume_trainer2.cur_step == 14
Expand Down Expand Up @@ -612,3 +608,112 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
loaded = pickle.loads(dumped)
assert len(loaded.get_hooks(HookStage.AFTER_TRAIN_STEP)) == 0 # <local> object cannot be serialized
assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1


def test_resume_and_load_checkpoint_cfg(tmp_path: Path):
# 0. prepare environment
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

alpaca_path = os.environ["ALPACA_PATH"]
tokenizer_path = os.environ["QWEN3_MOE_PATH"]

work_dir = tmp_path / "work_dir"
fake_hf_model_dir = tmp_path / "fake_hf_model"
fake_hf_model_dir.mkdir()
(fake_hf_model_dir / "config.json").write_text('{"model_type": "fake_model"}')
(fake_hf_model_dir / "model.safetensors").write_text("fake weights")

model_cfg = Qwen3MoE30BA3Config()
optim_cfg = AdamWConfig(lr=1e-4, weight_decay=0.01)

dataset_cfg = [
{
"dataset": DatasetConfig(name="alpaca", anno_path=alpaca_path, sample_ratio=1.0),
"tokenize_fn": FTDPTokenizeFnConfig(),
},
]
dataloader_cfg = DataloaderConfig(dataset_config_list=dataset_cfg)
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0.1, lr_min=1e-6)

Trainer.build_engine = Mock(return_value=FakeEngine())
Trainer._resume = Mock()

# 1. create: first train with auto_resume and load_checkpoint_cfg
auto_resume = True
load_checkpoint_cfg = LoadCheckpointConfig(
checkpoint_path=work_dir / "other_checkpoint",
load_optimizer_states=True,
load_optimizer_args=True,
load_dataset=False,
load_scheduler=False,
)

# 2. operate
trainer = Trainer(
load_from=fake_hf_model_dir,
model_cfg=model_cfg,
optim_cfg=optim_cfg,
dataloader_cfg=dataloader_cfg,
lr_cfg=lr_cfg,
tokenizer_path=tokenizer_path,
global_batch_size=2,
total_step=10,
checkpoint_interval=5,
work_dir=work_dir,
auto_resume=auto_resume,
load_checkpoint_cfg=load_checkpoint_cfg,
)

# 3. check: auto_resume does not overwrite load_checkpoint_cfg at first time
assert trainer._load_checkpoint_cfg.load_dataset is False
assert trainer._load_checkpoint_cfg.load_scheduler is False

del trainer

# 4. 2nd create: resume train with auto_resume and load_checkpoint_cfg
exp_dir = work_dir / "fake_exp_dir"
latest_checkpoint = exp_dir / "step-5"
resume_meta = XTunerMeta(
exps=[
ExpInfo(
checkpoint_list=[str(latest_checkpoint)],
exp_dir=str(exp_dir),
history=[
dict(
begin=0,
timestamp="20251126202933",
git_info=dict(commit="dae707", staged="", unstaged=""),
end=5,
),
],
),
],
)
Trainer._init_xtuner_meta = Mock(return_value=resume_meta)
trainer2 = Trainer(
load_from=str(fake_hf_model_dir),
model_cfg=model_cfg,
optim_cfg=optim_cfg,
dataloader_cfg=dataloader_cfg,
lr_cfg=lr_cfg,
tokenizer_path=tokenizer_path,
global_batch_size=2,
total_step=10,
checkpoint_interval=5,
work_dir=work_dir,
auto_resume=auto_resume,
load_checkpoint_cfg=load_checkpoint_cfg,
)

# 5. check: auto_resume overrides load_checkpoint_cfg when resume train
assert trainer2._load_checkpoint_cfg.checkpoint_path == latest_checkpoint
assert trainer2._load_checkpoint_cfg.load_dataset is True
assert trainer2._load_checkpoint_cfg.load_scheduler is True
assert trainer2._load_checkpoint_cfg.load_optimizer_states is True
assert trainer2._load_checkpoint_cfg.load_optimizer_args is True

dist.destroy_process_group()
20 changes: 17 additions & 3 deletions xtuner/v1/train/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from xtuner.v1.utils.device import get_device, get_torch_device_module
from xtuner.v1.utils.env_check import get_rollout_engine_version

from .trainer import ExpHistory, ExpInfo, GitInfo, XTunerMeta
from .trainer import ExpHistory, ExpInfo, GitInfo, LoadCheckpointConfig, XTunerMeta


# TODO: Move DEVICE to `xtuner.utils.device`
Expand Down Expand Up @@ -65,6 +65,8 @@ class RLTrainerConfig(BaseModel):
log_dir: Path | str | None = None
total_epochs: int
resume_config: ResumeConfig | None = None
auto_resume: bool = False
load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig()
strict_load: bool = True
hf_interval: int | None = None
hf_max_keep: int | None = None
Expand Down Expand Up @@ -157,6 +159,8 @@ class RLTrainer:
enable_evaluate (bool): Whether to perform periodic evaluation during training.
resume_config (ResumeConfig | None): Configuration for resuming training from
a previous checkpoint. Defaults to None.
auto_resume (bool): Whether to automatically resume training. Defaults to False.
load_checkpoint_cfg (LoadCheckpointConfig): Configuration for loading checkpoints.
strict_load (bool): Whether to strictly enforce checkpoint loading compatibility.
Defaults to True.
hf_interval (int | None): Interval (in epochs) for saving HuggingFace format
Expand Down Expand Up @@ -203,7 +207,9 @@ def __init__(
work_dir: Path | str | None = None,
log_dir: Path | str | None = None,
total_epochs: int,
resume_config: ResumeConfig | None = None,
resume_config: ResumeConfig | None = None, # TODO: Removed in version 1.1.0
auto_resume: bool = False,
load_checkpoint_cfg: LoadCheckpointConfig = LoadCheckpointConfig(),
strict_load: bool = True,
hf_interval: int | None = None,
hf_max_keep: int | None = None,
Expand Down Expand Up @@ -257,14 +263,20 @@ def __init__(
work_dir.mkdir(parents=True, exist_ok=True)

self._work_dir = work_dir
self._meta = self._init_xtuner_meta(work_dir, resume_config is not None)
auto_resume = auto_resume or (resume_config is not None and resume_config.auto_resume)
self._meta = self._init_xtuner_meta(work_dir, auto_resume)

if log_dir is None:
log_dir = self.exp_dir
if isinstance(log_dir, str):
log_dir = Path(log_dir)

self.logger = self._init_logger(log_dir)

self.logger.warning(
"`resume_config` is deprecated, please use `auto_resume` and `load_checkpoint_cfg` instead"
)

train_worker_cfg.log_dir = log_dir
dataflow_config.worker_log_dir = log_dir
rollout_config.worker_log_dir = log_dir
Expand Down Expand Up @@ -338,6 +350,8 @@ def from_config(cls, config: RLTrainerConfig) -> Self:
log_dir=config.log_dir,
total_epochs=config.total_epochs,
resume_config=config.resume_config,
auto_resume=config.auto_resume,
load_checkpoint_cfg=config.load_checkpoint_cfg,
strict_load=config.strict_load,
hf_interval=config.hf_interval,
hf_max_keep=config.hf_max_keep,
Expand Down
Loading
Loading