Skip to content

Commit 0b4e057

Browse files
committed
fix pause
1 parent cdbd3fe commit 0b4e057

File tree

7 files changed

+82
-48
lines changed

7 files changed

+82
-48
lines changed

xtuner/v1/data_proto/rl_data.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,29 @@ def update(self, other: "RLRolloutResponseItem") -> None:
8282
if not isinstance(other, RLRolloutResponseItem):
8383
raise TypeError("Can only update with another RLRolloutResponseItem instance.")
8484

85-
if self.response_ids is not None and other.response_ids:
86-
self.response_ids.extend(other.response_ids)
85+
if self.response_ids is not None:
86+
init_response_len = len(self.response_ids)
87+
if other.response_ids is not None:
88+
self.response_ids.extend(other.response_ids)
89+
logger.info(f"Updated response_ids from {init_response_len} to {len(self.response_ids)}")
90+
else:
91+
self.response_ids = self.response_ids
8792
else:
8893
self.response_ids = other.response_ids
8994

90-
if self.logprobs is not None and other.logprobs:
91-
self.logprobs.extend(other.logprobs)
95+
if self.logprobs is not None:
96+
if other.logprobs is not None:
97+
self.logprobs.extend(other.logprobs)
98+
else:
99+
self.logprobs = self.logprobs
92100
else:
93101
self.logprobs = other.logprobs
94102

95-
if self.response is not None and other.response:
96-
self.response += other.response
103+
if self.response is not None:
104+
if other.response is not None:
105+
self.response + other.response
106+
else:
107+
self.response = self.response
97108
else:
98109
self.response = other.response
99110
self.num_return_tokens += other.num_return_tokens
@@ -197,14 +208,14 @@ def check_valid_dataflow_item(group_data_items: List[RLDataFlowItem]) -> bool:
197208
"""
198209
for item in group_data_items:
199210
rollout_info = item.env.rollout
200-
response_valid = bool(rollout_info.response)
201-
ids_valid = bool(rollout_info.response_ids)
202-
logprobs_valid = bool(rollout_info.logprobs)
211+
response_valid = True if rollout_info.response is not None and len(rollout_info.response) > 0 else False
212+
ids_valid = True if rollout_info.response_ids is not None and len(rollout_info.response_ids) > 0 else False
213+
logprobs_valid = True if rollout_info.logprobs is not None and len(rollout_info.logprobs) > 0 else False
203214
if item.env.rollout.state in ["skipped", "failed"]:
204215
logger.info(f"Invalid dataflow item found: rollout state is {item.env.rollout.state}. UID: {item.uid}")
205216
return False
206217
if not response_valid and not ids_valid and item.env.rollout.state != "interrupted":
207-
logger.info(f"Invalid dataflow item found: no response or response_ids. UID:{item.data.uid} with rollout response {item.env.rollout}")
218+
logger.info(f"Invalid dataflow item found: no response or response_ids. UID:{item.uid.action_id} with rollout response {item.env.rollout}")
208219
return False
209220
if ids_valid and logprobs_valid and len(rollout_info.logprobs) != len(rollout_info.response_ids): # type: ignore[arg-type]
210221
logger.info(f"Invalid dataflow item found: logprobs and response_ids length mismatch. UID: {item.uid}")

xtuner/v1/ray/dataflow/flow.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
221221

222222
# Step 3: Determine the sample's state and act accordingly.
223223
group_state = determine_group_state(group_data_items)
224-
self.logger.debug(f"Determined replay state for {action_id}: {group_state}")
224+
self.logger.info(f"Determined replay state for {action_id}: {group_state}")
225225
if group_state == "completed":
226226
group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined]
227227
if len(group_data_items) > 0:
@@ -260,6 +260,7 @@ async def concurrent_task_runner(self):
260260
before completing.
261261
"""
262262
waiting_tasks = set()
263+
start_time = time.monotonic()
263264
with tqdm(total=self.target_batch_size, desc="rollout_controller for training samples") as pbar:
264265
update_step = max(1, int(self.target_batch_size * 0.01))
265266
next_update_threshold = update_step
@@ -286,8 +287,12 @@ async def concurrent_task_runner(self):
286287
pbar.n = self.finished_samples_count
287288
pbar.refresh()
288289

290+
elapsed_time = time.monotonic() - start_time
291+
self.logger.info(f"Sample collection finished. Time taken: {elapsed_time:.2f} seconds.")
292+
289293
# NOTE: Directly send pause requests to rollout workers because calling `rollout_controller.pause()`
290294
# would be queued behind many worker tasks, causing a significant delay.
295+
start_time = time.monotonic()
291296
if self.enable_partial_rollout:
292297
await self.pause()
293298
cleanup_start_time = time.monotonic()
@@ -311,21 +316,26 @@ async def concurrent_task_runner(self):
311316
await self.pause()
312317
waiting_tasks = pending_tasks
313318
self.logger.info("All worker tasks have completed after pausing env controller.")
314-
319+
elapsed_time = time.monotonic() - start_time
320+
self.logger.info(f"Pause generation. Time taken: {elapsed_time:.2f} seconds.")
315321
self.logging_replaybuffer_state()
316322
self.logger.info(ray.get(self.env_controller.get_rollout_stats.remote()))
317323

318324
async def pause(self, timeout: float = 60.0):
319325
"""Asynchronously sends abort requests to all rollout workers."""
320-
rollout_info = ray.get(self.env_controller.get_rollout_info.remote()) # type: ignore[attr-defined]
321-
self.worker_url_list = list(rollout_info["server_url_dict"].values())
322-
326+
self.logger.info("Sending abort requests to all rollout workers.")
327+
# rollout_info = ray.get(self.env_controller.get_rollout_info.remote()) # type: ignore[attr-defined]
328+
# self.worker_url_list = list(rollout_info["server_url_dict"].values())
329+
self.logger.info("get self.worker_url_list from env_controller: ", self.worker_url_list)
323330
if not self.worker_url_list:
324331
self.logger.info("No active rollout workers to pause.")
325332
return
326333

327334
async with httpx.AsyncClient() as client:
328-
tasks = [self._send_abort_request(client, url, timeout=timeout) for url in self.worker_url_list]
335+
tasks = []
336+
for url in self.worker_url_list:
337+
self.logger.info(f"Sending abort request to worker at {url}")
338+
tasks.append(self._send_abort_request(client, url, timeout=timeout))
329339
results = await asyncio.gather(*tasks)
330340

331341
failed_workers = [url for url, success in results if not success]
@@ -337,7 +347,7 @@ async def pause(self, timeout: float = 60.0):
337347
f"Failed: {len(failed_workers)}. Failed workers: {failed_workers}"
338348
)
339349
else:
340-
self.logger.debug(f"All {succeeded_count} abort requests sent successfully.")
350+
self.logger.info(f"All {succeeded_count} abort requests sent successfully.")
341351

342352
async def run(
343353
self,
@@ -397,7 +407,7 @@ async def _send_abort_request(self, client, url, timeout):
397407
try:
398408
response = await client.post(worker_url, json={"abort_all": True}, timeout=timeout)
399409
response.raise_for_status()
400-
self.logger.debug(f"Successfully sent abort request to {url}")
410+
self.logger.info(f"Successfully sent abort request to {url}")
401411
return url, True
402412
except Exception as e:
403413
self.logger.error(f"Failed to send abort request to {url}: {e}")

xtuner/v1/ray/dataflow/replay_buffer.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re
117117
group_state = determine_group_state(grouped_dataitem)
118118

119119
replay_state = ReplayState.from_str(group_state)
120-
logger.debug(f"determined group_state: {group_state}, replay_state: {replay_state}")
120+
logger.info(f"determined group_state: {group_state}, replay_state: {replay_state}, version: {version}")
121121
replay_meta = ReplayMeta(
122122
env=env_str,
123123
root_id=root_id,
@@ -337,18 +337,19 @@ def add(self, grouped_dataitem: List[RLDataFlowItem], partial_rollout_step: int
337337
# 1. 跟prompt相关的action_id记录
338338
if root_id in self._root2actions:
339339
# TODO: version 更新需要 根据是否update_weights来判断,需要考虑到非共卡的情况
340-
replay_meta.version += 1
341-
self.logger.info(f"Existing root_id: {root_id} found. Incrementing version to {replay_meta.version}.")
340+
replay_meta.version += 1 if partial_rollout_step > 0 else 0
341+
self.logger.info(f"Existing root_id: {root_id} with action_id {action_id} found. Incrementing version to {replay_meta.version}.")
342342
self._root2actions[root_id].append(action_id)
343343
else:
344344
self._root2actions[root_id] = [action_id]
345+
345346
self._actions[action_id] = replay_meta
346347

347348
# 2. 根据rollout状态加到finished, abort, abort_over_version队列中;Partial rollout is handled based on whether finish_reason is "abort".
348-
if replay_meta.state == ReplayState.INTERRUPTED and replay_meta.version < partial_rollout_step:
349+
if replay_meta.state == ReplayState.INTERRUPTED and (replay_meta.version < partial_rollout_step or partial_rollout_step == 0):
349350
self._interrupted_actions[replay_meta.version].append(action_id)
350351
self.logger.info(
351-
f"Add aborted sample with root_id: {root_id}, action_id: {action_id} to _interrupted_actions."
352+
f"Add aborted sample with action_id: {action_id} version: {replay_meta.version} to _interrupted_actions."
352353
)
353354
elif replay_meta.state == ReplayState.INTERRUPTED and replay_meta.version >= partial_rollout_step:
354355
self._expired_actions.append(action_id)
@@ -359,7 +360,7 @@ def add(self, grouped_dataitem: List[RLDataFlowItem], partial_rollout_step: int
359360
)
360361
elif replay_meta.state == ReplayState.COMPLETED:
361362
self._completed_actions[replay_meta.version].append(action_id)
362-
self.logger.debug(f"Add sample with root_id: {root_id}, action_id: {action_id} to finished_actions.")
363+
self.logger.info(f"Add sample with root_id: {root_id}, action_id: {action_id} to finished_actions.")
363364
elif replay_meta.state == ReplayState.FAILED:
364365
assert False, "Currently, failed samples are not supported in the replay buffer."
365366

@@ -541,7 +542,7 @@ def sample_from_expired_storage(self) -> List[RLDataFlowItem]:
541542

542543
# update env for expired samples
543544
for sample in group_samples:
544-
sample.data.input_ids = sample.data.input_ids[: sample.data.num_tokens]
545+
# sample.data.input_ids = sample.data.input_ids[: sample.data.num_tokens]
545546
sample.env = RLEnvDataItem()
546547
sample.uid.version = 0
547548
sample.extra_info.state = str(ReplayState.INIT)
@@ -560,20 +561,20 @@ def sample_from_interrupted_storage(self, tokenizer) -> List[RLDataFlowItem]:
560561
# update env for interrupted samples
561562
for sample in group_samples:
562563
assert sample.data.input_ids and sample.data.num_tokens, "input_ids or num_tokens is empty!"
563-
sample.data.input_ids = sample.data.input_ids[: sample.data.num_tokens]
564+
# sample.data.input_ids = sample.data.input_ids[: sample.data.num_tokens]
564565
sample.uid.action_id = int(uuid4().int)
565566
sample.uid.version = replay_meta.version
566567
sample.extra_info.state = str(ReplayState.INIT)
567-
if sample.env.rollout.response_ids and sample.data.input_ids:
568-
# TODO: response_ids 累加
569-
if "train_prompt_ids" in sample.data.extra_info:
570-
sample.data.input_ids = (
571-
sample.data.extra_info["train_prompt_ids"] + sample.env.rollout.response_ids
572-
)
573-
else:
574-
sample.data.input_ids.extend(sample.env.rollout.response_ids)
575-
elif sample.env.rollout.response:
576-
sample.data.input_ids.extend(tokenizer.encode(sample.env.rollout.response, add_special_tokens=False))
568+
# if sample.env.rollout.response_ids and sample.data.input_ids:
569+
# # TODO: response_ids 累加
570+
# if "train_prompt_ids" in sample.data.extra_info:
571+
# sample.data.input_ids = (
572+
# sample.data.extra_info["train_prompt_ids"] + sample.env.rollout.response_ids
573+
# )
574+
# else:
575+
# sample.data.input_ids.extend(sample.env.rollout.response_ids)
576+
# elif sample.env.rollout.response:
577+
# sample.data.input_ids.extend(tokenizer.encode(sample.env.rollout.response, add_special_tokens=False))
577578
self.logger.info(
578579
f"Sampling interrupted action_id: {action_id} from replay buffer, remain interrupted samples: {self.get_interrupted_samples()}"
579580
)

xtuner/v1/ray/environment/single_turn_env.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,16 @@ async def generate(
8787
sample.data.extra_info["action_id"] = sample.uid.action_id
8888
if sample.env.rollout.num_return_tokens > 0:
8989
sample.data.extra_info["num_return_tokens"] = sample.env.rollout.num_return_tokens
90-
self.logger.info(
91-
f"Set num_return_tokens: {sample.env.rollout.num_return_tokens} for sample {sample.uid}."
90+
sample.data.extra_info["response_ids"] = sample.env.rollout.response_ids
91+
sample.data.extra_info["response"] = sample.env.rollout.response
92+
sample.data.extra_info["logprobs"] = sample.env.rollout.logprobs
93+
assert len(sample.env.rollout.response_ids) == len(sample.env.rollout.logprobs), (
94+
f"num_return_tokens {sample.env.rollout.num_return_tokens} mismatch "
95+
f"len of response_ids {len(sample.env.rollout.response_ids)} and "
96+
f"len of logprobs {len(sample.env.rollout.logprobs)} for sample {sample.uid}."
97+
)
98+
self.logger.debug(
99+
f"Set num_return_tokens: {sample.env.rollout.num_return_tokens} and len of response_ids {len(sample.env.rollout.response_ids)} for sample {sample.uid}."
92100
)
93101
fut = self.rollout_controller.rollout.remote(
94102
prompt=sample.data.messages,

xtuner/v1/ray/rollout/controller.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,6 @@ async def rollout(
390390
format=format,
391391
extra_info=extra_info,
392392
)
393-
# if self.workers_info[server_url].running_count % 100 == 0:
394-
# log_msg = ""
395-
# for _, info in self.workers_info.items():
396-
# log_msg += f"rank {info.rank} worker info: {info}"
397-
# self.logger.info(log_msg)
398393
try:
399394
response = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout * 2)
400395
self.workers_info[server_url].success_count += 1

xtuner/v1/ray/rollout/lmdeploy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,12 @@ async def _create_request(
127127
if "num_return_tokens" in extra_info:
128128
max_return_tokens = sample_params["max_tokens"] - extra_info["num_return_tokens"]
129129
sample_params["max_tokens"] = max_return_tokens
130+
init_input_len = len(input_ids) if input_ids else 0
131+
payload["input_ids"] += extra_info["response_ids"]
130132
self.logger.info(
131-
f"Set max_tokens to {max_return_tokens} based on num_return_tokens {extra_info['num_return_tokens']}"
133+
f"Set max_tokens to {max_return_tokens} based on num_return_tokens {extra_info['num_return_tokens']}, init input_len: {init_input_len} and payload input len {len(payload['input_ids'])}."
132134
)
133-
134-
if self.enable_return_routed_experts:
135+
if self.enable_return_routed_experts:
135136
extra_params["return_routed_experts"] = True
136137

137138
lmdeploy_sample_params = self._transform_sample_params(sample_params, extra_params)

xtuner/v1/ray/rollout/worker.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import copy
32
import json
43
import multiprocessing
@@ -333,7 +332,13 @@ async def rollout_task(
333332
endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}"
334333

335334
while True:
336-
if extra_info.get("num_return_tokens", None) is not None and (sample_params["max_tokens"] - extra_info["num_return_tokens"]) == 0:
335+
if (
336+
extra_info.get("num_return_tokens", None) is not None
337+
and (sample_params["max_tokens"] - extra_info["num_return_tokens"]) == 0
338+
):
339+
self.logger.info(
340+
f"rollout request {uid} reached max tokens {sample_params['max_tokens']}, returning length finish_reason"
341+
)
337342
return RLRolloutResponseItem(
338343
response="",
339344
response_ids=[],
@@ -484,6 +489,9 @@ async def _handle_non_stream_response(self, uid, sample_params, extra_params, re
484489
routed_experts = ray.put(routed_experts)
485490
extra_info = {"routed_experts": routed_experts}
486491

492+
if finish_reason != "abort" and len(last_token_ids) == 0:
493+
self.logger.error(f"rollout request {uid} returned zero tokens with finish_reason {finish_reason}")
494+
487495
rollout_response = RLRolloutResponseItem(
488496
response=response["text"],
489497
response_ids=last_token_ids if len(last_token_ids) > 0 else None,

0 commit comments

Comments
 (0)