Skip to content

Conversation

@CyCle1024
Copy link
Collaborator

Key Change:

  1. Clean extra index_select and cat logic in xtuner/v1/model/base.py:_get_fused_hf_param _get_hf_params in RL update weights.
  2. Use EP shard to reduce peek memory in transferring fused big expert params in RL update weights.
  3. Add Rollout param update_weight_bucket_size_in_gb to config pack tensor size in RL update weights.
  4. test_update_weight default using ep_size=4.
  5. When skip_load_weights == True in RolloutConfig, lmdeploy would load weights from training workers, this would save time in large scale cluster.

@CyCle1024 CyCle1024 requested a review from hhaAndroid November 28, 2025 10:10

# Concatenate the tensors along the FSDP shard dim
for tensors, size in zip(_fsdp_unsharded_tensor_list, origin_fsdp_size):
# special case for partition of tensors are contiguous
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment should describe why rather than how

Comment on lines 1182 to 1188
if (
t.untyped_storage().data_ptr() != storage.data_ptr()
or t.dtype != dtype
or t.device != device
or t.stride()[1:] != inner_stride
):
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we made it a private function, we should remove some unnecessary check

ep_mesh: DeviceMesh = model.ep_mesh
ep_size = ep_mesh.size()
ep_group = ep_mesh.get_group()
ep_rank = dist.get_rank(group=ep_group)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments describing what happened here.

@hhaAndroid hhaAndroid force-pushed the opt_update_weight_mem branch from 3bc7871 to 11fb109 Compare November 28, 2025 11:09
@CyCle1024 CyCle1024 requested a review from HAOCHENYE December 1, 2025 08:09
@CyCle1024 CyCle1024 force-pushed the opt_update_weight_mem branch 3 times, most recently from 803c5ef to 05df05c Compare December 5, 2025 09:11
@CyCle1024 CyCle1024 force-pushed the opt_update_weight_mem branch from 05df05c to 61e09a0 Compare December 5, 2025 09:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants