Skip to content

Conversation

@smahdavi4
Copy link

@smahdavi4 smahdavi4 commented Dec 7, 2025

What does this PR do ?

The async GRPO does not need to offload the policy after checkpointing since the training and inference are not colocated. Currently this leads to exception after checkpointing since some tensors are not moved back to cuda.

Summary by CodeRabbit

  • Refactor
    • Streamlined resource handling during checkpoint finalization in GRPO async training by removing an unnecessary cleanup step. This change improves efficiency without affecting training behavior or output.

✏️ Tip: You can customize this high-level summary in your review settings.

@smahdavi4 smahdavi4 requested a review from a team as a code owner December 7, 2025 06:14
@smahdavi4 smahdavi4 changed the title Remove optimizer offload for async grpo dtensor Remove policy offload for async grpo dtensor Dec 7, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 7, 2025

📝 Walkthrough

Walkthrough

Removed a call to policy.offload_after_refit() that executed after finalizing a checkpoint during async GRPO training. This single-line removal eliminates the offload operation in that specific code path without altering control flow or initialization logic.

Changes

Cohort / File(s) Change Summary
GRPO async checkpoint handling
nemo_rl/algorithms/grpo.py
Removed policy.offload_after_refit() call following checkpoint finalization in async GRPO training

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~2 minutes

Suggested labels

Performance

Suggested reviewers

  • parthchadha
  • yuki-97

Pre-merge checks and finishing touches

✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main change: removing policy offload for async GRPO in the dtensor training workflow.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Test Results For Major Changes ✅ Passed Minor bug fix removing problematic optimizer offload call in async GRPO; no major feature changes, breaking changes, or significant refactoring.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0947683 and e545aad.

📒 Files selected for processing (1)
  • nemo_rl/algorithms/grpo.py (0 hunks)
💤 Files with no reviewable changes (1)
  • nemo_rl/algorithms/grpo.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@parthchadha
Copy link
Contributor

@smahdavi4 just curious if you can point to the exceptions you saw? I don't see them in any of my runs.

@smahdavi4
Copy link
Author

Here the exception, it only happens after a checkpoint is saved:

  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/optim/adam.py", line 949, in adam
    func(
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/optim/adam.py", line 447, in _single_tensor_adam
    exp_avg.lerp_(grad, 1 - device_beta1)
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 201, in dispatch
    local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parthchadha to review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants