-
Notifications
You must be signed in to change notification settings - Fork 390
[OPT] Reduce update_weight peek memory in RL #1306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
xtuner/v1/model/base.py
Outdated
|
|
||
| # Concatenate the tensors along the FSDP shard dim | ||
| for tensors, size in zip(_fsdp_unsharded_tensor_list, origin_fsdp_size): | ||
| # special case for partition of tensors are contiguous |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment should describe why rather than how
xtuner/v1/model/base.py
Outdated
| if ( | ||
| t.untyped_storage().data_ptr() != storage.data_ptr() | ||
| or t.dtype != dtype | ||
| or t.device != device | ||
| or t.stride()[1:] != inner_stride | ||
| ): | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we made it a private function, we should remove some unnecessary check
xtuner/v1/rl/base/worker.py
Outdated
| ep_mesh: DeviceMesh = model.ep_mesh | ||
| ep_size = ep_mesh.size() | ||
| ep_group = ep_mesh.get_group() | ||
| ep_rank = dist.get_rank(group=ep_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments describing what happened here.
3bc7871 to
11fb109
Compare
803c5ef to
05df05c
Compare
…k instead of local rank in pg
05df05c to
61e09a0
Compare
Key Change:
xtuner/v1/model/base.py:_get_fused_hf_param_get_hf_paramsin RL update weights.update_weight_bucket_size_in_gbto config pack tensor size in RL update weights.test_update_weightdefault usingep_size=4.skip_load_weights == Truein RolloutConfig, lmdeploy would load weights from training workers, this would save time in large scale cluster.