From 587ff0bf0372688a41846b82b05ef2713e4ce86d Mon Sep 17 00:00:00 2001 From: Haoyu Gao Date: Tue, 7 Oct 2025 20:21:37 -0700 Subject: [PATCH] Add `GroupQueueManager` for trajectory item grouping. PiperOrigin-RevId: 816501290 --- .../QueueManager/group_queue_manager_test.py | 139 +++++++++++++++ .../QueueManager/group_queue_manager.py | 165 ++++++++++++++++++ .../agentic/agents/agent_types.py | 48 +++-- 3 files changed, 341 insertions(+), 11 deletions(-) create mode 100644 tests/rl/experimental/agentic/QueueManager/group_queue_manager_test.py create mode 100644 tunix/rl/experimental/agentic/QueueManager/group_queue_manager.py diff --git a/tests/rl/experimental/agentic/QueueManager/group_queue_manager_test.py b/tests/rl/experimental/agentic/QueueManager/group_queue_manager_test.py new file mode 100644 index 00000000..16e463bd --- /dev/null +++ b/tests/rl/experimental/agentic/QueueManager/group_queue_manager_test.py @@ -0,0 +1,139 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for group_queue_manager.""" + +import asyncio + +from absl.testing import absltest +from tunix.rl.experimental.agentic.QueueManager import group_queue_manager + + +def _create_item( + group_id: str, episode_id: int, pair_index: int = 0 +) -> group_queue_manager.TrajectoryItem: + """Helper to create a TrajectoryItem for testing.""" + return group_queue_manager.TrajectoryItem( + pair_index=pair_index, + group_id=group_id, + episode_id=episode_id, + start_step=0, + traj=None, + ) + + +class GroupQueueManagerTest(absltest.TestCase): + + def test_put_and_get_simple_batch(self): + """Tests basic put and get functionality.""" + async def _run_test(): + manager = group_queue_manager.GroupQueueManager(group_size=2) + item1 = _create_item("g1", 1) + item2 = _create_item("g1", 1) + + await manager.put(item1) + self.assertEqual(manager._open_bucket_count(), 1) + self.assertEmpty(manager._ready_groups) + + await manager.put(item2) + self.assertEqual(manager._open_bucket_count(), 0) + self.assertLen(manager._ready_groups, 1) + + batch = await manager.get_batch(2) + self.assertLen(batch, 2) + self.assertCountEqual([item1, item2], batch) + asyncio.run(_run_test()) + + def test_get_batch_waits_for_items(self): + """Tests that get_batch waits until a group is ready.""" + async def _run_test(): + manager = group_queue_manager.GroupQueueManager(group_size=2) + item1 = _create_item("g1", 1) + item2 = _create_item("g1", 1) + + async def producer(): + await asyncio.sleep(0.01) + await manager.put(item1) + await asyncio.sleep(0.01) + await manager.put(item2) + + producer_task = asyncio.create_task(producer()) + batch = await manager.get_batch(2) + + self.assertLen(batch, 2) + await producer_task + asyncio.run(_run_test()) + + def test_batching_with_leftovers(self): + """Tests batching where a group is split across two get_batch calls.""" + async def _run_test(): + manager = group_queue_manager.GroupQueueManager(group_size=3) + items = [_create_item("g1", 1, i) for i in range(3)] + for item in items: + await manager.put(item) + + self.assertEmpty(manager._batch_buf) + batch1 = await manager.get_batch(2) + self.assertLen(batch1, 2) + self.assertCountEqual(items[:2], batch1) + self.assertLen(manager._batch_buf, 1) + self.assertEqual(manager._batch_buf[0], items[2]) + + batch2 = await manager.get_batch(1) + self.assertLen(batch2, 1) + self.assertEqual(batch2[0], items[2]) + self.assertEmpty(manager._batch_buf) + asyncio.run(_run_test()) + + def test_max_open_buckets(self): + """Tests that put blocks when max_open_buckets is reached.""" + async def _run_test(): + manager = group_queue_manager.GroupQueueManager( + group_size=2, max_open_buckets=1 + ) + item_g1 = _create_item("g1", 1) + item_g2 = _create_item("g2", 1) + + await manager.put(item_g1) + + put_task = asyncio.create_task(manager.put(item_g2)) + + await asyncio.sleep(0.01) + self.assertFalse(put_task.done()) + + await manager.put(_create_item("g1", 1)) + await asyncio.sleep(0.01) + + await put_task + self.assertEqual(manager._open_bucket_count(), 1) + self.assertIn(("g2", 1), manager._buckets) + asyncio.run(_run_test()) + + def test_put_exception(self): + """Tests that an exception is propagated to put and get calls.""" + async def _run_test(): + manager = group_queue_manager.GroupQueueManager(group_size=2) + exc = ValueError("Test Exception") + manager.put_exception(exc) + + with self.assertRaises(ValueError): + await manager.put(_create_item("g1", 1)) + + with self.assertRaises(ValueError): + await manager.get_batch(1) + asyncio.run(_run_test()) + + +if __name__ == "__main__": + absltest.main() diff --git a/tunix/rl/experimental/agentic/QueueManager/group_queue_manager.py b/tunix/rl/experimental/agentic/QueueManager/group_queue_manager.py new file mode 100644 index 00000000..f806cd3b --- /dev/null +++ b/tunix/rl/experimental/agentic/QueueManager/group_queue_manager.py @@ -0,0 +1,165 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Manages queues of trajectory items, grouping them by group_id and episode_id.""" + +from __future__ import annotations +import asyncio +import collections +from collections.abc import Hashable +import dataclasses +from typing import Deque, Dict, List, Optional, Tuple +from tunix.rl.experimental.agentic.agents import agent_types + +Trajectory = agent_types.Trajectory +TrajectoryItem = agent_types.TrajectoryItem +field = dataclasses.field +dataclass = dataclasses.dataclass + + +class GroupQueueManager: + """Manages queues of trajectory items, grouping them by group_id and episode_id. + + This class collects `TrajectoryItem` instances into buckets based on their + `(group_id, episode_id)`. Once a bucket reaches `group_size`, it becomes a + "ready group" and can be retrieved in batches. It also handles managing the + number of open buckets and provides mechanisms for clearing and handling + exceptions. + """ + + def __init__( + self, + *, + group_size: int, + max_open_buckets: Optional[int] = None, + ): + self.group_size = group_size + self.max_open_buckets = max_open_buckets or 0 + self._buckets: Dict[Tuple[Hashable, int], List[TrajectoryItem]] = {} + self._ready_groups: Deque[List[TrajectoryItem]] = collections.deque() + self._clearing = False + self._exc: Optional[BaseException] = None + self._lock = asyncio.Lock() + self._capacity = asyncio.Condition(self._lock) + self._have_ready = asyncio.Event() + self._batch_buf: List[TrajectoryItem] = [] + self._notify_all_task: Optional[asyncio.Task[None]] = None + + def put_exception(self, exc: BaseException): + self._exc = exc + self._have_ready.set() + + async def _notify_all(): + async with self._capacity: + self._capacity.notify_all() + + self._notify_all_task = asyncio.create_task(_notify_all()) + + async def prepare_clear(self): + self._clearing = True + self._have_ready.set() + async with self._capacity: + self._capacity.notify_all() + + async def clear(self): + async with self._lock: + self._buckets.clear() + self._ready_groups.clear() + self._batch_buf.clear() + self._exc = None + self._clearing = False + self._have_ready = asyncio.Event() + + async def put(self, item: TrajectoryItem): + """Adds an item, grouping by `(group_id, episode_id)`. + + Items are grouped in buckets. When a bucket reaches `self.group_size`, it's + moved to `_ready_groups`. Waits if `max_open_buckets` is exceeded. + + Args: + item: The TrajectoryItem to add. + + Raises: + BaseException: If an exception has been set via `put_exception`. + """ + if self._clearing: + return + if self._exc: + raise self._exc + key = (item.group_id, item.episode_id) + async with self._capacity: + new_bucket = key not in self._buckets + while ( + (not self._clearing) + and (self.max_open_buckets > 0) + and new_bucket + and (self._open_bucket_count() >= self.max_open_buckets) + ): + await self._capacity.wait() + if self._clearing: + return + if self._exc: + raise self._exc + bucket = self._buckets.setdefault(key, []) + bucket.append(item) + if len(bucket) == self.group_size: + self._ready_groups.append(bucket.copy()) + del self._buckets[key] + self._capacity.notify_all() + self._have_ready.set() + + async def _get_one_ready_group(self) -> List[TrajectoryItem]: + while True: + if self._exc: + raise self._exc + if self._clearing: + return [] + if self._ready_groups: + return self._ready_groups.popleft() + await self._have_ready.wait() + self._have_ready.clear() + + async def get_batch(self, batch_size: int) -> List[TrajectoryItem]: + """Retrieves a batch of TrajectoryItems, waiting until enough are ready. + + Items are taken from `_batch_buf` and then from `_ready_groups`. Excess + items from groups are buffered in `_batch_buf`. + + Args: + batch_size: The desired number of TrajectoryItems. + + Returns: + A list of `TrajectoryItem` instances, up to `batch_size`. + """ + out = [] + if self._batch_buf: + take = min(batch_size, len(self._batch_buf)) + out.extend(self._batch_buf[:take]) + self._batch_buf = self._batch_buf[take:] + if len(out) == batch_size: + return out + while len(out) < batch_size: + group = await self._get_one_ready_group() + if not group: + break + room = batch_size - len(out) + if len(group) <= room: + out.extend(group) + else: + out.extend(group[:room]) + self._batch_buf.extend(group[room:]) + return out + + def _open_bucket_count(self) -> int: + return len(self._buckets) diff --git a/tunix/rl/experimental/agentic/agents/agent_types.py b/tunix/rl/experimental/agentic/agents/agent_types.py index 68e50e10..26666c17 100644 --- a/tunix/rl/experimental/agentic/agents/agent_types.py +++ b/tunix/rl/experimental/agentic/agents/agent_types.py @@ -5,8 +5,9 @@ and complete episode trajectories. """ +from collections.abc import Hashable import dataclasses -from typing import Any, Optional +from typing import Any, Dict, Optional field = dataclasses.field dataclass = dataclasses.dataclass @@ -79,16 +80,41 @@ class Trajectory: steps: list[Step] = field(default_factory=list) reward: float = 0.0 - def to_dict(self) -> dict[str, Any]: - """Convert trajectory to dictionary format for serialization. - Useful for logging, storage, or transmission over APIs. All Step objects +@dataclass +class TrajectoryItem: + """Represents an item within a Trajectory, potentially for pairing or grouping. + + Attributes: + pair_index (int): Index for pairing. + group_id (collections.abc.Hashable): Identifier for grouping trajectories. + episode_id (int): Unique identifier for the episode. + start_step (int): The starting step index within the full trajectory. + traj (Trajectory): The Trajectory object itself. + metadata (Dict[str, Any]): Additional metadata. + """ + + pair_index: int + group_id: Hashable + episode_id: int + start_step: int + traj: Trajectory + metadata: Dict[str, Any] = field(default_factory=dict) + + +def to_dict(self) -> dict[str, Any]: + """Convert trajectory to dictionary format for serialization. + + Useful for logging, storage, or transmission over APIs. All Step objects are recursively converted to dictionaries using dataclass serialization. - Returns: - dict: Serializable dictionary representation of the trajectory - """ - return { - "steps": [asdict(step) for step in self.steps], - "reward": float(self.reward), - } + Args: + self: The Trajectory object to convert. + + Returns: + dict: Serializable dictionary representation of the trajectory + """ + return { + "steps": [asdict(step) for step in self.steps], + "reward": float(self.reward), + }