Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
b8d691f
feat: Introduce get_trajectory function and enhance gather.py typing
bradhilton Aug 21, 2025
d1714de
refactor: Clean up code formatting and update server port in test not…
bradhilton Aug 21, 2025
0fc4500
refactor: Simplify server startup message in test notebook
bradhilton Aug 21, 2025
74ad944
feat: Add asyncio support with pytest-asyncio and implement trajector…
bradhilton Aug 21, 2025
e2cb601
feat: Enhance trajectory support in tests by integrating optional ART…
bradhilton Aug 21, 2025
583b777
feat: Introduce yield_trajectory and capture_yielded_trajectory for e…
bradhilton Aug 21, 2025
e31047f
Merge branch 'main' into feat/contextual-trajectory
bradhilton Aug 22, 2025
86d1b85
feat(auto-trajectory): add auto_trajectory and capture_auto_trajector…
bradhilton Aug 22, 2025
8a2deb3
feat(auto-trajectory): enhance trajectory handling with HTTPX respons…
bradhilton Aug 22, 2025
127d3b0
feat(tests): add unit tests for tokenize_trajectory_groups functional…
bradhilton Aug 22, 2025
553b76d
feat(tests): enhance auto_trajectory tests with tool integration and …
bradhilton Aug 22, 2025
4db00f3
feat(auto-trajectory): integrate synchronous chat completion streamin…
bradhilton Aug 22, 2025
4392bcf
refactor(tests): update comment for optional ART support in auto_traj…
bradhilton Aug 22, 2025
3349d7f
feat(tests): add unit tests for litellm auto trajectory handling and …
bradhilton Aug 22, 2025
07d92f2
feat(tests): add api_key parameter to litellm auto trajectory tests f…
bradhilton Aug 22, 2025
93e5ffd
refactor(gather): simplify type imports and clean up after_each param…
bradhilton Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
536 changes: 536 additions & 0 deletions dev/data.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/art/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ",".join(conf)

from . import dev
from .auto_trajectory import auto_trajectory, capture_auto_trajectory
from .backend import Backend
from .batches import trajectory_group_batches
from .gather import gather_trajectories, gather_trajectory_groups
Expand All @@ -28,6 +29,8 @@

__all__ = [
"dev",
"auto_trajectory",
"capture_auto_trajectory",
"gather_trajectories",
"gather_trajectory_groups",
"trajectory_group_batches",
Expand Down
145 changes: 145 additions & 0 deletions src/art/auto_trajectory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import contextvars
import json
from typing import Any, AsyncIterator, Coroutine, Iterator, Literal, overload

import httpx._models
from openai import OpenAI
from openai._streaming import Stream
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

from .openai import consume_sync_chat_completion_stream
from .trajectories import History, Trajectory


@overload
def auto_trajectory(*, required: Literal[True]) -> Trajectory: ...


@overload
def auto_trajectory(*, required: Literal[False] = False) -> Trajectory | None: ...


def auto_trajectory(*, required: bool = False) -> Trajectory | None:
context = auto_trajectory_context_var.get(None)
if context is None:
if required:
raise RuntimeError(
"No auto trajectory in context. `auto_trajectory(required=True)` must be called in a `capture_auto_trajectory(...)` scope."
)
return None
return context.trajectory


async def capture_auto_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory:
with AutoTrajectoryContext():
await coroutine
trajectory = auto_trajectory_context_var.get().trajectory
trajectory.finish()
return trajectory


class AutoTrajectoryContext:
def __init__(self) -> None:
self.trajectory = Trajectory(
messages_and_choices=[],
reward=0.0,
)
self.openai_client = OpenAI(api_key="")

def __enter__(self) -> None:
self.token = auto_trajectory_context_var.set(self)

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
auto_trajectory_context_var.reset(self.token)

def handle_httpx_response(self, response: httpx._models.Response) -> None:
try:
request_content = json.loads(getattr(response.request, "_content", b""))
messages = request_content["messages"]
tools = request_content.get("tools", None)
setattr(response, "_content", getattr(response, "_content_so_far", b""))
print(getattr(response, "_content"))
if request_content.get("stream", False):
choice = consume_sync_chat_completion_stream(
Stream(
cast_to=ChatCompletionChunk,
response=response,
client=self.openai_client,
)
).choices[0]
else:
choice = Choice(
**json.loads(getattr(response, "_content"))["choices"][0]
)
history: Trajectory | History = self.trajectory
history_index = -1
while True:
history_messages = history.messages()
if history_messages == messages[: len(history_messages)] and (
history.tools == tools
or (history_messages == [] and history.tools is None)
):
break
history_index += 1
try:
history = self.trajectory.additional_histories[history_index]
except IndexError:
history = History(messages_and_choices=[])
self.trajectory.additional_histories.append(history)
break
history.messages_and_choices.extend(
messages[len(history.messages_and_choices) :]
)
history.messages_and_choices.append(choice)
history.tools = tools
except:
pass


auto_trajectory_context_var: contextvars.ContextVar[AutoTrajectoryContext] = (
contextvars.ContextVar("auto_trajectory_context")
)


def patch_httpx() -> None:
original_iter_bytes = httpx._models.Response.iter_bytes
original_aiter_bytes = httpx._models.Response.aiter_bytes
original_close = httpx._models.Response.close
original_aclose = httpx._models.Response.aclose

def patched_iter_bytes(
self: httpx._models.Response, chunk_size: int | None = None
) -> Iterator[bytes]:
for chunk in original_iter_bytes(self, chunk_size):
setattr(
self, "_content_so_far", getattr(self, "_content_so_far", b"") + chunk
)
yield chunk

async def patched_aiter_bytes(
self: httpx._models.Response, chunk_size: int | None = None
) -> AsyncIterator[bytes]:
async for chunk in original_aiter_bytes(self, chunk_size):
setattr(
self, "_content_so_far", getattr(self, "_content_so_far", b"") + chunk
)
yield chunk

def patched_close(self: httpx._models.Response) -> None:
original_close(self)
if context := auto_trajectory_context_var.get(None):
context.handle_httpx_response(self)

async def patched_aclose(self: httpx._models.Response) -> None:
await original_aclose(self)
if context := auto_trajectory_context_var.get(None):
context.handle_httpx_response(self)

httpx._models.Response.iter_bytes = patched_iter_bytes
httpx._models.Response.aiter_bytes = patched_aiter_bytes
httpx._models.Response.close = patched_close
httpx._models.Response.aclose = patched_aclose


patch_httpx()
184 changes: 103 additions & 81 deletions src/art/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, AsyncIterator, Callable, cast

import openai
from openai import AsyncStream
from openai import AsyncStream, Stream
from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_message import (
Expand Down Expand Up @@ -82,86 +82,8 @@ async def consume_chat_completion_stream(
chat_completion: ChatCompletion | None = None
async for chunk in stream:
if chat_completion is None:
chat_completion = ChatCompletion(
id=chunk.id,
choices=[
Choice(
finish_reason="stop",
index=choice.index,
logprobs=(ChoiceLogprobs() if choice.logprobs else None),
message=ChatCompletionMessage(role="assistant"),
)
for choice in chunk.choices
],
created=chunk.created,
model=chunk.model,
object="chat.completion",
)
for choice, chunk_choice in zip(chat_completion.choices, chunk.choices):
choice.finish_reason = chunk_choice.finish_reason or "stop"
if chunk_choice.logprobs:
if choice.logprobs is None:
choice.logprobs = ChoiceLogprobs()
if chunk_choice.logprobs.content:
if choice.logprobs.content is None:
choice.logprobs.content = []
choice.logprobs.content.extend(chunk_choice.logprobs.content)
if chunk_choice.logprobs.refusal:
if choice.logprobs.refusal is None:
choice.logprobs.refusal = []
choice.logprobs.refusal.extend(chunk_choice.logprobs.refusal)
if chunk_choice.delta.content:
if choice.message.content is None:
choice.message.content = ""
choice.message.content += chunk_choice.delta.content
if chunk_choice.delta.refusal:
if choice.message.refusal is None:
choice.message.refusal = ""
choice.message.refusal += chunk_choice.delta.refusal
if chunk_choice.delta.function_call:
if choice.message.function_call is None:
choice.message.function_call = FunctionCall(arguments="", name="")
choice.message.function_call.name += (
chunk_choice.delta.function_call.name or ""
)
choice.message.function_call.arguments += (
chunk_choice.delta.function_call.arguments or ""
)
if chunk_choice.delta.tool_calls:
if choice.message.tool_calls is None:
choice.message.tool_calls = []
for tool_call in chunk_choice.delta.tool_calls:
while tool_call.index not in range(len(choice.message.tool_calls)):
choice.message.tool_calls.append(
ChatCompletionMessageToolCall(
id="",
function=Function(arguments="", name=""),
type="function",
)
)
if tool_call.id:
choice.message.tool_calls[tool_call.index].id = tool_call.id
if tool_call.function:
if tool_call.function.name:
choice.message.tool_calls[
tool_call.index
].function.name = tool_call.function.name
if tool_call.function.arguments:
choice.message.tool_calls[
tool_call.index
].function.arguments += tool_call.function.arguments
if getattr(chunk_choice.delta, "reasoning", None):
if not hasattr(choice.message, "reasoning"):
setattr(choice.message, "reasoning", "")
setattr(
choice.message,
"reasoning",
getattr(choice.message, "reasoning")
+ getattr(chunk_choice.delta, "reasoning"),
)
chat_completion.service_tier = chunk.service_tier
chat_completion.system_fingerprint = chunk.system_fingerprint
chat_completion.usage = chunk.usage
chat_completion = init_chat_completion(chunk)
update_chat_completion(chat_completion, chunk)
if on_chunk:
try:
on_chunk(chunk, chat_completion)
Expand All @@ -170,3 +92,103 @@ async def consume_chat_completion_stream(
break
assert chat_completion is not None
return chat_completion


def consume_sync_chat_completion_stream(
stream: Stream[ChatCompletionChunk],
) -> ChatCompletion:
chat_completion: ChatCompletion | None = None
for chunk in stream:
if chat_completion is None:
chat_completion = init_chat_completion(chunk)
update_chat_completion(chat_completion, chunk)
assert chat_completion is not None
return chat_completion


def init_chat_completion(chunk: ChatCompletionChunk) -> ChatCompletion:
return ChatCompletion(
id=chunk.id,
choices=[
Choice(
finish_reason="stop",
index=choice.index,
logprobs=(ChoiceLogprobs() if choice.logprobs else None),
message=ChatCompletionMessage(role="assistant"),
)
for choice in chunk.choices
],
created=chunk.created,
model=chunk.model,
object="chat.completion",
)


def update_chat_completion(
chat_completion: ChatCompletion, chunk: ChatCompletionChunk
) -> None:
for choice, chunk_choice in zip(chat_completion.choices, chunk.choices):
choice.finish_reason = chunk_choice.finish_reason or "stop"
if chunk_choice.logprobs:
if choice.logprobs is None:
choice.logprobs = ChoiceLogprobs()
if chunk_choice.logprobs.content:
if choice.logprobs.content is None:
choice.logprobs.content = []
choice.logprobs.content.extend(chunk_choice.logprobs.content)
if chunk_choice.logprobs.refusal:
if choice.logprobs.refusal is None:
choice.logprobs.refusal = []
choice.logprobs.refusal.extend(chunk_choice.logprobs.refusal)
if chunk_choice.delta.content:
if choice.message.content is None:
choice.message.content = ""
choice.message.content += chunk_choice.delta.content
if chunk_choice.delta.refusal:
if choice.message.refusal is None:
choice.message.refusal = ""
choice.message.refusal += chunk_choice.delta.refusal
if chunk_choice.delta.function_call:
if choice.message.function_call is None:
choice.message.function_call = FunctionCall(arguments="", name="")
choice.message.function_call.name += (
chunk_choice.delta.function_call.name or ""
)
choice.message.function_call.arguments += (
chunk_choice.delta.function_call.arguments or ""
)
if chunk_choice.delta.tool_calls:
if choice.message.tool_calls is None:
choice.message.tool_calls = []
for tool_call in chunk_choice.delta.tool_calls:
while tool_call.index not in range(len(choice.message.tool_calls)):
choice.message.tool_calls.append(
ChatCompletionMessageToolCall(
id="",
function=Function(arguments="", name=""),
type="function",
)
)
if tool_call.id:
choice.message.tool_calls[tool_call.index].id = tool_call.id
if tool_call.function:
if tool_call.function.name:
choice.message.tool_calls[
tool_call.index
].function.name = tool_call.function.name
if tool_call.function.arguments:
choice.message.tool_calls[
tool_call.index
].function.arguments += tool_call.function.arguments
if getattr(chunk_choice.delta, "reasoning", None):
if not hasattr(choice.message, "reasoning"):
setattr(choice.message, "reasoning", "")
setattr(
choice.message,
"reasoning",
getattr(choice.message, "reasoning")
+ getattr(chunk_choice.delta, "reasoning"),
)
chat_completion.service_tier = chunk.service_tier
chat_completion.system_fingerprint = chunk.system_fingerprint
chat_completion.usage = chunk.usage
3 changes: 3 additions & 0 deletions src/art/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class History(pydantic.BaseModel):
messages_and_choices: MessagesAndChoices
tools: Tools | None = None

def messages(self) -> Messages:
return get_messages(self.messages_and_choices)


class Trajectory(pydantic.BaseModel):
messages_and_choices: MessagesAndChoices
Expand Down
2 changes: 1 addition & 1 deletion src/art/yield_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:


yield_trajectory_context_var: contextvars.ContextVar[YieldTrajectoryContext] = (
contextvars.ContextVar("trajectory", default=YieldTrajectoryContext())
contextvars.ContextVar("yield_trajectory_context", default=YieldTrajectoryContext())
)
Loading