Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions tests/rl/experimental/agentic/QueueManager/group_queue_manager_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for group_queue_manager."""

import asyncio

from absl.testing import absltest
from tunix.rl.experimental.agentic.QueueManager import group_queue_manager


def _create_item(
group_id: str, episode_id: int, pair_index: int = 0
) -> group_queue_manager.TrajectoryItem:
"""Helper to create a TrajectoryItem for testing."""
return group_queue_manager.TrajectoryItem(
pair_index=pair_index,
group_id=group_id,
episode_id=episode_id,
start_step=0,
traj=None,
)


class GroupQueueManagerTest(absltest.TestCase):

def test_put_and_get_simple_batch(self):
"""Tests basic put and get functionality."""
async def _run_test():
manager = group_queue_manager.GroupQueueManager(group_size=2)
item1 = _create_item("g1", 1)
item2 = _create_item("g1", 1)

await manager.put(item1)
self.assertEqual(manager._open_bucket_count(), 1)
self.assertEmpty(manager._ready_groups)

await manager.put(item2)
self.assertEqual(manager._open_bucket_count(), 0)
self.assertLen(manager._ready_groups, 1)

batch = await manager.get_batch(2)
self.assertLen(batch, 2)
self.assertCountEqual([item1, item2], batch)
asyncio.run(_run_test())

def test_get_batch_waits_for_items(self):
"""Tests that get_batch waits until a group is ready."""
async def _run_test():
manager = group_queue_manager.GroupQueueManager(group_size=2)
item1 = _create_item("g1", 1)
item2 = _create_item("g1", 1)

async def producer():
await asyncio.sleep(0.01)
await manager.put(item1)
await asyncio.sleep(0.01)
await manager.put(item2)

producer_task = asyncio.create_task(producer())
batch = await manager.get_batch(2)

self.assertLen(batch, 2)
await producer_task
asyncio.run(_run_test())

def test_batching_with_leftovers(self):
"""Tests batching where a group is split across two get_batch calls."""
async def _run_test():
manager = group_queue_manager.GroupQueueManager(group_size=3)
items = [_create_item("g1", 1, i) for i in range(3)]
for item in items:
await manager.put(item)

self.assertEmpty(manager._batch_buf)
batch1 = await manager.get_batch(2)
self.assertLen(batch1, 2)
self.assertCountEqual(items[:2], batch1)
self.assertLen(manager._batch_buf, 1)
self.assertEqual(manager._batch_buf[0], items[2])

batch2 = await manager.get_batch(1)
self.assertLen(batch2, 1)
self.assertEqual(batch2[0], items[2])
self.assertEmpty(manager._batch_buf)
asyncio.run(_run_test())

def test_max_open_buckets(self):
"""Tests that put blocks when max_open_buckets is reached."""
async def _run_test():
manager = group_queue_manager.GroupQueueManager(
group_size=2, max_open_buckets=1
)
item_g1 = _create_item("g1", 1)
item_g2 = _create_item("g2", 1)

await manager.put(item_g1)

put_task = asyncio.create_task(manager.put(item_g2))

await asyncio.sleep(0.01)
self.assertFalse(put_task.done())

await manager.put(_create_item("g1", 1))
await asyncio.sleep(0.01)

await put_task
self.assertEqual(manager._open_bucket_count(), 1)
self.assertIn(("g2", 1), manager._buckets)
asyncio.run(_run_test())

def test_put_exception(self):
"""Tests that an exception is propagated to put and get calls."""
async def _run_test():
manager = group_queue_manager.GroupQueueManager(group_size=2)
exc = ValueError("Test Exception")
manager.put_exception(exc)

with self.assertRaises(ValueError):
await manager.put(_create_item("g1", 1))

with self.assertRaises(ValueError):
await manager.get_batch(1)
asyncio.run(_run_test())


if __name__ == "__main__":
absltest.main()
174 changes: 174 additions & 0 deletions tunix/rl/experimental/agentic/QueueManager/group_queue_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Manages queues of trajectory items, grouping them by group_id and episode_id."""

from __future__ import annotations
import asyncio
import collections
from collections.abc import Hashable
import dataclasses
from typing import Any, Deque, Dict, List, Optional, Tuple
from tunix.rl.experimental.agentic.agents import agent_types

Trajectory = agent_types.Trajectory
field = dataclasses.field
dataclass = dataclasses.dataclass


@dataclass
class TrajectoryItem:
pair_index: int
group_id: Hashable
episode_id: int
start_step: int
traj: Trajectory
meta: Dict[str, Any] = field(default_factory=dict)


class GroupQueueManager:
"""Manages queues of trajectory items, grouping them by group_id and episode_id.
This class collects `TrajectoryItem` instances into buckets based on their
`(group_id, episode_id)`. Once a bucket reaches `group_size`, it becomes a
"ready group" and can be retrieved in batches. It also handles managing the
number of open buckets and provides mechanisms for clearing and handling
exceptions.
"""

def __init__(
self,
*,
group_size: int,
max_open_buckets: Optional[int] = None,
):
self.group_size = group_size
self.max_open_buckets = max_open_buckets or 0
self._buckets: Dict[Tuple[Hashable, int], List[TrajectoryItem]] = {}
self._ready_groups: Deque[List[TrajectoryItem]] = collections.deque()
self._clearing = False
self._exc: Optional[BaseException] = None
self._lock = asyncio.Lock()
self._capacity = asyncio.Condition(self._lock)
self._have_ready = asyncio.Event()
self._batch_buf: List[TrajectoryItem] = []
self._notify_all_task: Optional[asyncio.Task[None]] = None

def put_exception(self, exc: BaseException):
self._exc = exc
self._have_ready.set()

async def _notify_all():
async with self._capacity:
self._capacity.notify_all()

self._notify_all_task = asyncio.create_task(_notify_all())

async def prepare_clear(self):
self._clearing = True
self._have_ready.set()
async with self._capacity:
self._capacity.notify_all()

async def clear(self):
async with self._lock:
self._buckets.clear()
self._ready_groups.clear()
self._batch_buf.clear()
self._exc = None
self._clearing = False
self._have_ready = asyncio.Event()

async def put(self, item: TrajectoryItem):
"""Adds an item, grouping by `(group_id, episode_id)`.
Items are grouped in buckets. When a bucket reaches `self.group_size`, it's
moved to `_ready_groups`. Waits if `max_open_buckets` is exceeded.
Args:
item: The TrajectoryItem to add.
Raises:
BaseException: If an exception has been set via `put_exception`.
"""
if self._clearing:
return
if self._exc:
raise self._exc
key = (item.group_id, item.episode_id)
async with self._capacity:
new_bucket = key not in self._buckets
while (
(not self._clearing)
and (self.max_open_buckets > 0)
and new_bucket
and (self._open_bucket_count() >= self.max_open_buckets)
):
await self._capacity.wait()
if self._clearing:
return
if self._exc:
raise self._exc
bucket = self._buckets.setdefault(key, [])
bucket.append(item)
if len(bucket) == self.group_size:
self._ready_groups.append(bucket.copy())
del self._buckets[key]
self._capacity.notify_all()
self._have_ready.set()

async def _get_one_ready_group(self) -> List[TrajectoryItem]:
while True:
if self._exc:
raise self._exc
if self._clearing:
return []
if self._ready_groups:
return self._ready_groups.popleft()
await self._have_ready.wait()
self._have_ready.clear()

async def get_batch(self, batch_size: int) -> List[TrajectoryItem]:
"""Retrieves a batch of TrajectoryItems, waiting until enough are ready.
Items are taken from `_batch_buf` and then from `_ready_groups`. Excess
items from groups are buffered in `_batch_buf`.
Args:
batch_size: The desired number of TrajectoryItems.
Returns:
A list of `TrajectoryItem` instances, up to `batch_size`.
"""
out = []
if self._batch_buf:
take = min(batch_size, len(self._batch_buf))
out.extend(self._batch_buf[:take])
self._batch_buf = self._batch_buf[take:]
if len(out) == batch_size:
return out
while len(out) < batch_size:
group = await self._get_one_ready_group()
if not group:
break
room = batch_size - len(out)
if len(group) <= room:
out.extend(group)
else:
out.extend(group[:room])
self._batch_buf.extend(group[room:])
return out

def _open_bucket_count(self) -> int:
return len(self._buckets)