|
2 | 2 | import tempfile |
3 | 3 |
|
4 | 4 | import pytest |
5 | | -from chatlas import ChatOpenAI, Turn |
6 | 5 | from pydantic import BaseModel |
7 | 6 |
|
| 7 | +from chatlas import ChatOpenAI, ToolResult, Turn |
| 8 | + |
8 | 9 |
|
9 | 10 | def test_simple_batch_chat(): |
10 | 11 | chat = ChatOpenAI() |
@@ -91,6 +92,76 @@ def test_basic_export(snapshot): |
91 | 92 | assert snapshot == f.read() |
92 | 93 |
|
93 | 94 |
|
| 95 | +def test_tool_results(): |
| 96 | + chat = ChatOpenAI(system_prompt="Be very terse, not even punctuation.") |
| 97 | + |
| 98 | + def get_date(): |
| 99 | + """Gets the current date""" |
| 100 | + return ToolResult("2024-01-01", response_output=["Tool result..."]) |
| 101 | + |
| 102 | + chat.register_tool(get_date) |
| 103 | + chat.on_tool_request(lambda req: [f"Requesting tool {req.name}..."]) |
| 104 | + |
| 105 | + results = [] |
| 106 | + for chunk in chat.stream("What's the date?"): |
| 107 | + results.append(chunk) |
| 108 | + |
| 109 | + # Make sure values haven't been str()'d yet |
| 110 | + assert ["Requesting tool get_date..."] in results |
| 111 | + assert ["Tool result..."] in results |
| 112 | + |
| 113 | + response_str = "".join(str(chunk) for chunk in results) |
| 114 | + |
| 115 | + assert "Requesting tool get_date..." in response_str |
| 116 | + assert "Tool result..." in response_str |
| 117 | + assert "2024-01-01" in response_str |
| 118 | + |
| 119 | + chat.register_tool(get_date, on_request=lambda req: f"Calling {req.name}...") |
| 120 | + |
| 121 | + response = chat.chat("What's the date?") |
| 122 | + assert "Calling get_date..." in str(response) |
| 123 | + assert "Requesting tool get_date..." not in str(response) |
| 124 | + assert "Tool result..." in str(response) |
| 125 | + assert "2024-01-01" in str(response) |
| 126 | + |
| 127 | + |
| 128 | +@pytest.mark.asyncio |
| 129 | +async def test_tool_results_async(): |
| 130 | + chat = ChatOpenAI(system_prompt="Be very terse, not even punctuation.") |
| 131 | + |
| 132 | + async def get_date(): |
| 133 | + """Gets the current date""" |
| 134 | + import asyncio |
| 135 | + |
| 136 | + await asyncio.sleep(0.1) |
| 137 | + return ToolResult("2024-01-01", response_output=["Tool result..."]) |
| 138 | + |
| 139 | + chat.register_tool(get_date) |
| 140 | + chat.on_tool_request(lambda req: [f"Requesting tool {req.name}..."]) |
| 141 | + |
| 142 | + results = [] |
| 143 | + async for chunk in await chat.stream_async("What's the date?"): |
| 144 | + results.append(chunk) |
| 145 | + |
| 146 | + # Make sure values haven't been str()'d yet |
| 147 | + assert ["Requesting tool get_date..."] in results |
| 148 | + assert ["Tool result..."] in results |
| 149 | + |
| 150 | + response_str = "".join(str(chunk) for chunk in results) |
| 151 | + |
| 152 | + assert "Requesting tool get_date..." in response_str |
| 153 | + assert "Tool result..." in response_str |
| 154 | + assert "2024-01-01" in response_str |
| 155 | + |
| 156 | + chat.register_tool(get_date, on_request=lambda req: [f"Calling {req.name}..."]) |
| 157 | + |
| 158 | + response = await chat.chat_async("What's the date?") |
| 159 | + assert "Calling get_date..." in await response.get_content() |
| 160 | + assert "Requesting tool get_date..." not in await response.get_content() |
| 161 | + assert "Tool result..." in await response.get_content() |
| 162 | + assert "2024-01-01" in await response.get_content() |
| 163 | + |
| 164 | + |
94 | 165 | def test_extract_data(): |
95 | 166 | chat = ChatOpenAI() |
96 | 167 |
|
|
0 commit comments