Skip to content

Conversation

@YanhuiDua
Copy link
Collaborator

No description provided.

@YanhuiDua YanhuiDua requested a review from Copilot November 21, 2025 12:16
Copilot finished reviewing on behalf of YanhuiDua November 21, 2025 12:19
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces partial rollout support with importance sampling (IS) for reinforcement learning training, enabling more efficient off-policy learning and handling of distribution mismatch between rollout and training policies.

Key changes:

  • Implements rollout importance sampling with configurable aggregation levels (token, sequence, geometric) and handling modes (truncate, mask)
  • Adds state management for partial rollouts with version tracking and replay buffer support for interrupted/completed/expired samples
  • Enhances rollout error handling with retry logic and proper state transitions (completed, interrupted, skipped, failed)

Reviewed Changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
xtuner/v1/rl/base/rollout_is.py New module implementing importance sampling with comprehensive metrics and mismatch detection
xtuner/v1/rl/base/loss.py Adds rollout IS configuration and weights to loss context for training
xtuner/v1/rl/oreal/loss.py Integrates IS weights into OREAL policy loss computation
xtuner/v1/rl/grpo/loss.py Integrates IS weights into GRPO policy loss computation
xtuner/v1/rl/loss_fn.py Adds new CISPO loss function (has bug: duplicate function name)
xtuner/v1/rl/base/worker.py Adds IS weight computation and metrics logging in training worker
xtuner/v1/train/rl_trainer.py Adds memory usage logging and updates validation function names
xtuner/v1/ray/rollout/worker.py Implements retry logic and state-based error handling for rollout requests
xtuner/v1/ray/rollout/controller.py Adds worker statistics tracking and improved failure recovery
xtuner/v1/ray/rollout/sglang.py Supports partial rollout with num_return_tokens parameter
xtuner/v1/ray/rollout/lmdeploy.py Adds FP8 support and num_return_tokens handling
xtuner/v1/ray/dataflow/replay_buffer.py Refactored to support versioned states with interrupted/completed/expired sample management
xtuner/v1/ray/dataflow/flow.py Updated to handle partial rollout with state-based sample routing
xtuner/v1/ray/environment/single_turn_env.py Simplifies error handling with state-based validation
xtuner/v1/ray/evaluator.py Removes retry logic in favor of upstream error handling
xtuner/v1/ray/judger/compass_verifier_v2.py New judger implementation (has bugs: missing await, hardcoded IPs)
xtuner/v1/data_proto/rl_data.py Adds state field and update method for partial rollout support
xtuner/v1/data_proto/utils.py Adds tensor packing/unpacking utilities for IS computation
xtuner/v1/ray/config/worker.py Adds FP8 and retry configuration options
xtuner/v1/train/cli/rl.py Updates Ray initialization logic for local execution
xtuner/v1/model/dense/qwen2.py Adds bos_token_id field (has bug: missing default value)
xtuner/v1/utils/httpx_utils.py Removes unused rollout status helper function
Comments suppressed due to low confidence (1)

xtuner/v1/ray/dataflow/replay_buffer.py:390

  • Variable target_batch_size is not used.
        target_batch_size = min(global_batch_size, len(self._completed_actions))

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 383 to 387
if self.workers_info[server_url].running_count % 100 == 0:
log_msg = ""
for _, info in self.workers_info.items():
log_msg += f"rank {info.rank} worker info: {info}"
self.logger.info(log_msg)
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logging condition if self.workers_info[server_url].running_count % 100 == 0: will log very frequently as running_count increases. Consider using a more sophisticated logging strategy or a longer interval to avoid excessive log output.

Copilot uses AI. Check for mistakes.
return loss

@register_policy_loss("cispo")
def pg_loss_fn(
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The duplicate function name pg_loss_fn will cause the first definition (lines 66-84) to be overwritten by the second definition (lines 87-105). The second function should have a different name, likely cispo_loss_fn based on the decorator.

Suggested change
def pg_loss_fn(
def cispo_loss_fn(

Copilot uses AI. Check for mistakes.

config = cls(
hf_config=hf_config,
vocab_size=hf_config.vocab_size,
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removed line hf_config=hf_config, might break the initialization if the parent class TransformerConfig expects this parameter. Verify that hf_config is not required by the parent class or is set elsewhere.

Copilot uses AI. Check for mistakes.
if response.finish_reason == "failed":
self.deactivate_worker_by_url(url)
response.extra_info.pop("url", None)
response = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout * 2)
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The timeout multiplier timeout=self.config.rollout_timeout * 2 is a magic number. Consider making this configurable or adding a comment explaining why doubling the timeout is appropriate.

Suggested change
response = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout * 2)
# The timeout is doubled to account for potential delays in worker response.
# Consider making the multiplier configurable in RolloutConfig if tuning is needed.
response = await asyncio.wait_for(
response_ref,
timeout=self.config.rollout_timeout * 2 # Magic number: 2. See comment above.
)

Copilot uses AI. Check for mistakes.
try:
rollout_responses = await asyncio.wait_for(
asyncio.gather(*response_future), timeout=self.rollout_timeout
asyncio.gather(*response_future), timeout=self.rollout_timeout * 2
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The timeout multiplier timeout=self.rollout_timeout * 2 is a magic number. This should be made configurable or documented to explain why the timeout is doubled.

Copilot uses AI. Check for mistakes.
Comment on lines +130 to +139
hosts: list = [
"10.103.12.31:12345",
"10.103.12.31:12346",
"10.103.12.31:12347",
"10.103.12.31:12348",
"10.103.12.31:12349",
"10.103.12.31:12350",
"10.103.12.31:12351",
"10.103.12.31:12352",
]
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded IP addresses and ports are exposed in the configuration. These should be moved to environment variables or a secure configuration file to avoid exposing internal infrastructure details in the codebase.

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,147 @@
import re
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 're' is not used.

Suggested change
import re

Copilot uses AI. Check for mistakes.
import requests
import aiohttp
import asyncio
from typing import Any, Callable, List, Optional
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Any' is not used.
Import of 'Callable' is not used.
Import of 'Optional' is not used.

Suggested change
from typing import Any, Callable, List, Optional
from typing import List

Copilot uses AI. Check for mistakes.

from pydantic import BaseModel

from xtuner.v1.ray.judger.native import NativeJudger
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'NativeJudger' is not used.

Suggested change
from xtuner.v1.ray.judger.native import NativeJudger

Copilot uses AI. Check for mistakes.

@staticmethod
def from_str(state_str: str) -> "ReplayState":
for state in ReplayState:
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This for-loop may attempt to iterate over a non-iterable instance of class type.

Copilot uses AI. Check for mistakes.
@YanhuiDua YanhuiDua force-pushed the heapq_merge_is_optim branch from 64bfaa9 to 1750dba Compare November 25, 2025 03:16
@YanhuiDua YanhuiDua force-pushed the heapq_merge_is_optim branch from d4939ec to 0b4e057 Compare November 27, 2025 11:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants