Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c7efa17
mv infer engine error log to infer_engine_error.log
YanhuiDua Nov 5, 2025
75af333
fix log
YanhuiDua Nov 5, 2025
62094f2
feat: support rollout importance sampling helper from verl
RangiLyu Oct 30, 2025
83db4bc
rebase main
RangiLyu Nov 12, 2025
7ed4b42
fix lint
RangiLyu Nov 18, 2025
cb5f810
revert some change
RangiLyu Nov 18, 2025
503c055
fix lint
RangiLyu Nov 18, 2025
9f1853a
[Feat] support step-off-1 partial rollout
YanhuiDua Nov 19, 2025
e99ec62
Merge commit 'refs/pull/1197/head' of https://github.com/InternLM/xtu…
YanhuiDua Nov 19, 2025
516f105
[Feat] support step-off-1 partial rollout
YanhuiDua Nov 19, 2025
bda42f1
Merge branch 'heapq' of https://github.com/YanhuiDua/xtuner into heapq
YanhuiDua Nov 19, 2025
df837bc
add is loss_fn and compass judger
YanhuiDua Nov 19, 2025
715b339
fix when len of group dataitem is 0
YanhuiDua Nov 19, 2025
f250b2f
Merge branch 'heapq' of https://github.com/YanhuiDua/xtuner into heapq
YanhuiDua Nov 19, 2025
2b22447
refactor replaybuffer storage and fix filter
YanhuiDua Nov 19, 2025
d2b97ec
fix abort and failed order and add controller log info
YanhuiDua Nov 20, 2025
2433900
refactor rollout worker to handle exception
YanhuiDua Nov 20, 2025
2b2e339
fix
YanhuiDua Nov 21, 2025
6cd32f9
restart rollout worker, add memory info, and stop all dataflow task w…
YanhuiDua Nov 21, 2025
f40479c
check entropy
YanhuiDua Nov 21, 2025
4578f4e
Revert Merge commit 'refs/pull/1197/head' of https://github.com/Inter…
YanhuiDua Nov 21, 2025
bafd4b1
fix max_tokens and add more log info
YanhuiDua Nov 24, 2025
524bd62
fix relaunch server error
YanhuiDua Nov 24, 2025
ea52a9e
fix log level
YanhuiDua Nov 24, 2025
1750dba
fix abort state
YanhuiDua Nov 25, 2025
b59d7cb
fix update url error
YanhuiDua Nov 25, 2025
3e16b3b
Merge branch 'main' of https://github.com/YanhuiDua/xtuner into heapq…
YanhuiDua Nov 25, 2025
7eb7a65
Merge branch 'main' of https://github.com/InternLM/xtuner into heapq_…
YanhuiDua Nov 25, 2025
dd04037
Merge branch 'main' of https://github.com/InternLM/xtuner into heapq_…
YanhuiDua Nov 25, 2025
cdbd3fe
fix state for response_length to max_tokens
YanhuiDua Nov 25, 2025
0b4e057
fix pause
YanhuiDua Nov 27, 2025
90a6778
add more log info
YanhuiDua Nov 28, 2025
2616618
change send data concorrency
YanhuiDua Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 198 additions & 22 deletions xtuner/v1/data_proto/rl_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Any, Dict, List, Optional, TypedDict
import copy
from typing import Any, Dict, List, Literal, Optional, TypedDict

from cyclopts import Parameter
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated

from xtuner.v1.utils.logger import get_logger


logger = get_logger()

# ====================================
# ====== DataFlow 数据流 ==============
Expand All @@ -26,7 +31,7 @@ class RLUIDItem(BaseModel):
root_id: int = -1
action_id: int = -1
observation_id: int = -1
version: int = -1
version: int = 0


class RLDatasetItem(BaseModel):
Expand Down Expand Up @@ -67,11 +72,137 @@ class RLRolloutResponseItem(BaseModel):

model_config = ConfigDict(extra="forbid")
response: Optional[str] = None
async_response: Optional[List[Any]] = None
response_ids: Optional[List[int]] = None
num_return_tokens: Optional[int] = None
finish_reason: Optional[str] = None # "stop", "length", "abort", "failed", "skipped"
async_response_ids: Optional[List[Any]] = None
logprobs: Optional[List[float]] = None
async_logprobs: Optional[List[Any]] = None
num_return_tokens: int = 0
finish_reason: Optional[str] = None # "stop", "length", "abort", "failed", "skipped"
extra_info: Dict[str, Any] = dict()
state: Literal["init", "completed", "interrupted", "skipped", "failed"] = "init"

# def update(self, other: "RLRolloutResponseItem") -> None:
# """Updates another RLRolloutResponseItem into this one for partial
# rollout."""
# if not isinstance(other, RLRolloutResponseItem):
# raise TypeError("Can only update with another RLRolloutResponseItem instance.")

# logger.info("call update RLRolloutResponseItem function")
# init_response_ids_len = 0
# if self.response_ids is not None:
# init_response_ids_len = len(self.response_ids)
# if other.response_ids is not None:
# self.response_ids.extend(other.response_ids)
# else:
# self.response_ids = self.response_ids
# else:
# self.response_ids = other.response_ids

# init_logprobs_len = 0
# if self.logprobs is not None:
# init_logprobs_len = len(self.logprobs)
# if other.logprobs is not None:
# self.logprobs.extend(other.logprobs)
# else:
# self.logprobs = self.logprobs
# else:
# self.logprobs = other.logprobs

# init_response_len = 0
# if self.response is not None:
# init_response_len = len(self.response)
# if other.response is not None and len(other.response) > 0:
# self.response += other.response
# else:
# self.response = self.response
# else:
# self.response = other.response

# logger.info(
# f"Updated response_ids from {init_response_ids_len} to {len(self.response_ids)}, logprobs from {init_logprobs_len} to {len(self.logprobs)}. response from {init_response_len} to {len(self.response)}."
# )
# self.num_return_tokens = len(self.response_ids)
# self.finish_reason = other.finish_reason
# self.extra_info.update(other.extra_info)
# self.state = other.state

def update(self, other: "RLRolloutResponseItem") -> None:
"""Updates another RLRolloutResponseItem into this one for partial
rollout."""
if not isinstance(other, RLRolloutResponseItem):
raise TypeError("Can only update with another RLRolloutResponseItem instance.")

if self.response_ids is not None:
init_response_ids = copy.deepcopy(self.response_ids)
other_response_ids = copy.deepcopy(other.response_ids)
init_async_response_ids = copy.deepcopy(self.async_response_ids)
if other.response_ids is not None:
self.async_response_ids.append(other_response_ids.copy())
self.response_ids.extend(other_response_ids.copy())
logger.debug(
f"update response_ids from {init_response_ids} with {other_response_ids} to {self.response_ids}, async_response_ids from {init_async_response_ids} to {self.async_response_ids}."
)
else:
if other.response_ids is not None:
other_response_ids = copy.deepcopy(other.response_ids)
self.response_ids = other_response_ids
self.async_response_ids = [other_response_ids.copy()]
else:
self.async_response_ids = []

if self.logprobs is not None:
other_logprobs = copy.deepcopy(other.logprobs)
if other.logprobs is not None:
self.async_logprobs.append(other_logprobs.copy())
self.logprobs.extend(other_logprobs.copy())
else:
if other.logprobs is not None:
other_logprobs = copy.deepcopy(other.logprobs)
self.async_logprobs = [other_logprobs.copy()]
self.logprobs = other_logprobs
else:
self.async_logprobs = []

if self.response is not None:
init_response = copy.deepcopy(self.response)
other_response = copy.deepcopy(other.response)
if other.response is not None:
self.response += other_response
self.async_response.append(other_response)
logger.debug(
f"update response from {repr(init_response)} with {repr(other_response)} to {repr(self.response)}, async_response_ids: {self.async_response}."
)
else:
if other.response is not None:
self.response = other.response
self.async_response = [other.response]
else:
self.async_response = []

response_ids_lens = []
for response_ids in self.async_response_ids:
response_ids_lens.append(len(response_ids))
logprobs_lens = []
for logprobs in self.async_logprobs:
logprobs_lens.append(len(logprobs))
response_lens = []
for response in self.async_response:
response_lens.append(len(response))
logger.debug(
f"update response_ids lengths: {response_ids_lens}, logprobs lengths: {logprobs_lens}, response lengths: {response_lens}."
)

if self.response_ids is not None:
assert sum(response_ids_lens) == len(self.response_ids), "response_ids length mismatch after update."
if self.logprobs is not None:
assert sum(logprobs_lens) == len(self.logprobs), "logprobs length mismatch after update."
if self.response is not None:
assert sum(response_lens) == len(self.response), "response length mismatch after update."
self.num_return_tokens = sum(response_ids_lens)
self.finish_reason = other.finish_reason
self.extra_info.update(other.extra_info)
self.state = other.state


class RLJudgerResponseItem(BaseModel):
Expand Down Expand Up @@ -124,6 +255,7 @@ class RLExtraDataItem(BaseModel):

model_config = ConfigDict(extra="forbid")
retry_times: int = 0
state: str = ""
extra_info: Dict[str, Any] = dict()


Expand All @@ -147,28 +279,67 @@ class RLDataFlowItem(BaseModel):
extra_info: RLExtraDataItem = RLExtraDataItem()


def check_dataflow_item(group_data_items):
if not group_data_items or len(group_data_items) == 0:
return False
def check_valid_dataflow_item(group_data_items: List[RLDataFlowItem]) -> bool:
"""Validates a group of RLDataFlowItem objects based on their state and
data integrity.

# 如果存在abort的状态,相当于跳过检查,下次会重新rollout
is_abort = any(item.env.rollout.finish_reason == "abort" for item in group_data_items)
is_skipped = any(item.env.rollout.finish_reason == "skipped" for item in group_data_items)
if is_abort or is_skipped:
return True
The validation follows a priority order for finish reasons:
1. 'abort' or 'skipped': The group is considered valid for retry (returns True).
2. 'failed' (rollout or judger): The group is invalid (returns False).
3. Data Integrity Checks:
- At least one of `response` or `response_ids` must be present.
- If `response_ids` is present, `logprobs` must also be present and have the same length.

no_failures = all(item.env.rollout.finish_reason != "failed" for item in group_data_items)
if not no_failures:
return False
Args:
group_data_items: A list of RLDataFlowItem to be checked.

no_judger_failures = all(item.env.judger.extra_info.get("state", "") != "failed" for item in group_data_items)
if not no_judger_failures:
return False
Returns:
A tuple containing:
- bool: True if the group is valid or can be retried, False otherwise.
- str: A message explaining the validation result.
"""
for item in group_data_items:
rollout_info = item.env.rollout
response_valid = True if rollout_info.response is not None and len(rollout_info.response) > 0 else False
ids_valid = True if rollout_info.response_ids is not None and len(rollout_info.response_ids) > 0 else False
logprobs_valid = True if rollout_info.logprobs is not None and len(rollout_info.logprobs) > 0 else False
if item.env.rollout.state in ["skipped", "failed"]:
logger.info(f"Invalid dataflow item found: rollout state is {item.env.rollout.state}. UID: {item.uid}")
return False
if not response_valid and not ids_valid and item.env.rollout.state != "interrupted":
logger.info(
f"Invalid dataflow item found: no response or response_ids. UID:{item.uid.action_id} with rollout response {item.env.rollout}"
)
return False
if ids_valid and logprobs_valid and len(rollout_info.logprobs) != len(rollout_info.response_ids): # type: ignore[arg-type]
logger.info(f"Invalid dataflow item found: logprobs and response_ids length mismatch. UID: {item.uid}")
return False
return True


def update_rollout_item(group_data_items, target_value):
"""Update a list of RLDataFlowItem objects by merging another
RLRolloutResponseItem into each item's env.rollout attribute.

all_responses_valid = all(item.env.rollout.response for item in group_data_items)
all_ids_valid = all(item.env.rollout.response_ids for item in group_data_items)
Args:
group_data_items (List[RLDataFlowItem]): List of data items to update.
target_value (RLRolloutResponseItem): The rollout response item to merge into each data item.

return all_responses_valid or all_ids_valid
Returns:
List[RLDataFlowItem]: The updated list of data items.

Example:
>>> # Suppose you want to update the rollout response for each item
>>> items = [RLDataFlowItem(), RLDataFlowItem()]
>>> rollout_response = RLRolloutResponseItem(response="new response", response_ids=[1,2,3])
>>> update_rollout_item(items, rollout_response)
# Now each item's env.rollout has been updated with the new response and response_ids
"""

for idx, item in enumerate(group_data_items):
item.env.rollout.update(target_value[idx])

return group_data_items


def update_dataflow_item(group_data_items, target_key, target_value):
Expand Down Expand Up @@ -200,7 +371,12 @@ def update_dataflow_item(group_data_items, target_key, target_value):
parent_obj = group_data_items[i]
for key in keys[:-1]:
parent_obj = getattr(parent_obj, key)
setattr(parent_obj, keys[-1], target_value[i])

if keys[-1] == "rollout":
existing_rollout_item = getattr(parent_obj, keys[-1])
existing_rollout_item.update(target_value[i])
else:
setattr(parent_obj, keys[-1], target_value[i])

return group_data_items

Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/model/dense/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def to_hf_key_list(self, key: str) -> list[str]:

class Qwen2DenseConfig(TransformerConfig):
use_sliding_window: bool = False
bos_token_id: int

def build(self) -> Qwen2Dense:
return Qwen2Dense(self)
Expand All @@ -44,7 +45,6 @@ def from_hf(cls, hf_path: str | Path) -> Self:
assert isinstance(hf_config, HFConfig)

config = cls(
hf_config=hf_config,
vocab_size=hf_config.vocab_size,
max_position_embeddings=hf_config.max_position_embeddings,
pad_token_id=getattr(hf_config, "pad_token_id"),
Expand Down
7 changes: 7 additions & 0 deletions xtuner/v1/ray/config/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ class RolloutConfig(BaseModel):
help="Maximum number of retries per rollout worker before deactivation.",
),
] = None
max_retry_per_sample: Annotated[
int,
Parameter(
group=infer_group,
help="Maximum number of retries per sample before marking it as failed.",
),
] = 1
worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir"

def model_post_init(self, __context: Any) -> None:
Expand Down
Loading