@@ -117,7 +117,7 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re
117117 group_state = determine_group_state (grouped_dataitem )
118118
119119 replay_state = ReplayState .from_str (group_state )
120- logger .debug (f"determined group_state: { group_state } , replay_state: { replay_state } " )
120+ logger .info (f"determined group_state: { group_state } , replay_state: { replay_state } , version: { version } " )
121121 replay_meta = ReplayMeta (
122122 env = env_str ,
123123 root_id = root_id ,
@@ -337,18 +337,19 @@ def add(self, grouped_dataitem: List[RLDataFlowItem], partial_rollout_step: int
337337 # 1. 跟prompt相关的action_id记录
338338 if root_id in self ._root2actions :
339339 # TODO: version 更新需要 根据是否update_weights来判断,需要考虑到非共卡的情况
340- replay_meta .version += 1
341- self .logger .info (f"Existing root_id: { root_id } found. Incrementing version to { replay_meta .version } ." )
340+ replay_meta .version += 1 if partial_rollout_step > 0 else 0
341+ self .logger .info (f"Existing root_id: { root_id } with action_id { action_id } found. Incrementing version to { replay_meta .version } ." )
342342 self ._root2actions [root_id ].append (action_id )
343343 else :
344344 self ._root2actions [root_id ] = [action_id ]
345+
345346 self ._actions [action_id ] = replay_meta
346347
347348 # 2. 根据rollout状态加到finished, abort, abort_over_version队列中;Partial rollout is handled based on whether finish_reason is "abort".
348- if replay_meta .state == ReplayState .INTERRUPTED and replay_meta .version < partial_rollout_step :
349+ if replay_meta .state == ReplayState .INTERRUPTED and ( replay_meta .version < partial_rollout_step or partial_rollout_step == 0 ) :
349350 self ._interrupted_actions [replay_meta .version ].append (action_id )
350351 self .logger .info (
351- f"Add aborted sample with root_id : { root_id } , action_id : { action_id } to _interrupted_actions."
352+ f"Add aborted sample with action_id : { action_id } version : { replay_meta . version } to _interrupted_actions."
352353 )
353354 elif replay_meta .state == ReplayState .INTERRUPTED and replay_meta .version >= partial_rollout_step :
354355 self ._expired_actions .append (action_id )
@@ -359,7 +360,7 @@ def add(self, grouped_dataitem: List[RLDataFlowItem], partial_rollout_step: int
359360 )
360361 elif replay_meta .state == ReplayState .COMPLETED :
361362 self ._completed_actions [replay_meta .version ].append (action_id )
362- self .logger .debug (f"Add sample with root_id: { root_id } , action_id: { action_id } to finished_actions." )
363+ self .logger .info (f"Add sample with root_id: { root_id } , action_id: { action_id } to finished_actions." )
363364 elif replay_meta .state == ReplayState .FAILED :
364365 assert False , "Currently, failed samples are not supported in the replay buffer."
365366
@@ -541,7 +542,7 @@ def sample_from_expired_storage(self) -> List[RLDataFlowItem]:
541542
542543 # update env for expired samples
543544 for sample in group_samples :
544- sample .data .input_ids = sample .data .input_ids [: sample .data .num_tokens ]
545+ # sample.data.input_ids = sample.data.input_ids[: sample.data.num_tokens]
545546 sample .env = RLEnvDataItem ()
546547 sample .uid .version = 0
547548 sample .extra_info .state = str (ReplayState .INIT )
@@ -560,20 +561,20 @@ def sample_from_interrupted_storage(self, tokenizer) -> List[RLDataFlowItem]:
560561 # update env for interrupted samples
561562 for sample in group_samples :
562563 assert sample .data .input_ids and sample .data .num_tokens , "input_ids or num_tokens is empty!"
563- sample .data .input_ids = sample .data .input_ids [: sample .data .num_tokens ]
564+ # sample.data.input_ids = sample.data.input_ids[: sample.data.num_tokens]
564565 sample .uid .action_id = int (uuid4 ().int )
565566 sample .uid .version = replay_meta .version
566567 sample .extra_info .state = str (ReplayState .INIT )
567- if sample .env .rollout .response_ids and sample .data .input_ids :
568- # TODO: response_ids 累加
569- if "train_prompt_ids" in sample .data .extra_info :
570- sample .data .input_ids = (
571- sample .data .extra_info ["train_prompt_ids" ] + sample .env .rollout .response_ids
572- )
573- else :
574- sample .data .input_ids .extend (sample .env .rollout .response_ids )
575- elif sample .env .rollout .response :
576- sample .data .input_ids .extend (tokenizer .encode (sample .env .rollout .response , add_special_tokens = False ))
568+ # if sample.env.rollout.response_ids and sample.data.input_ids:
569+ # # TODO: response_ids 累加
570+ # if "train_prompt_ids" in sample.data.extra_info:
571+ # sample.data.input_ids = (
572+ # sample.data.extra_info["train_prompt_ids"] + sample.env.rollout.response_ids
573+ # )
574+ # else:
575+ # sample.data.input_ids.extend(sample.env.rollout.response_ids)
576+ # elif sample.env.rollout.response:
577+ # sample.data.input_ids.extend(tokenizer.encode(sample.env.rollout.response, add_special_tokens=False))
577578 self .logger .info (
578579 f"Sampling interrupted action_id: { action_id } from replay buffer, remain interrupted samples: { self .get_interrupted_samples ()} "
579580 )
0 commit comments