Skip to content

Commit 11fb109

Browse files
CyCle1024hhaAndroid
authored andcommitted
[OPT] Reduce update_weight peek memory in RL
1 parent dc9f027 commit 11fb109

File tree

5 files changed

+150
-56
lines changed

5 files changed

+150
-56
lines changed

tests/ray/test_update_weight.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def init_config(self):
7272
if hasattr(model_cfg, 'balancing_loss_cfg'):
7373
model_cfg.balancing_loss_cfg = BalancingLossConfig()
7474
optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False)
75-
fsdp_cfg: FSDPConfig = FSDPConfig()
75+
fsdp_cfg: FSDPConfig = FSDPConfig(ep_size=4)
76+
model_cfg.ep_size = fsdp_cfg.ep_size
7677
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7)
7778
self.worker_cfg: WorkerConfig = WorkerConfig(
7879
model_cfg=model_cfg,
@@ -84,7 +85,7 @@ def init_config(self):
8485
loss_type="vanilla",
8586
),
8687
ignore_idx=-100,
87-
use_kl_loss=True,
88+
use_kl_loss=False,
8889
kl_loss_coef=0.001,
8990
kl_loss_type="low_var_kl",
9091
mode="eager"),

xtuner/v1/model/base.py

Lines changed: 94 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def _get_fused_hf_param(
481481
dtype: torch.dtype,
482482
device="cpu",
483483
bucket_size=None,
484-
return_full_key_per_rank: bool = False,
484+
update_weights_for_rl: bool = False,
485485
) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]:
486486
if not params:
487487
return
@@ -506,63 +506,58 @@ def _get_hf_params(
506506
for load_spec, fsdp_unshared_tensor in zip(spec_list, fsdp_unshard_tensor_list):
507507
hf_keys = load_spec.hf_keys
508508

509-
if load_spec.group is not None:
510-
all_hf_keys_list: list[None] | list[list[str]] = [None for _ in range(load_spec.group.size())]
511-
dist.all_gather_object(all_hf_keys_list, hf_keys, group=load_spec.group)
512-
all_hf_keys_list = cast(list[list[str]], all_hf_keys_list)
513-
all_hf_keys = list(chain(*all_hf_keys_list))
509+
if update_weights_for_rl:
510+
hf_keys_list.append(hf_keys)
511+
saved_fused_tensor_list.append(fsdp_unshared_tensor)
514512
else:
515-
all_hf_keys = hf_keys
516-
517-
current_rank = dist.get_rank()
518-
fused_save_ranks = self._get_ranks_to_save_fused_tensor(len(all_hf_keys))
519-
key_per_rank = len(all_hf_keys) / len(fused_save_ranks)
520-
assert key_per_rank.is_integer(), (
521-
f"XTuner Internal Error, size of all_hf_keys: {len(all_hf_keys)}, "
522-
f"size of `fused_save_ranks` {len(fused_save_ranks)}"
523-
)
513+
if load_spec.group is not None:
514+
all_hf_keys_list: list[None] | list[list[str]] = [None for _ in range(load_spec.group.size())]
515+
dist.all_gather_object(all_hf_keys_list, hf_keys, group=load_spec.group)
516+
all_hf_keys_list = cast(list[list[str]], all_hf_keys_list)
517+
all_hf_keys = list(chain(*all_hf_keys_list))
518+
else:
519+
all_hf_keys = hf_keys
520+
521+
current_rank = dist.get_rank()
522+
fused_save_ranks = self._get_ranks_to_save_fused_tensor(len(all_hf_keys))
523+
key_per_rank = len(all_hf_keys) / len(fused_save_ranks)
524+
assert key_per_rank.is_integer(), (
525+
f"XTuner Internal Error, size of all_hf_keys: {len(all_hf_keys)}, "
526+
f"size of `fused_save_ranks` {len(fused_save_ranks)}"
527+
)
524528

525-
# 1. When return_full_key_per_rank is False, we intends to save hf models across ranks,
526-
# each rank only saves part of hf keys and tensors
527-
# 2. When return_full_key_per_rank is True, we intends to generate full tensors on each
528-
# rank for ipc updating weights in RL training.
529-
if not return_full_key_per_rank:
530529
start = int(current_rank * key_per_rank)
531530
end = int(start + key_per_rank)
532-
else:
533-
start = 0
534-
end = len(all_hf_keys)
535531

536-
_hf_key_list = all_hf_keys[start:end]
532+
_hf_key_list = all_hf_keys[start:end]
537533

538-
if not _hf_key_list:
539-
continue
534+
if not _hf_key_list:
535+
continue
540536

541-
hf_keys_list.append(_hf_key_list)
537+
hf_keys_list.append(_hf_key_list)
542538

543-
assert load_spec.dim is not None
544-
if load_spec.group is not None:
545539
assert load_spec.dim is not None
546-
_gathered_tensor_list = [
547-
torch.zeros_like(fsdp_unshared_tensor) for _ in range(load_spec.group.size())
548-
]
549-
dist.all_gather(_gathered_tensor_list, fsdp_unshared_tensor, group=load_spec.group)
550-
_gathered_tensor = torch.cat(_gathered_tensor_list, dim=load_spec.dim)
551-
else:
552-
_gathered_tensor = fsdp_unshared_tensor
553-
554-
hf_tensor_size = _gathered_tensor.shape[load_spec.dim] / len(all_hf_keys)
555-
_saved_fused_tensor = torch.index_select(
556-
_gathered_tensor,
557-
dim=load_spec.dim,
558-
index=torch.arange(
559-
int(start * hf_tensor_size),
560-
int(end * hf_tensor_size),
561-
dtype=torch.int64,
562-
device=_gathered_tensor.device,
563-
),
564-
)
565-
saved_fused_tensor_list.append(_saved_fused_tensor)
540+
if load_spec.group is not None:
541+
assert load_spec.dim is not None
542+
_gathered_tensor_list = [
543+
torch.zeros_like(fsdp_unshared_tensor) for _ in range(load_spec.group.size())
544+
]
545+
dist.all_gather(_gathered_tensor_list, fsdp_unshared_tensor, group=load_spec.group)
546+
_gathered_tensor = torch.cat(_gathered_tensor_list, dim=load_spec.dim)
547+
else:
548+
_gathered_tensor = fsdp_unshared_tensor
549+
hf_tensor_size = _gathered_tensor.shape[load_spec.dim] / len(all_hf_keys)
550+
_saved_fused_tensor = torch.index_select(
551+
_gathered_tensor,
552+
dim=load_spec.dim,
553+
index=torch.arange(
554+
int(start * hf_tensor_size),
555+
int(end * hf_tensor_size),
556+
dtype=torch.int64,
557+
device=_gathered_tensor.device,
558+
),
559+
)
560+
saved_fused_tensor_list.append(_saved_fused_tensor)
566561

567562
# Split the fused tensor into hf tensors
568563
hf_tensor_list: list[torch.Tensor] = []
@@ -1141,6 +1136,14 @@ def _fsdp_foreach_allgather(
11411136

11421137
# Concatenate the tensors along the FSDP shard dim
11431138
for tensors, size in zip(_fsdp_unsharded_tensor_list, origin_fsdp_size):
1139+
# special case for partition of tensors are contiguous
1140+
fused_tensor = self.fuse_contiguous_chunks_without_alloc(tensors)
1141+
if fused_tensor is not None and fused_tensor.shape[self.FSDP_SHARD_DIM] == size:
1142+
fsdp_unsharded_tensor_list.append(fused_tensor)
1143+
continue
1144+
elif fused_tensor is not None:
1145+
# free memory ASAP
1146+
del fused_tensor
11441147
tensor = torch.cat(tensors, dim=self.FSDP_SHARD_DIM)
11451148
cat_tensor = torch.index_select(
11461149
tensor,
@@ -1157,6 +1160,48 @@ def _fsdp_foreach_allgather(
11571160

11581161
return fsdp_unsharded_tensor_list
11591162

1163+
@staticmethod
1164+
def fuse_contiguous_chunks_without_alloc(tensors: list[torch.Tensor]) -> torch.Tensor | None:
1165+
"""Fuse contiguous chunks without extra memory allocation.
1166+
1167+
Return None if not possible.
1168+
"""
1169+
if not tensors:
1170+
return None
1171+
base = tensors[0]
1172+
storage = base.untyped_storage()
1173+
dtype = base.dtype
1174+
device = base.device
1175+
stride = base.stride()
1176+
1177+
inner_stride = stride[1:]
1178+
inner_elems = math.prod(base.shape[1:]) if base.dim() > 1 else 1
1179+
1180+
chunks = []
1181+
for t in tensors:
1182+
if (
1183+
t.untyped_storage().data_ptr() != storage.data_ptr()
1184+
or t.dtype != dtype
1185+
or t.device != device
1186+
or t.stride()[1:] != inner_stride
1187+
):
1188+
return None
1189+
chunks.append((t.storage_offset(), t.shape[0], t))
1190+
chunks.sort(key=lambda x: x[0])
1191+
1192+
expected_offset = chunks[0][0]
1193+
total_rows = 0
1194+
for offset, rows, _ in chunks:
1195+
if offset != expected_offset:
1196+
return None
1197+
expected_offset += rows * inner_elems
1198+
total_rows += rows
1199+
1200+
size = (total_rows, *base.shape[1:])
1201+
flat = torch.empty(0, dtype=dtype, device=device)
1202+
flat.set_(storage, chunks[0][0], size, stride)
1203+
return flat
1204+
11601205
def _maybe_compile_layers(self):
11611206
if self.fsdp_config is not None:
11621207
if self.fsdp_config.torch_compile:

xtuner/v1/ray/config/worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ class RolloutConfig(BaseModel):
170170
help="Whether to enable returning routed experts for the rollout worker.",
171171
),
172172
] = False
173+
update_weight_bucket_size_in_gb: Annotated[
174+
float,
175+
Parameter(group=infer_group, help="Bucket size in GB for updating weight."),
176+
] = 0.5 # 512MB
173177
launch_server_method: Annotated[
174178
Literal["ray", "multiprocessing"],
175179
Parameter(

xtuner/v1/rl/base/worker.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def update_rollout_info(
621621
self.rollout_cfg_info["tp"] = tp
622622
self.rollout_cfg_info["ep"] = ep
623623
self.rollout_cfg_info["api_key"] = rollout_config.api_key
624+
self.rollout_cfg_info["update_weight_bucket_size_in_gb"] = rollout_config.update_weight_bucket_size_in_gb
624625
if os.environ.get("XTUNER_USE_SGLANG", "0") == "1":
625626
self.rollout_cfg_info["backend"] = "sglang"
626627
else:
@@ -651,19 +652,51 @@ def _update_weights_hf_generator(self):
651652
else:
652653
dtype = torch.bfloat16
653654

654-
same_gen = model._get_same_hf_param(model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE)
655+
bucket_size = int(self.rollout_cfg_info["update_weight_bucket_size_in_gb"] * 1024**3)
656+
same_gen = model._get_same_hf_param(
657+
model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size
658+
)
655659
fused_gen = model._get_fused_hf_param(
656660
model._group_param_by_load_spec(LoadEnum.FUSED),
657661
dtype=dtype,
658662
device=DEVICE,
659-
return_full_key_per_rank=True,
663+
bucket_size=bucket_size,
664+
update_weights_for_rl=True,
660665
)
661666
shard_gen = model._get_shard_hf_param(
662-
model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE
667+
model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE, bucket_size=bucket_size
663668
)
664-
for name_list, param_list in chain(same_gen, fused_gen, shard_gen):
669+
670+
for name_list, fused_param_list in fused_gen:
671+
state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)}
672+
if model.fsdp_config.ep_size > 1:
673+
ep_mesh: DeviceMesh = model.ep_mesh
674+
ep_size = ep_mesh.size()
675+
ep_group = ep_mesh.get_group()
676+
ep_rank = dist.get_rank(group=ep_group)
677+
for src_rank in range(ep_size):
678+
broadcast_state_dict = dict()
679+
for key, tensor in state_dict.items():
680+
obj_to_broadcast = [key, tensor.to("meta")] if ep_rank == src_rank else [None, None]
681+
dist.broadcast_object_list(obj_to_broadcast, src=src_rank, group=ep_group)
682+
real_key, meta_tensor = obj_to_broadcast
683+
buffer = (
684+
state_dict[real_key]
685+
if ep_rank == src_rank
686+
else torch.empty_like(meta_tensor, device=DEVICE)
687+
)
688+
dist.broadcast(buffer, src=src_rank, group=ep_group)
689+
broadcast_state_dict[real_key] = buffer
690+
self.request_update_params(broadcast_state_dict, finished=False)
691+
del broadcast_state_dict, buffer
692+
else:
693+
self.request_update_params(state_dict, finished=False)
694+
del state_dict, name_list, fused_param_list
695+
696+
for name_list, param_list in chain(same_gen, shard_gen):
665697
state_dict = {name: param.detach() for name, param in zip(name_list, param_list)}
666698
self.request_update_params(state_dict, finished=False)
699+
del state_dict, name_list, param_list
667700

668701
if self.rollout_cfg_info["backend"] == "pytorch":
669702
self.request_update_params({}, finished=True)

xtuner/v1/train/rl_trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,18 @@ def __init__(
298298
* total_epochs
299299
)
300300
bind_train_rollout(train_controller=self._train_controller, env_controller=self._rollout_env_controller)
301-
ray.get(self._train_controller.offload.remote(target="all"))
301+
# update weights if rollout_config.skip_load_weights == True
302+
if rollout_config.skip_load_weights:
303+
self.logger.info("Rollout workers skip load weights, update weights from train workers.")
304+
ray.get(self._train_controller.offload.remote(target="optimizer"))
305+
ray.get(self._rollout_env_controller.offload.remote())
306+
ray.get(self._rollout_env_controller.onload_weights.remote())
307+
ray.get(self._train_controller.update_weights.remote())
308+
ray.get(self._train_controller.offload.remote(target="model"))
309+
ray.get(self._rollout_env_controller.onload_kvcache.remote())
310+
self.logger.info("Rollout workers has updated weights from train workers.")
311+
else:
312+
ray.get(self._train_controller.offload.remote(target="all"))
302313

303314
self._train_worker_cfg = train_worker_cfg
304315

0 commit comments

Comments
 (0)