@@ -122,6 +122,7 @@ def __init__(
122122 self .finished_samples_count = 0
123123 self .failed_samples_count = 0
124124 self .skipped_sample_count = 0
125+ self .filtered_sample_count = 0
125126 self .sample_from_expired_storage = False
126127 self .logger = get_logger (log_dir = self .config .worker_log_dir , tag = "DataFlow" )
127128 self .target_batch_size = self .config .global_batch_size
@@ -161,6 +162,7 @@ def _reset_internal_states_on_step(
161162 self .finished_samples_count = 0
162163 self .failed_samples_count = 0
163164 self .skipped_sample_count = 0
165+ self .filtered_sample_count = 0
164166 self .logger .info (
165167 f"global_batch_size: { global_batch_size } , sample_params: { sample_params } , extra_params: { extra_params } "
166168 )
@@ -224,6 +226,8 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
224226 group_data_items = await self .replay_buffer .post_processor .remote (group_data_items ) # type: ignore[attr-defined]
225227 if len (group_data_items ) > 0 :
226228 await self .replay_buffer .add .remote (group_data_items ) # type: ignore[attr-defined]
229+ else :
230+ self .filtered_sample_count += 1
227231 self .logger .debug (f"Worker task completed successfully for { action_id } ." )
228232 elif group_state == "interrupted" :
229233 await self .replay_buffer .add .remote (group_data_items ) # type: ignore[attr-defined]
@@ -289,17 +293,17 @@ async def concurrent_task_runner(self):
289293 cleanup_start_time = time .monotonic ()
290294 cleanup_timeout = 10 * 60 # 10 minutes in seconds
291295 while len (waiting_tasks ) > 0 :
292- elapsed_time = time .monotonic () - cleanup_start_time
293- if elapsed_time > cleanup_timeout :
294- self .logger .warning (
295- f"Cleanup timeout of { cleanup_timeout } s reached. "
296- f"Forcefully cancelling { len (waiting_tasks )} remaining tasks."
297- )
298- for task in waiting_tasks :
299- task .cancel ()
300- # Wait for cancellations to complete
301- await asyncio .gather (* waiting_tasks , return_exceptions = True )
302- break # Exit the cleanup loop
296+ # elapsed_time = time.monotonic() - cleanup_start_time
297+ # if elapsed_time > cleanup_timeout:
298+ # self.logger.warning(
299+ # f"Cleanup timeout of {cleanup_timeout}s reached. "
300+ # f"Forcefully cancelling {len(waiting_tasks)} remaining tasks."
301+ # )
302+ # for task in waiting_tasks:
303+ # task.cancel()
304+ # # Wait for cancellations to complete
305+ # await asyncio.gather(*waiting_tasks, return_exceptions=True)
306+ # break # Exit the cleanup loop
303307 done_tasks , pending_tasks = await asyncio .wait (
304308 waiting_tasks , timeout = 0.1 , return_when = asyncio .FIRST_COMPLETED
305309 )
@@ -309,7 +313,8 @@ async def concurrent_task_runner(self):
309313 self .logger .info ("All worker tasks have completed after pausing env controller." )
310314
311315 self .logging_replaybuffer_state ()
312-
316+ self .logger .info (ray .get (self .env_controller .get_rollout_stats .remote ()))
317+
313318 async def pause (self , timeout : float = 60.0 ):
314319 """Asynchronously sends abort requests to all rollout workers."""
315320 rollout_info = ray .get (self .env_controller .get_rollout_info .remote ()) # type: ignore[attr-defined]
@@ -378,6 +383,7 @@ def logging_replaybuffer_state(self, logging_msg: Optional[str] = None):
378383 status = self .get_replaybuffer_status ()
379384 logging_msg = logging_msg if logging_msg else ""
380385 logging_msg += f"ReplayBuffer Status: { status } "
386+ logging_msg += f", Filtered samples count: { self .filtered_sample_count } "
381387 self .logger .info (logging_msg )
382388
383389 def get_replaybuffer_status (self ):
0 commit comments