diff --git a/dev/data.ipynb b/dev/data.ipynb new file mode 100644 index 00000000..2b1671b0 --- /dev/null +++ b/dev/data.ipynb @@ -0,0 +1,536 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a7ff6842", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "01f78de0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mbradhilton\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.21.0" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/ubuntu/sky_workdir/dev/wandb/run-20250822_022145-test" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Resuming run test to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/bradhilton/tests" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/bradhilton/tests/runs/test" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 08-22 02:21:51 [__init__.py:235] Automatically detected platform cuda.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/sky_workdir/src/art/__init__.py:10: UserWarning: WARNING: Unsloth should be imported before transformers, peft to ensure all optimizations are applied. Your code may run slower or encounter memory issues without these optimizations.\n", + "\n", + "Please restructure your imports with 'import unsloth' at the top of your file.\n", + " import unsloth # type: ignore # noqa: F401\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n", + "INFO 08-22 02:21:59 [__init__.py:235] Automatically detected platform cuda.\n", + "🦥 Unsloth Zoo will now patch everything to make training faster!\n", + "Unsloth: Patching vLLM v1 graph capture\n", + "Unsloth: Patching vLLM v0 graph capture\n", + "==((====))== Unsloth 2025.8.6: Fast Qwen2 patching. Transformers: 4.53.2. vLLM: 0.10.0.\n", + " \\\\ /| NVIDIA H100 PCIe. Num GPUs = 1. Max memory: 79.189 GB. Platform: Linux.\n", + "O^O/ \\_/ \\ Torch: 2.7.1+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.1\n", + "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]\n", + " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", + "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n", + "Unsloth: vLLM loading unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit with actual GPU utilization = 78.47%\n", + "Unsloth: Your GPU has CUDA compute capability 9.0 with VRAM = 79.19 GB.\n", + "Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 32768. Num Sequences = 368.\n", + "Unsloth: vLLM's KV Cache can use up to 56.27 GB. Also swap space = 6 GB.\n", + "Unsloth: Not an error, but `device` is not supported in vLLM. Skipping.\n", + "INFO 08-22 02:22:18 [config.py:1604] Using max model len 32768\n", + "Unsloth: vLLM Bitsandbytes config using kwargs = {'load_in_8bit': False, 'load_in_4bit': True, 'bnb_4bit_compute_dtype': 'bfloat16', 'bnb_4bit_quant_storage': 'uint8', 'bnb_4bit_quant_type': 'nf4', 'bnb_4bit_use_double_quant': True, 'llm_int8_enable_fp32_cpu_offload': False, 'llm_int8_has_fp16_weight': False, 'llm_int8_skip_modules': ['lm_head', 'multi_modal_projector', 'merger', 'modality_projection', 'model.layers.0.self_attn', 'model.layers.1.self_attn', 'model.layers.2.mlp', 'model.layers.3.mlp', 'model.layers.4.mlp', 'model.layers.25.mlp', 'model.layers.26.mlp'], 'llm_int8_threshold': 6.0}\n", + "INFO 08-22 02:22:18 [llm_engine.py:228] Initializing a V0 LLM engine (v0.10.0) with config: model='unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit', speculative_config=None, tokenizer='unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=bitsandbytes, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=False, use_async_output_proc=True, pooler_config=None, compilation_config={\"level\":0,\"debug_dump_path\":\"\",\"cache_dir\":\"\",\"backend\":\"inductor\",\"custom_ops\":[],\"splitting_ops\":[],\"use_inductor\":true,\"compile_sizes\":[],\"inductor_compile_config\":{\"epilogue_fusion\":true,\"max_autotune\":false,\"shape_padding\":true,\"trace.enabled\":false,\"triton.cudagraphs\":true,\"debug\":false,\"dce\":true,\"memory_planning\":true,\"coordinate_descent_tuning\":true,\"trace.graph_diagram\":false,\"compile_threads\":26,\"group_fusion\":true,\"disable_progress\":false,\"verbose_progress\":true,\"triton.multi_kernel\":0,\"triton.use_block_ptr\":true,\"triton.enable_persistent_tma_matmul\":true,\"triton.autotune_at_compile_time\":false,\"triton.cooperative_reductions\":false,\"cuda.compile_opt_level\":\"-O2\",\"cuda.enable_cuda_lto\":true,\"combo_kernels\":false,\"benchmark_combo_kernel\":true,\"combo_kernel_foreach_dynamic_shapes\":true,\"enable_auto_functionalized_v2\":false},\"inductor_passes\":{},\"use_cudagraph\":true,\"cudagraph_num_of_warmups\":1,\"cudagraph_capture_sizes\":[368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],\"cudagraph_copy_inputs\":false,\"full_cuda_graph\":false,\"max_capture_size\":368,\"local_cache_dir\":null}, use_cached_outputs=False, \n", + "INFO 08-22 02:22:20 [cuda.py:398] Using Flash Attention backend.\n", + "INFO 08-22 02:22:20 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0\n", + "INFO 08-22 02:22:20 [model_runner.py:1083] Starting to load model unsloth/qwen2.5-7b-instruct-unsloth-bnb-4bit...\n", + "INFO 08-22 02:22:21 [bitsandbytes_loader.py:733] Loading weights with BitsAndBytes quantization. May take a while ...\n", + "INFO 08-22 02:22:21 [weight_utils.py:296] Using model weights format ['*.safetensors']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00 Trajectory: ... + + +@overload +def auto_trajectory(*, required: Literal[False] = False) -> Trajectory | None: ... + + +def auto_trajectory(*, required: bool = False) -> Trajectory | None: + context = auto_trajectory_context_var.get(None) + if context is None: + if required: + raise RuntimeError( + "No auto trajectory in context. `auto_trajectory(required=True)` must be called in a `capture_auto_trajectory(...)` scope." + ) + return None + return context.trajectory + + +async def capture_auto_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory: + with AutoTrajectoryContext(): + await coroutine + trajectory = auto_trajectory_context_var.get().trajectory + trajectory.finish() + return trajectory + + +class AutoTrajectoryContext: + def __init__(self) -> None: + self.trajectory = Trajectory( + messages_and_choices=[], + reward=0.0, + ) + self.openai_client = OpenAI(api_key="") + + def __enter__(self) -> None: + self.token = auto_trajectory_context_var.set(self) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + auto_trajectory_context_var.reset(self.token) + + def handle_httpx_response(self, response: httpx._models.Response) -> None: + try: + request_content = json.loads(getattr(response.request, "_content", b"")) + messages = request_content["messages"] + tools = request_content.get("tools", None) + setattr(response, "_content", getattr(response, "_content_so_far", b"")) + print(getattr(response, "_content")) + if request_content.get("stream", False): + choice = consume_sync_chat_completion_stream( + Stream( + cast_to=ChatCompletionChunk, + response=response, + client=self.openai_client, + ) + ).choices[0] + else: + choice = Choice( + **json.loads(getattr(response, "_content"))["choices"][0] + ) + history: Trajectory | History = self.trajectory + history_index = -1 + while True: + history_messages = history.messages() + if history_messages == messages[: len(history_messages)] and ( + history.tools == tools + or (history_messages == [] and history.tools is None) + ): + break + history_index += 1 + try: + history = self.trajectory.additional_histories[history_index] + except IndexError: + history = History(messages_and_choices=[]) + self.trajectory.additional_histories.append(history) + break + history.messages_and_choices.extend( + messages[len(history.messages_and_choices) :] + ) + history.messages_and_choices.append(choice) + history.tools = tools + except: + pass + + +auto_trajectory_context_var: contextvars.ContextVar[AutoTrajectoryContext] = ( + contextvars.ContextVar("auto_trajectory_context") +) + + +def patch_httpx() -> None: + original_iter_bytes = httpx._models.Response.iter_bytes + original_aiter_bytes = httpx._models.Response.aiter_bytes + original_close = httpx._models.Response.close + original_aclose = httpx._models.Response.aclose + + def patched_iter_bytes( + self: httpx._models.Response, chunk_size: int | None = None + ) -> Iterator[bytes]: + for chunk in original_iter_bytes(self, chunk_size): + setattr( + self, "_content_so_far", getattr(self, "_content_so_far", b"") + chunk + ) + yield chunk + + async def patched_aiter_bytes( + self: httpx._models.Response, chunk_size: int | None = None + ) -> AsyncIterator[bytes]: + async for chunk in original_aiter_bytes(self, chunk_size): + setattr( + self, "_content_so_far", getattr(self, "_content_so_far", b"") + chunk + ) + yield chunk + + def patched_close(self: httpx._models.Response) -> None: + original_close(self) + if context := auto_trajectory_context_var.get(None): + context.handle_httpx_response(self) + + async def patched_aclose(self: httpx._models.Response) -> None: + await original_aclose(self) + if context := auto_trajectory_context_var.get(None): + context.handle_httpx_response(self) + + httpx._models.Response.iter_bytes = patched_iter_bytes + httpx._models.Response.aiter_bytes = patched_aiter_bytes + httpx._models.Response.close = patched_close + httpx._models.Response.aclose = patched_aclose + + +patch_httpx() diff --git a/src/art/openai.py b/src/art/openai.py index 9a2811b1..dd9a32e6 100644 --- a/src/art/openai.py +++ b/src/art/openai.py @@ -1,7 +1,7 @@ from typing import Any, AsyncIterator, Callable, cast import openai -from openai import AsyncStream +from openai import AsyncStream, Stream from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_message import ( @@ -82,86 +82,8 @@ async def consume_chat_completion_stream( chat_completion: ChatCompletion | None = None async for chunk in stream: if chat_completion is None: - chat_completion = ChatCompletion( - id=chunk.id, - choices=[ - Choice( - finish_reason="stop", - index=choice.index, - logprobs=(ChoiceLogprobs() if choice.logprobs else None), - message=ChatCompletionMessage(role="assistant"), - ) - for choice in chunk.choices - ], - created=chunk.created, - model=chunk.model, - object="chat.completion", - ) - for choice, chunk_choice in zip(chat_completion.choices, chunk.choices): - choice.finish_reason = chunk_choice.finish_reason or "stop" - if chunk_choice.logprobs: - if choice.logprobs is None: - choice.logprobs = ChoiceLogprobs() - if chunk_choice.logprobs.content: - if choice.logprobs.content is None: - choice.logprobs.content = [] - choice.logprobs.content.extend(chunk_choice.logprobs.content) - if chunk_choice.logprobs.refusal: - if choice.logprobs.refusal is None: - choice.logprobs.refusal = [] - choice.logprobs.refusal.extend(chunk_choice.logprobs.refusal) - if chunk_choice.delta.content: - if choice.message.content is None: - choice.message.content = "" - choice.message.content += chunk_choice.delta.content - if chunk_choice.delta.refusal: - if choice.message.refusal is None: - choice.message.refusal = "" - choice.message.refusal += chunk_choice.delta.refusal - if chunk_choice.delta.function_call: - if choice.message.function_call is None: - choice.message.function_call = FunctionCall(arguments="", name="") - choice.message.function_call.name += ( - chunk_choice.delta.function_call.name or "" - ) - choice.message.function_call.arguments += ( - chunk_choice.delta.function_call.arguments or "" - ) - if chunk_choice.delta.tool_calls: - if choice.message.tool_calls is None: - choice.message.tool_calls = [] - for tool_call in chunk_choice.delta.tool_calls: - while tool_call.index not in range(len(choice.message.tool_calls)): - choice.message.tool_calls.append( - ChatCompletionMessageToolCall( - id="", - function=Function(arguments="", name=""), - type="function", - ) - ) - if tool_call.id: - choice.message.tool_calls[tool_call.index].id = tool_call.id - if tool_call.function: - if tool_call.function.name: - choice.message.tool_calls[ - tool_call.index - ].function.name = tool_call.function.name - if tool_call.function.arguments: - choice.message.tool_calls[ - tool_call.index - ].function.arguments += tool_call.function.arguments - if getattr(chunk_choice.delta, "reasoning", None): - if not hasattr(choice.message, "reasoning"): - setattr(choice.message, "reasoning", "") - setattr( - choice.message, - "reasoning", - getattr(choice.message, "reasoning") - + getattr(chunk_choice.delta, "reasoning"), - ) - chat_completion.service_tier = chunk.service_tier - chat_completion.system_fingerprint = chunk.system_fingerprint - chat_completion.usage = chunk.usage + chat_completion = init_chat_completion(chunk) + update_chat_completion(chat_completion, chunk) if on_chunk: try: on_chunk(chunk, chat_completion) @@ -170,3 +92,103 @@ async def consume_chat_completion_stream( break assert chat_completion is not None return chat_completion + + +def consume_sync_chat_completion_stream( + stream: Stream[ChatCompletionChunk], +) -> ChatCompletion: + chat_completion: ChatCompletion | None = None + for chunk in stream: + if chat_completion is None: + chat_completion = init_chat_completion(chunk) + update_chat_completion(chat_completion, chunk) + assert chat_completion is not None + return chat_completion + + +def init_chat_completion(chunk: ChatCompletionChunk) -> ChatCompletion: + return ChatCompletion( + id=chunk.id, + choices=[ + Choice( + finish_reason="stop", + index=choice.index, + logprobs=(ChoiceLogprobs() if choice.logprobs else None), + message=ChatCompletionMessage(role="assistant"), + ) + for choice in chunk.choices + ], + created=chunk.created, + model=chunk.model, + object="chat.completion", + ) + + +def update_chat_completion( + chat_completion: ChatCompletion, chunk: ChatCompletionChunk +) -> None: + for choice, chunk_choice in zip(chat_completion.choices, chunk.choices): + choice.finish_reason = chunk_choice.finish_reason or "stop" + if chunk_choice.logprobs: + if choice.logprobs is None: + choice.logprobs = ChoiceLogprobs() + if chunk_choice.logprobs.content: + if choice.logprobs.content is None: + choice.logprobs.content = [] + choice.logprobs.content.extend(chunk_choice.logprobs.content) + if chunk_choice.logprobs.refusal: + if choice.logprobs.refusal is None: + choice.logprobs.refusal = [] + choice.logprobs.refusal.extend(chunk_choice.logprobs.refusal) + if chunk_choice.delta.content: + if choice.message.content is None: + choice.message.content = "" + choice.message.content += chunk_choice.delta.content + if chunk_choice.delta.refusal: + if choice.message.refusal is None: + choice.message.refusal = "" + choice.message.refusal += chunk_choice.delta.refusal + if chunk_choice.delta.function_call: + if choice.message.function_call is None: + choice.message.function_call = FunctionCall(arguments="", name="") + choice.message.function_call.name += ( + chunk_choice.delta.function_call.name or "" + ) + choice.message.function_call.arguments += ( + chunk_choice.delta.function_call.arguments or "" + ) + if chunk_choice.delta.tool_calls: + if choice.message.tool_calls is None: + choice.message.tool_calls = [] + for tool_call in chunk_choice.delta.tool_calls: + while tool_call.index not in range(len(choice.message.tool_calls)): + choice.message.tool_calls.append( + ChatCompletionMessageToolCall( + id="", + function=Function(arguments="", name=""), + type="function", + ) + ) + if tool_call.id: + choice.message.tool_calls[tool_call.index].id = tool_call.id + if tool_call.function: + if tool_call.function.name: + choice.message.tool_calls[ + tool_call.index + ].function.name = tool_call.function.name + if tool_call.function.arguments: + choice.message.tool_calls[ + tool_call.index + ].function.arguments += tool_call.function.arguments + if getattr(chunk_choice.delta, "reasoning", None): + if not hasattr(choice.message, "reasoning"): + setattr(choice.message, "reasoning", "") + setattr( + choice.message, + "reasoning", + getattr(choice.message, "reasoning") + + getattr(chunk_choice.delta, "reasoning"), + ) + chat_completion.service_tier = chunk.service_tier + chat_completion.system_fingerprint = chunk.system_fingerprint + chat_completion.usage = chunk.usage diff --git a/src/art/trajectories.py b/src/art/trajectories.py index d7c340d6..4a2020a5 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -23,6 +23,9 @@ class History(pydantic.BaseModel): messages_and_choices: MessagesAndChoices tools: Tools | None = None + def messages(self) -> Messages: + return get_messages(self.messages_and_choices) + class Trajectory(pydantic.BaseModel): messages_and_choices: MessagesAndChoices diff --git a/src/art/yield_trajectory.py b/src/art/yield_trajectory.py index 26a1ebd5..5109d193 100644 --- a/src/art/yield_trajectory.py +++ b/src/art/yield_trajectory.py @@ -29,5 +29,5 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: yield_trajectory_context_var: contextvars.ContextVar[YieldTrajectoryContext] = ( - contextvars.ContextVar("trajectory", default=YieldTrajectoryContext()) + contextvars.ContextVar("yield_trajectory_context", default=YieldTrajectoryContext()) ) diff --git a/tests/unit/test_auto_trajectory.py b/tests/unit/test_auto_trajectory.py new file mode 100644 index 00000000..af170ad7 --- /dev/null +++ b/tests/unit/test_auto_trajectory.py @@ -0,0 +1,398 @@ +import warnings + +# Suppress pydantic warnings at module level +warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic") + +import litellm +import litellm.litellm_core_utils.streaming_handler +import litellm.types.utils +import pytest +import pytest_asyncio +from aiohttp import web +from openai import AsyncOpenAI +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam + +import art +from art.utils.litellm import convert_litellm_choice_to_openai + +mock_response = { + "id": "chatcmpl-293ce9f37dba40e5be39448acaf6fb49", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": { + "content": [ + { + "token": "token_id:9707", + "bytes": [72, 101, 108, 108, 111], + "logprob": -0.0017243054462596774, + "top_logprobs": [], + }, + { + "token": "token_id:0", + "bytes": [33], + "logprob": -0.007611795328557491, + "top_logprobs": [], + }, + { + "token": "token_id:2585", + "bytes": [32, 72, 111, 119], + "logprob": -0.03061593696475029, + "top_logprobs": [], + }, + { + "token": "token_id:646", + "bytes": [32, 99, 97, 110], + "logprob": -1.1920858014491387e-05, + "top_logprobs": [], + }, + { + "token": "token_id:358", + "bytes": [32, 73], + "logprob": -2.3841855067985307e-07, + "top_logprobs": [], + }, + { + "token": "token_id:7789", + "bytes": [32, 97, 115, 115, 105, 115, 116], + "logprob": -0.020548323169350624, + "top_logprobs": [], + }, + { + "token": "token_id:498", + "bytes": [32, 121, 111, 117], + "logprob": 0.0, + "top_logprobs": [], + }, + { + "token": "token_id:3351", + "bytes": [32, 116, 111, 100, 97, 121], + "logprob": -4.410734163684538e-06, + "top_logprobs": [], + }, + { + "token": "token_id:30", + "bytes": [63], + "logprob": -2.3841855067985307e-07, + "top_logprobs": [], + }, + { + "token": "token_id:151645", + "bytes": [], + "logprob": -0.0083366259932518, + "top_logprobs": [], + }, + ], + "refusal": None, + }, + "message": { + "content": "Hello! How can I assist you today?", + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": [], + "reasoning_content": None, + }, + "stop_reason": None, + } + ], + "created": 1755801745, + "model": "test", + "object": "chat.completion", + "service_tier": None, + "system_fingerprint": None, + "usage": { + "completion_tokens": 10, + "prompt_tokens": 31, + "total_tokens": 41, + "completion_tokens_details": None, + "prompt_tokens_details": None, + }, + "prompt_logprobs": None, + "kv_transfer_params": None, +} +mock_stream_response = b"""data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"id":"chatcmpl-tool-29e663261e524fcfa2162f4f3d76a7f0","type":"function","index":0,"function":{"name":"get_current_weather","arguments":"{"}}]},"logprobs":{"content":[{"token":"token_id:314","logprob":-0.00015293381875380874,"bytes":[32,123],"top_logprobs":[]}]},"finish_reason":null}]} + +data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"id":"chatcmpl-tool-29e663261e524fcfa2162f4f3d76a7f0","type":"function","index":0,"function":{"name":"get_current_weather","arguments":"{"}}]},"logprobs":{"content":[{"token":"token_id:314","logprob":-0.00015293381875380874,"bytes":[32,123],"top_logprobs":[]}]},"finish_reason":null}]} + +data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":null,"arguments":"}"}}]},"logprobs":{"content":[{"token":"token_id:3417","logprob":-3.576278118089249e-7,"bytes":[125,125],"top_logprobs":[]}]},"finish_reason":null}]} + +data: {"id":"chatcmpl-aa0d1e3261414f53acafc2f8e53bf9d6","object":"chat.completion.chunk","created":1755831263,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":null,"arguments":"}"}}]},"logprobs":{"content":[{"token":"token_id:3417","logprob":-3.576278118089249e-7,"bytes":[125,125],"top_logprobs":[]}]},"finish_reason":null}]} + +data: [DONE] + +data: [DONE] + +""" +mock_stream_choice = Choice( + **{ + "finish_reason": "stop", + "index": 0, + "logprobs": { + "content": [ + { + "token": "token_id:314", + "bytes": [32, 123], + "logprob": -0.00015293381875380874, + "top_logprobs": [], + }, + { + "token": "token_id:314", + "bytes": [32, 123], + "logprob": -0.00015293381875380874, + "top_logprobs": [], + }, + { + "token": "token_id:3417", + "bytes": [125, 125], + "logprob": -3.576278118089249e-07, + "top_logprobs": [], + }, + { + "token": "token_id:3417", + "bytes": [125, 125], + "logprob": -3.576278118089249e-07, + "top_logprobs": [], + }, + ], + "refusal": None, + }, + "message": { + "content": None, + "refusal": None, + "role": "assistant", + "annotations": None, + "audio": None, + "function_call": None, + "tool_calls": [ + { + "id": "chatcmpl-tool-29e663261e524fcfa2162f4f3d76a7f0", + "function": {"arguments": "{{}}", "name": "get_current_weather"}, + "type": "function", + } + ], + }, + } +) + + +@pytest_asyncio.fixture +async def test_server(): + """Start a test server for the module.""" + + async def handler(request: web.Request) -> web.Response: + body = await request.json() + if body.get("stream", False): + return web.Response(body=mock_stream_response) + return web.json_response(mock_response) + + app = web.Application() + app.router.add_route("POST", "/v1/chat/completions", handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 8888) + await site.start() + print(f"Test server started on http://localhost:8888") + + yield # Tests run here + + print("Cleaning up test server...") + await runner.cleanup() + + +async def test_auto_trajectory(test_server: None) -> None: + message: ChatCompletionMessageParam = {"role": "user", "content": "Hi!"} + tools: list[ChatCompletionToolParam] = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + } + ] + + async def say_hi() -> str | None: + """A method that says hi to an assistant and returns the response.""" + client = AsyncOpenAI(base_url="http://localhost:8888/v1", api_key="default") + chat_completion = await client.chat.completions.create( + model="test", + messages=[message], + tools=tools, + ) + # test a follow up message + chat_completion = await client.chat.completions.create( + model="test", + messages=[ + message, + { + "role": "assistant", + "content": chat_completion.choices[0].message.content, + }, + message, + ], + tools=tools, + ) + # and another call without tools (should create a new history) + chat_completion = await client.chat.completions.create( + model="test", + messages=[ + message, + { + "role": "assistant", + "content": chat_completion.choices[0].message.content, + }, + message, + { + "role": "assistant", + "content": chat_completion.choices[0].message.content, + }, + message, + ], + ) + # and another call with tools, but limited messages (should create another history) + chat_completion = await client.chat.completions.create( + model="test", + messages=[message], + tools=tools, + ) + # and another call with tool_choice="required" & stream=True + async for _ in await client.chat.completions.create( + model="test", + messages=[message], + tool_choice="required", + tools=tools, + stream=True, + ): + pass + # Add ART support with a couple lines of optional code + if trajectory := art.auto_trajectory(): + trajectory.reward = 1.0 + return chat_completion.choices[0].message.content + + # Use the capture_auto_trajectory utility to capture a trajectory automatically + trajectory = await art.capture_auto_trajectory(say_hi()) + assert trajectory.messages_and_choices == [ + message, + Choice(**mock_response["choices"][0]), + message, + Choice(**mock_response["choices"][0]), + ] + assert trajectory.tools == tools + assert trajectory.additional_histories[0].messages_and_choices == [ + message, + { + "content": "Hello! How can I assist you today?", + "role": "assistant", + }, + message, + { + "content": "Hello! How can I assist you today?", + "role": "assistant", + }, + message, + Choice(**mock_response["choices"][0]), + ] + assert trajectory.additional_histories[0].tools is None + assert trajectory.additional_histories[1].messages_and_choices == [ + message, + Choice(**mock_response["choices"][0]), + ] + assert trajectory.additional_histories[1].tools == tools + assert trajectory.additional_histories[2].messages_and_choices == [ + message, + mock_stream_choice, + ] + assert trajectory.additional_histories[2].tools == tools + + +@pytest.mark.filterwarnings("ignore::UserWarning:pydantic") +async def test_litellm_auto_trajectory(test_server: None) -> None: + message: ChatCompletionMessageParam = {"role": "user", "content": "Hi!"} + tools: list[ChatCompletionToolParam] = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + } + ] + + async def say_hi() -> str | None: + """A method that says hi to an assistant and returns the response.""" + response = await litellm.acompletion( + model="openai/test", + messages=[message], + tools=tools, + base_url="http://localhost:8888/v1", + api_key="default", + ) + assert isinstance(response, litellm.types.utils.ModelResponse) + choice = convert_litellm_choice_to_openai(response.choices[0]) + # follow up message + response = await litellm.acompletion( + model="openai/test", + messages=[ + message, + {"role": "assistant", "content": choice.message.content}, + message, + ], + tools=tools, + base_url="http://localhost:8888/v1", + api_key="default", + ) + assert isinstance(response, litellm.types.utils.ModelResponse) + choice = convert_litellm_choice_to_openai(response.choices[0]) + # another call with tool_choice="required" & stream=True + stream = await litellm.acompletion( + model="openai/test", + messages=[message], + tool_choice="required", + tools=tools, + stream=True, + base_url="http://localhost:8888/v1", + api_key="default", + ) + assert isinstance( + stream, litellm.litellm_core_utils.streaming_handler.CustomStreamWrapper + ) + async for _ in stream: + pass + # Add ART support with a couple lines of optional code + if trajectory := art.auto_trajectory(): + trajectory.reward = 1.0 + return choice.message.content + + # Use the capture_auto_trajectory utility to capture a trajectory automatically + trajectory = await art.capture_auto_trajectory(say_hi()) + assert trajectory.messages_and_choices == [ + message, + Choice(**mock_response["choices"][0]), + message, + Choice(**mock_response["choices"][0]), + ] + assert trajectory.additional_histories[0].messages_and_choices == [ + message, + mock_stream_choice, + ] + assert trajectory.additional_histories[0].tools == tools diff --git a/tests/unit/test_yield_trajectory.py b/tests/unit/test_yield_trajectory.py index 95a53fe5..c730c4e7 100644 --- a/tests/unit/test_yield_trajectory.py +++ b/tests/unit/test_yield_trajectory.py @@ -128,7 +128,7 @@ async def handler(_: web.Request) -> web.Response: await runner.cleanup() -async def test_with_trajectory(test_server: None) -> None: +async def test_yield_trajectory(test_server: None) -> None: async def say_hi() -> str | None: """A method that says hi to an assistant and returns the response.""" client = AsyncOpenAI(base_url="http://localhost:8888/v1", api_key="default")