Skip to content

Commit 803c5ef

Browse files
committed
[Fix] change bucket_size_in_gb from rollout_cfg to train worker cfg
1 parent db05eb6 commit 803c5ef

File tree

2 files changed

+2
-6
lines changed

2 files changed

+2
-6
lines changed

xtuner/v1/ray/config/worker.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,6 @@ class RolloutConfig(BaseModel):
170170
help="Whether to enable returning routed experts for the rollout worker.",
171171
),
172172
] = False
173-
update_weight_bucket_size_in_gb: Annotated[
174-
float,
175-
Parameter(group=infer_group, help="Bucket size in GB for updating weight."),
176-
] = 0.5 # 512MB
177173
launch_server_method: Annotated[
178174
Literal["ray", "multiprocessing"],
179175
Parameter(

xtuner/v1/rl/base/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class WorkerConfig(BaseModel):
129129
ref_load_from: str | Path | None = None
130130
ref_model_fsdp_cfg: FSDPConfig | None = None
131131
log_dir: str | Path | None = None
132+
update_weight_bucket_size_in_gb: float = 0.5 # 512MB
132133

133134

134135
class WorkerInputItem(TypedDict):
@@ -578,7 +579,6 @@ def update_rollout_info(
578579
self.rollout_cfg_info["tp"] = tp
579580
self.rollout_cfg_info["ep"] = ep
580581
self.rollout_cfg_info["api_key"] = rollout_config.api_key
581-
self.rollout_cfg_info["update_weight_bucket_size_in_gb"] = rollout_config.update_weight_bucket_size_in_gb
582582
if os.environ.get("XTUNER_USE_SGLANG", "0") == "1":
583583
self.rollout_cfg_info["backend"] = "sglang"
584584
else:
@@ -609,7 +609,7 @@ def _update_weights_hf_generator(self):
609609
else:
610610
dtype = torch.bfloat16
611611

612-
bucket_size = int(self.rollout_cfg_info["update_weight_bucket_size_in_gb"] * 1024**3)
612+
bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3)
613613
same_gen = model._get_same_hf_param(
614614
model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size
615615
)

0 commit comments

Comments
 (0)