Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/ray/test_update_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def init_config(self):
if hasattr(model_cfg, 'balancing_loss_cfg'):
model_cfg.balancing_loss_cfg = BalancingLossConfig()
optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False)
fsdp_cfg: FSDPConfig = FSDPConfig()
fsdp_cfg: FSDPConfig = FSDPConfig(ep_size=4)
model_cfg.ep_size = fsdp_cfg.ep_size
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7)
self.worker_cfg: WorkerConfig = WorkerConfig(
model_cfg=model_cfg,
Expand All @@ -84,7 +85,7 @@ def init_config(self):
loss_type="vanilla",
),
ignore_idx=-100,
use_kl_loss=True,
use_kl_loss=False,
kl_loss_coef=0.001,
kl_loss_type="low_var_kl",
mode="eager"),
Expand Down
154 changes: 100 additions & 54 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def _get_fused_hf_param(
dtype: torch.dtype,
device="cpu",
bucket_size=None,
return_full_key_per_rank: bool = False,
update_weights_for_rl: bool = False,
) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]:
if not params:
return
Expand All @@ -525,70 +525,65 @@ def _get_hf_params(
for load_spec, fsdp_unshared_tensor in zip(spec_list, fsdp_unshard_tensor_list):
hf_keys = load_spec.hf_keys

if load_spec.group is not None:
all_hf_keys_list: list[None] | list[list[str]] = [None for _ in range(load_spec.group.size())]
dist.all_gather_object(all_hf_keys_list, hf_keys, group=load_spec.group)
all_hf_keys_list = cast(list[list[str]], all_hf_keys_list)
all_hf_keys = list(chain(*all_hf_keys_list))
if update_weights_for_rl:
hf_keys_list.append(hf_keys)
saved_fused_tensor_list.append(fsdp_unshared_tensor)
else:
all_hf_keys = hf_keys
if load_spec.group is not None:
all_hf_keys_list: list[None] | list[list[str]] = [None for _ in range(load_spec.group.size())]
dist.all_gather_object(all_hf_keys_list, hf_keys, group=load_spec.group)
all_hf_keys_list = cast(list[list[str]], all_hf_keys_list)
all_hf_keys = list(chain(*all_hf_keys_list))
else:
all_hf_keys = hf_keys

current_rank = dist.get_rank()
current_rank = dist.get_rank()

expected_fused_save_ranks = self._get_ranks_to_save_fused_tensor(len(all_hf_keys))
hardcode_fused_save_ranks = list(
range(min((dist.get_world_size(), self.config.hf_save_cfg.max_save_rank)))
)
expected_fused_save_ranks = self._get_ranks_to_save_fused_tensor(len(all_hf_keys))
hardcode_fused_save_ranks = list(
range(min((dist.get_world_size(), self.config.hf_save_cfg.max_save_rank)))
)

key_per_rank = len(all_hf_keys) / len(hardcode_fused_save_ranks)
# assert key_per_rank.is_integer(), (
# f"XTuner Internal Error, size of all_hf_keys: {len(all_hf_keys)}, "
# f"size of `fused_save_ranks` {len(fused_save_ranks)}"
# )
if not key_per_rank.is_integer():
key_per_rank = len(all_hf_keys) / len(expected_fused_save_ranks)

key_per_rank = len(all_hf_keys) / len(hardcode_fused_save_ranks)
# assert key_per_rank.is_integer(), (
# f"XTuner Internal Error, size of all_hf_keys: {len(all_hf_keys)}, "
# f"size of `fused_save_ranks` {len(fused_save_ranks)}"
# )
if not key_per_rank.is_integer():
key_per_rank = len(all_hf_keys) / len(expected_fused_save_ranks)

# 1. When return_full_key_per_rank is False, we intends to save hf models across ranks,
# each rank only saves part of hf keys and tensors
# 2. When return_full_key_per_rank is True, we intends to generate full tensors on each
# rank for ipc updating weights in RL training.
if not return_full_key_per_rank:
start = int(current_rank * key_per_rank)
end = int(start + key_per_rank)
else:
start = 0
end = len(all_hf_keys)

_hf_key_list = all_hf_keys[start:end]
_hf_key_list = all_hf_keys[start:end]

if not _hf_key_list:
continue
if not _hf_key_list:
continue

hf_keys_list.append(_hf_key_list)
hf_keys_list.append(_hf_key_list)

assert load_spec.dim is not None
if load_spec.group is not None:
assert load_spec.dim is not None
_gathered_tensor_list = [
torch.zeros_like(fsdp_unshared_tensor) for _ in range(load_spec.group.size())
]
dist.all_gather(_gathered_tensor_list, fsdp_unshared_tensor, group=load_spec.group)
_gathered_tensor = torch.cat(_gathered_tensor_list, dim=load_spec.dim)
else:
_gathered_tensor = fsdp_unshared_tensor

hf_tensor_size = _gathered_tensor.shape[load_spec.dim] / len(all_hf_keys)
_saved_fused_tensor = torch.index_select(
_gathered_tensor,
dim=load_spec.dim,
index=torch.arange(
int(start * hf_tensor_size),
int(end * hf_tensor_size),
dtype=torch.int64,
device=_gathered_tensor.device,
),
)
saved_fused_tensor_list.append(_saved_fused_tensor)
if load_spec.group is not None:
assert load_spec.dim is not None
_gathered_tensor_list = [
torch.zeros_like(fsdp_unshared_tensor) for _ in range(load_spec.group.size())
]
dist.all_gather(_gathered_tensor_list, fsdp_unshared_tensor, group=load_spec.group)
_gathered_tensor = torch.cat(_gathered_tensor_list, dim=load_spec.dim)
else:
_gathered_tensor = fsdp_unshared_tensor
hf_tensor_size = _gathered_tensor.shape[load_spec.dim] / len(all_hf_keys)
_saved_fused_tensor = torch.index_select(
_gathered_tensor,
dim=load_spec.dim,
index=torch.arange(
int(start * hf_tensor_size),
int(end * hf_tensor_size),
dtype=torch.int64,
device=_gathered_tensor.device,
),
)
saved_fused_tensor_list.append(_saved_fused_tensor)

# Split the fused tensor into hf tensors
hf_tensor_list: list[torch.Tensor] = []
Expand Down Expand Up @@ -1177,6 +1172,18 @@ def _fsdp_foreach_allgather(

# Concatenate the tensors along the FSDP shard dim
for tensors, size in zip(_fsdp_unsharded_tensor_list, origin_fsdp_size):
# Tn the case of one big tensor in tensor_list, the partition of tensors are contiguous.
# Therefore the cat and index_select operation can be omitted,
# and use _fuse_contiguous_chunks_without_alloc instead to reduce device peak memory.
# e.g. When a fused MoE weight exceeds bucket_size given, len(tensor_list) would be 1.
# and fused_tensor is not None reducing peak device memory.
fused_tensor = self._fuse_contiguous_chunks_without_alloc(tensors)
if fused_tensor is not None and fused_tensor.shape[self.FSDP_SHARD_DIM] == size:
fsdp_unsharded_tensor_list.append(fused_tensor)
continue
elif fused_tensor is not None:
# free memory ASAP
del fused_tensor
tensor = torch.cat(tensors, dim=self.FSDP_SHARD_DIM)
cat_tensor = torch.index_select(
tensor,
Expand All @@ -1193,6 +1200,45 @@ def _fsdp_foreach_allgather(

return fsdp_unsharded_tensor_list

@staticmethod
def _fuse_contiguous_chunks_without_alloc(tensors: list[torch.Tensor]) -> torch.Tensor | None:
"""Fuse contiguous chunks without extra memory allocation.

Return None if not possible.
"""
if not tensors:
raise ValueError("tensors should not be empty")
base = tensors[0]
storage = base.untyped_storage()
dtype = base.dtype
device = base.device
stride = base.stride()

inner_stride = stride[1:]
inner_elems = math.prod(base.shape[1:]) if base.dim() > 1 else 1

chunks = []
for t in tensors:
# we should check both storage and stride to ensure contiguity
# regardless of the implementation of foreach_all_gather
if t.untyped_storage().data_ptr() != storage.data_ptr() or t.stride()[1:] != inner_stride:
return None
chunks.append((t.storage_offset(), t.shape[0], t))
chunks.sort(key=lambda x: x[0])

expected_offset = chunks[0][0]
total_rows = 0
for offset, rows, _ in chunks:
if offset != expected_offset:
return None
expected_offset += rows * inner_elems
total_rows += rows

size = (total_rows, *base.shape[1:])
flat = torch.empty(0, dtype=dtype, device=device)
flat.set_(storage, chunks[0][0], size, stride)
return flat

def _maybe_compile_layers(self):
if self.fsdp_config is not None:
if self.fsdp_config.torch_compile:
Expand Down
44 changes: 40 additions & 4 deletions xtuner/v1/rl/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class WorkerConfig(BaseModel):
ref_load_from: str | Path | None = None
ref_model_fsdp_cfg: FSDPConfig | None = None
log_dir: str | Path | None = None
update_weight_bucket_size_in_gb: float = 0.5 # 512MB


class WorkerInputItem(TypedDict):
Expand Down Expand Up @@ -608,19 +609,54 @@ def _update_weights_hf_generator(self):
else:
dtype = torch.bfloat16

same_gen = model._get_same_hf_param(model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE)
bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3)
same_gen = model._get_same_hf_param(
model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size
)
fused_gen = model._get_fused_hf_param(
model._group_param_by_load_spec(LoadEnum.FUSED),
dtype=dtype,
device=DEVICE,
return_full_key_per_rank=True,
bucket_size=bucket_size,
update_weights_for_rl=True,
)
shard_gen = model._get_shard_hf_param(
model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE
model._group_param_by_load_spec(LoadEnum.SHARD), dtype=dtype, device=DEVICE, bucket_size=bucket_size
)
for name_list, param_list in chain(same_gen, fused_gen, shard_gen):

for name_list, fused_param_list in fused_gen:
state_dict = {name: param.detach() for name, param in zip(name_list, fused_param_list)}
if model.fsdp_config.ep_size > 1:
# When ep_size > 1, generator generates part of the fused param on each ep rank in one ep_group.
# We can all gather them to get full fused param but it would lead to a larger memory usage.
# So we broadcast the part fused param from each ep rank in ep_group sequentially,
# and update the part of the fused param sequentially to reduce memory usage.
ep_mesh: DeviceMesh = model.ep_mesh
ep_group = ep_mesh.get_group()
global_rank = dist.get_rank()
for src_global_rank in dist.get_process_group_ranks(ep_group):
broadcast_state_dict = dict()
for key, tensor in state_dict.items():
obj_to_broadcast = [key, tensor.to("meta")] if global_rank == src_global_rank else [None, None]
dist.broadcast_object_list(obj_to_broadcast, src=src_global_rank, group=ep_group)
real_key, meta_tensor = obj_to_broadcast
buffer = (
state_dict[real_key]
if global_rank == src_global_rank
else torch.empty_like(meta_tensor, device=DEVICE)
)
dist.broadcast(buffer, src=src_global_rank, group=ep_group)
broadcast_state_dict[real_key] = buffer
self.request_update_params(broadcast_state_dict, finished=False)
del broadcast_state_dict, buffer
else:
self.request_update_params(state_dict, finished=False)
del state_dict, name_list, fused_param_list

for name_list, param_list in chain(same_gen, shard_gen):
state_dict = {name: param.detach() for name, param in zip(name_list, param_list)}
self.request_update_params(state_dict, finished=False)
del state_dict, name_list, param_list

if self.rollout_cfg_info["backend"] == "pytorch":
self.request_update_params({}, finished=True)
Expand Down
13 changes: 12 additions & 1 deletion xtuner/v1/train/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,18 @@ def __init__(
* total_epochs
)
bind_train_rollout(train_controller=self._train_controller, env_controller=self._rollout_env_controller)
ray.get(self._train_controller.offload.remote(target="all"))
# update weights if rollout_config.skip_load_weights == True
if rollout_config.skip_load_weights:
self.logger.info("Rollout workers skip load weights, update weights from train workers.")
ray.get(self._train_controller.offload.remote(target="optimizer"))
ray.get(self._rollout_env_controller.offload.remote())
ray.get(self._rollout_env_controller.onload_weights.remote())
ray.get(self._train_controller.update_weights.remote())
ray.get(self._train_controller.offload.remote(target="model"))
ray.get(self._rollout_env_controller.onload_kvcache.remote())
self.logger.info("Rollout workers has updated weights from train workers.")
else:
ray.get(self._train_controller.offload.remote(target="all"))

self._train_worker_cfg = train_worker_cfg

Expand Down
Loading