-
Notifications
You must be signed in to change notification settings - Fork 388
[wip] support partial rollout #1284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces partial rollout support with importance sampling (IS) for reinforcement learning training, enabling more efficient off-policy learning and handling of distribution mismatch between rollout and training policies.
Key changes:
- Implements rollout importance sampling with configurable aggregation levels (token, sequence, geometric) and handling modes (truncate, mask)
- Adds state management for partial rollouts with version tracking and replay buffer support for interrupted/completed/expired samples
- Enhances rollout error handling with retry logic and proper state transitions (completed, interrupted, skipped, failed)
Reviewed Changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
xtuner/v1/rl/base/rollout_is.py |
New module implementing importance sampling with comprehensive metrics and mismatch detection |
xtuner/v1/rl/base/loss.py |
Adds rollout IS configuration and weights to loss context for training |
xtuner/v1/rl/oreal/loss.py |
Integrates IS weights into OREAL policy loss computation |
xtuner/v1/rl/grpo/loss.py |
Integrates IS weights into GRPO policy loss computation |
xtuner/v1/rl/loss_fn.py |
Adds new CISPO loss function (has bug: duplicate function name) |
xtuner/v1/rl/base/worker.py |
Adds IS weight computation and metrics logging in training worker |
xtuner/v1/train/rl_trainer.py |
Adds memory usage logging and updates validation function names |
xtuner/v1/ray/rollout/worker.py |
Implements retry logic and state-based error handling for rollout requests |
xtuner/v1/ray/rollout/controller.py |
Adds worker statistics tracking and improved failure recovery |
xtuner/v1/ray/rollout/sglang.py |
Supports partial rollout with num_return_tokens parameter |
xtuner/v1/ray/rollout/lmdeploy.py |
Adds FP8 support and num_return_tokens handling |
xtuner/v1/ray/dataflow/replay_buffer.py |
Refactored to support versioned states with interrupted/completed/expired sample management |
xtuner/v1/ray/dataflow/flow.py |
Updated to handle partial rollout with state-based sample routing |
xtuner/v1/ray/environment/single_turn_env.py |
Simplifies error handling with state-based validation |
xtuner/v1/ray/evaluator.py |
Removes retry logic in favor of upstream error handling |
xtuner/v1/ray/judger/compass_verifier_v2.py |
New judger implementation (has bugs: missing await, hardcoded IPs) |
xtuner/v1/data_proto/rl_data.py |
Adds state field and update method for partial rollout support |
xtuner/v1/data_proto/utils.py |
Adds tensor packing/unpacking utilities for IS computation |
xtuner/v1/ray/config/worker.py |
Adds FP8 and retry configuration options |
xtuner/v1/train/cli/rl.py |
Updates Ray initialization logic for local execution |
xtuner/v1/model/dense/qwen2.py |
Adds bos_token_id field (has bug: missing default value) |
xtuner/v1/utils/httpx_utils.py |
Removes unused rollout status helper function |
Comments suppressed due to low confidence (1)
xtuner/v1/ray/dataflow/replay_buffer.py:390
- Variable target_batch_size is not used.
target_batch_size = min(global_batch_size, len(self._completed_actions))
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
xtuner/v1/ray/rollout/controller.py
Outdated
| if self.workers_info[server_url].running_count % 100 == 0: | ||
| log_msg = "" | ||
| for _, info in self.workers_info.items(): | ||
| log_msg += f"rank {info.rank} worker info: {info}" | ||
| self.logger.info(log_msg) |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logging condition if self.workers_info[server_url].running_count % 100 == 0: will log very frequently as running_count increases. Consider using a more sophisticated logging strategy or a longer interval to avoid excessive log output.
| return loss | ||
|
|
||
| @register_policy_loss("cispo") | ||
| def pg_loss_fn( |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The duplicate function name pg_loss_fn will cause the first definition (lines 66-84) to be overwritten by the second definition (lines 87-105). The second function should have a different name, likely cispo_loss_fn based on the decorator.
| def pg_loss_fn( | |
| def cispo_loss_fn( |
|
|
||
| config = cls( | ||
| hf_config=hf_config, | ||
| vocab_size=hf_config.vocab_size, |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removed line hf_config=hf_config, might break the initialization if the parent class TransformerConfig expects this parameter. Verify that hf_config is not required by the parent class or is set elsewhere.
| if response.finish_reason == "failed": | ||
| self.deactivate_worker_by_url(url) | ||
| response.extra_info.pop("url", None) | ||
| response = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout * 2) |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The timeout multiplier timeout=self.config.rollout_timeout * 2 is a magic number. Consider making this configurable or adding a comment explaining why doubling the timeout is appropriate.
| response = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout * 2) | |
| # The timeout is doubled to account for potential delays in worker response. | |
| # Consider making the multiplier configurable in RolloutConfig if tuning is needed. | |
| response = await asyncio.wait_for( | |
| response_ref, | |
| timeout=self.config.rollout_timeout * 2 # Magic number: 2. See comment above. | |
| ) |
| try: | ||
| rollout_responses = await asyncio.wait_for( | ||
| asyncio.gather(*response_future), timeout=self.rollout_timeout | ||
| asyncio.gather(*response_future), timeout=self.rollout_timeout * 2 |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The timeout multiplier timeout=self.rollout_timeout * 2 is a magic number. This should be made configurable or documented to explain why the timeout is doubled.
| hosts: list = [ | ||
| "10.103.12.31:12345", | ||
| "10.103.12.31:12346", | ||
| "10.103.12.31:12347", | ||
| "10.103.12.31:12348", | ||
| "10.103.12.31:12349", | ||
| "10.103.12.31:12350", | ||
| "10.103.12.31:12351", | ||
| "10.103.12.31:12352", | ||
| ] |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded IP addresses and ports are exposed in the configuration. These should be moved to environment variables or a secure configuration file to avoid exposing internal infrastructure details in the codebase.
| @@ -0,0 +1,147 @@ | |||
| import re | |||
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 're' is not used.
| import re |
| import requests | ||
| import aiohttp | ||
| import asyncio | ||
| from typing import Any, Callable, List, Optional |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'Any' is not used.
Import of 'Callable' is not used.
Import of 'Optional' is not used.
| from typing import Any, Callable, List, Optional | |
| from typing import List |
|
|
||
| from pydantic import BaseModel | ||
|
|
||
| from xtuner.v1.ray.judger.native import NativeJudger |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'NativeJudger' is not used.
| from xtuner.v1.ray.judger.native import NativeJudger |
|
|
||
| @staticmethod | ||
| def from_str(state_str: str) -> "ReplayState": | ||
| for state in ReplayState: |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This for-loop may attempt to iterate over a non-iterable instance of class type.
64bfaa9 to
1750dba
Compare
d4939ec to
0b4e057
Compare
No description provided.