Skip to content

Commit 1750dba

Browse files
committed
fix abort state
1 parent ea52a9e commit 1750dba

File tree

10 files changed

+162
-51
lines changed

10 files changed

+162
-51
lines changed

xtuner/v1/data_proto/rl_data.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from cyclopts import Parameter
44
from pydantic import BaseModel, ConfigDict, Field
55
from typing_extensions import Annotated
6+
from xtuner.v1.utils.logger import get_logger
67

8+
logger = get_logger()
79

810
# ====================================
911
# ====== DataFlow 数据流 ==============
@@ -84,10 +86,12 @@ def update(self, other: "RLRolloutResponseItem") -> None:
8486
self.response_ids.extend(other.response_ids)
8587
else:
8688
self.response_ids = other.response_ids
89+
8790
if self.logprobs is not None and other.logprobs:
8891
self.logprobs.extend(other.logprobs)
8992
else:
9093
self.logprobs = other.logprobs
94+
9195
if self.response is not None and other.response:
9296
self.response += other.response
9397
else:
@@ -197,10 +201,13 @@ def check_valid_dataflow_item(group_data_items: List[RLDataFlowItem]) -> bool:
197201
ids_valid = bool(rollout_info.response_ids)
198202
logprobs_valid = bool(rollout_info.logprobs)
199203
if item.env.rollout.state in ["skipped", "failed"]:
204+
logger.info(f"Invalid dataflow item found: rollout state is {item.env.rollout.state}. UID: {item.uid}")
200205
return False
201-
if not response_valid and not ids_valid:
206+
if not response_valid and not ids_valid and item.env.rollout.state != "interrupted":
207+
logger.info(f"Invalid dataflow item found: no response or response_ids. UID:{item.data.uid} with rollout response {item.env.rollout}")
202208
return False
203209
if ids_valid and logprobs_valid and len(rollout_info.logprobs) != len(rollout_info.response_ids): # type: ignore[arg-type]
210+
logger.info(f"Invalid dataflow item found: logprobs and response_ids length mismatch. UID: {item.uid}")
204211
return False
205212
return True
206213

xtuner/v1/ray/dataflow/flow.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
self.finished_samples_count = 0
123123
self.failed_samples_count = 0
124124
self.skipped_sample_count = 0
125+
self.filtered_sample_count = 0
125126
self.sample_from_expired_storage = False
126127
self.logger = get_logger(log_dir=self.config.worker_log_dir, tag="DataFlow")
127128
self.target_batch_size = self.config.global_batch_size
@@ -161,6 +162,7 @@ def _reset_internal_states_on_step(
161162
self.finished_samples_count = 0
162163
self.failed_samples_count = 0
163164
self.skipped_sample_count = 0
165+
self.filtered_sample_count = 0
164166
self.logger.info(
165167
f"global_batch_size: {global_batch_size}, sample_params: {sample_params}, extra_params: {extra_params}"
166168
)
@@ -224,6 +226,8 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
224226
group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined]
225227
if len(group_data_items) > 0:
226228
await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined]
229+
else:
230+
self.filtered_sample_count += 1
227231
self.logger.debug(f"Worker task completed successfully for {action_id}.")
228232
elif group_state == "interrupted":
229233
await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined]
@@ -289,17 +293,17 @@ async def concurrent_task_runner(self):
289293
cleanup_start_time = time.monotonic()
290294
cleanup_timeout = 10 * 60 # 10 minutes in seconds
291295
while len(waiting_tasks) > 0:
292-
elapsed_time = time.monotonic() - cleanup_start_time
293-
if elapsed_time > cleanup_timeout:
294-
self.logger.warning(
295-
f"Cleanup timeout of {cleanup_timeout}s reached. "
296-
f"Forcefully cancelling {len(waiting_tasks)} remaining tasks."
297-
)
298-
for task in waiting_tasks:
299-
task.cancel()
300-
# Wait for cancellations to complete
301-
await asyncio.gather(*waiting_tasks, return_exceptions=True)
302-
break # Exit the cleanup loop
296+
# elapsed_time = time.monotonic() - cleanup_start_time
297+
# if elapsed_time > cleanup_timeout:
298+
# self.logger.warning(
299+
# f"Cleanup timeout of {cleanup_timeout}s reached. "
300+
# f"Forcefully cancelling {len(waiting_tasks)} remaining tasks."
301+
# )
302+
# for task in waiting_tasks:
303+
# task.cancel()
304+
# # Wait for cancellations to complete
305+
# await asyncio.gather(*waiting_tasks, return_exceptions=True)
306+
# break # Exit the cleanup loop
303307
done_tasks, pending_tasks = await asyncio.wait(
304308
waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED
305309
)
@@ -309,7 +313,8 @@ async def concurrent_task_runner(self):
309313
self.logger.info("All worker tasks have completed after pausing env controller.")
310314

311315
self.logging_replaybuffer_state()
312-
316+
self.logger.info(ray.get(self.env_controller.get_rollout_stats.remote()))
317+
313318
async def pause(self, timeout: float = 60.0):
314319
"""Asynchronously sends abort requests to all rollout workers."""
315320
rollout_info = ray.get(self.env_controller.get_rollout_info.remote()) # type: ignore[attr-defined]
@@ -378,6 +383,7 @@ def logging_replaybuffer_state(self, logging_msg: Optional[str] = None):
378383
status = self.get_replaybuffer_status()
379384
logging_msg = logging_msg if logging_msg else ""
380385
logging_msg += f"ReplayBuffer Status: {status}"
386+
logging_msg += f", Filtered samples count: {self.filtered_sample_count}"
381387
self.logger.info(logging_msg)
382388

383389
def get_replaybuffer_status(self):

xtuner/v1/ray/dataflow/replay_buffer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,9 @@ def add(self, grouped_dataitem: List[RLDataFlowItem], partial_rollout_step: int
336336

337337
# 1. 跟prompt相关的action_id记录
338338
if root_id in self._root2actions:
339+
# TODO: version 更新需要 根据是否update_weights来判断,需要考虑到非共卡的情况
339340
replay_meta.version += 1
340-
self.logger.debug(f"Existing root_id: {root_id} found. Incrementing version to {replay_meta.version}.")
341+
self.logger.info(f"Existing root_id: {root_id} found. Incrementing version to {replay_meta.version}.")
341342
self._root2actions[root_id].append(action_id)
342343
else:
343344
self._root2actions[root_id] = [action_id]
@@ -346,14 +347,14 @@ def add(self, grouped_dataitem: List[RLDataFlowItem], partial_rollout_step: int
346347
# 2. 根据rollout状态加到finished, abort, abort_over_version队列中;Partial rollout is handled based on whether finish_reason is "abort".
347348
if replay_meta.state == ReplayState.INTERRUPTED and replay_meta.version < partial_rollout_step:
348349
self._interrupted_actions[replay_meta.version].append(action_id)
349-
self.logger.debug(
350+
self.logger.info(
350351
f"Add aborted sample with root_id: {root_id}, action_id: {action_id} to _interrupted_actions."
351352
)
352353
elif replay_meta.state == ReplayState.INTERRUPTED and replay_meta.version >= partial_rollout_step:
353354
self._expired_actions.append(action_id)
354355
replay_meta.version = 0
355356
replay_meta.state = ReplayState.EXPIRED
356-
self.logger.debug(
357+
self.logger.info(
357358
f"Action_id: {action_id} has exceeded partial_rollout_step {partial_rollout_step}. Add this sample with root_id: {root_id} to _expired_actions list."
358359
)
359360
elif replay_meta.state == ReplayState.COMPLETED:
@@ -386,8 +387,9 @@ def get(self, global_batch_size: int) -> Tuple[List[List[RLDataFlowItem]], List[
386387
"""
387388
samples = []
388389
multimodal_train_infos = []
389-
target_batch_size = min(global_batch_size, len(self._completed_actions))
390-
for _ in range(global_batch_size):
390+
target_batch_size = min(global_batch_size, self.get_completed_samples())
391+
self.logger.info(f"Retrieving {target_batch_size} completed samples from the replay buffer.")
392+
for _ in range(target_batch_size):
391393
action_id = self._pop_highest_version_action(self._completed_actions)
392394
replay_meta = self._actions[action_id] # type: ignore[index]
393395
group_samples = mapping_replaymeta_to_dataitem(replay_meta)
@@ -563,16 +565,17 @@ def sample_from_interrupted_storage(self, tokenizer) -> List[RLDataFlowItem]:
563565
sample.uid.version = replay_meta.version
564566
sample.extra_info.state = str(ReplayState.INIT)
565567
if sample.env.rollout.response_ids and sample.data.input_ids:
568+
# TODO: response_ids 累加
566569
if "train_prompt_ids" in sample.data.extra_info:
567570
sample.data.input_ids = (
568571
sample.data.extra_info["train_prompt_ids"] + sample.env.rollout.response_ids
569572
)
570573
else:
571574
sample.data.input_ids.extend(sample.env.rollout.response_ids)
572-
# elif sample.env.rollout.response:
573-
# sample.data.input_ids.extend(tokenizer.encode(sample.env.rollout.response, add_special_tokens=False))
575+
elif sample.env.rollout.response:
576+
sample.data.input_ids.extend(tokenizer.encode(sample.env.rollout.response, add_special_tokens=False))
574577
self.logger.info(
575-
f"Sampling interrupted action_id: {action_id} from replay buffer, remain interrupted samples: {len(self._interrupted_actions)}"
578+
f"Sampling interrupted action_id: {action_id} from replay buffer, remain interrupted samples: {self.get_interrupted_samples()}"
576579
)
577580
return group_samples
578581

xtuner/v1/ray/environment/base_env.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,11 @@ def check_active_workers(self, block=True):
212212
block (bool): Whether to block until the operation completes.
213213
"""
214214
return self._call_rollout_func("check_active_workers", block)
215+
216+
def get_rollout_stats(self, block=True):
217+
"""Gets statistics from the rollout workers.
218+
219+
Args:
220+
block (bool): Whether to block until the operation completes.
221+
"""
222+
return self._call_rollout_func("get_rollout_stats", block)

xtuner/v1/ray/environment/single_turn_env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def generate(
8787
sample.data.extra_info["action_id"] = sample.uid.action_id
8888
if sample.env.rollout.num_return_tokens > 0:
8989
sample.data.extra_info["num_return_tokens"] = sample.env.rollout.num_return_tokens
90-
self.logger.debug(
90+
self.logger.info(
9191
f"Set num_return_tokens: {sample.env.rollout.num_return_tokens} for sample {sample.uid}."
9292
)
9393
fut = self.rollout_controller.rollout.remote(
@@ -136,7 +136,7 @@ async def run(
136136
if self.judger_controller and continue_judger:
137137
try:
138138
judger_responses: List[RLJudgerResponseItem] = await asyncio.wait_for(
139-
self.judger_controller.run.remote(group_data_items), timeout=self.judger_timeout
139+
self.judger_controller.run.remote(group_data_items), timeout=self.judger_timeout * 2
140140
)
141141
except asyncio.TimeoutError:
142142
self.logger.error("Get judger controller response timeout and return the failed response.")

xtuner/v1/ray/rollout/controller.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -390,26 +390,41 @@ async def rollout(
390390
format=format,
391391
extra_info=extra_info,
392392
)
393-
if self.workers_info[server_url].running_count % 100 == 0:
394-
log_msg = ""
395-
for _, info in self.workers_info.items():
396-
log_msg += f"rank {info.rank} worker info: {info}"
397-
self.logger.info(log_msg)
393+
# if self.workers_info[server_url].running_count % 100 == 0:
394+
# log_msg = ""
395+
# for _, info in self.workers_info.items():
396+
# log_msg += f"rank {info.rank} worker info: {info}"
397+
# self.logger.info(log_msg)
398398
try:
399399
response = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout * 2)
400-
self.workers_info[server_url].running_count -= 1
401400
self.workers_info[server_url].success_count += 1
402401
if response.state == "failed" or response.state == "skipped":
403402
self.logger.info(f"Rollout worker {worker} returned state {response.state}. Deactivating worker.")
403+
self.workers_info[server_url].skipped_count += 1
404404
self.deactivate_worker_by_url(server_url)
405405
return response
406406
except asyncio.TimeoutError:
407-
self.workers_info[server_url].running_count -= 1
408407
self.workers_info[server_url].failure_count += 1
409408
# self.deactivate_worker_by_url(server_url) # do not deactivate on timeout, only on skipped state
410409
self.logger.error(f"Get response from rollout worker {worker} timeout and return skip this sample.")
410+
self.deactivate_worker_by_url(server_url)
411411
return RLRolloutResponseItem(state="skipped")
412412

413+
def get_rollout_stats(self) -> str:
414+
"""Get statistics about the rollout workers.
415+
Returns:
416+
str: A formatted string containing statistics about each rollout
417+
"""
418+
log_parts = ["Rollout Worker Stats:"]
419+
for url, info in self.workers_info.items():
420+
log_parts.append(
421+
f" - URL: {url} | Rank: {info.rank} | Active: {info.is_active} | "
422+
f"Running: {info.running_count} | Success: {info.success_count} | "
423+
f"Failures: {info.failure_count}"
424+
)
425+
log_msg = "\n".join(log_parts)
426+
return log_msg
427+
413428
def start_api_server(self, host: str = "0.0.0.0", port: int = 8000):
414429
"""Starts the API server to expose the rollout functionality."""
415430
app = FastAPI()

xtuner/v1/ray/rollout/lmdeploy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ async def _create_request(
127127
if "num_return_tokens" in extra_info:
128128
max_return_tokens = sample_params["max_tokens"] - extra_info["num_return_tokens"]
129129
sample_params["max_tokens"] = max_return_tokens
130-
self.logger.debug(
130+
self.logger.info(
131131
f"Set max_tokens to {max_return_tokens} based on num_return_tokens {extra_info['num_return_tokens']}"
132132
)
133133

xtuner/v1/ray/rollout/sglang.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ async def _create_request(
7878
payload["messages"] = prompt
7979
payload.update(sglang_sample_params)
8080
# note: chat completions 接口需要传入 max_tokens 和 min_tokens 参数
81-
if "num_return_tokens" in extra_params:
82-
max_return_tokens = sglang_sample_params["max_new_tokens"] - extra_params["num_return_tokens"]
81+
if "num_return_tokens" in extra_info:
82+
max_return_tokens = sglang_sample_params["max_new_tokens"] - extra_info["num_return_tokens"]
8383
payload["max_tokens"] = max_return_tokens
8484
self.logger.info(
85-
f"Set max_tokens to {max_return_tokens} based on num_return_tokens {extra_params['num_return_tokens']}"
85+
f"Set max_tokens to {max_return_tokens} based on num_return_tokens {extra_info['num_return_tokens']}"
8686
)
8787
else:
8888
payload["max_tokens"] = sglang_sample_params["max_new_tokens"]

xtuner/v1/ray/rollout/worker.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -333,16 +333,24 @@ async def rollout_task(
333333
endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}"
334334

335335
while True:
336-
http_result = await self._create_request(
337-
endpoint_url,
338-
openai_prompts,
339-
input_ids,
340-
openai_tools,
341-
tool_choice,
342-
sample_params=sample_params,
343-
extra_params=extra_params,
344-
extra_info=extra_info,
345-
)
336+
if extra_info.get("num_return_tokens", None) is not None and (sample_params["max_tokens"] - extra_info["num_return_tokens"]) == 0:
337+
return RLRolloutResponseItem(
338+
response="",
339+
response_ids=[],
340+
num_return_tokens=0,
341+
finish_reason="length",
342+
)
343+
else:
344+
http_result = await self._create_request(
345+
endpoint_url,
346+
openai_prompts,
347+
input_ids,
348+
openai_tools,
349+
tool_choice,
350+
sample_params=sample_params,
351+
extra_params=extra_params,
352+
extra_info=extra_info,
353+
)
346354
# Case 1: Request was successful
347355
if http_result.response is not None: # 推理完成:completed状态:finish_reason为abort/stop/length, 退出
348356
response = await self._handle_non_stream_response(

0 commit comments

Comments
 (0)