Skip to content

Commit 973b7aa

Browse files
committed
Add unit tests
1 parent 91ed99a commit 973b7aa

File tree

1 file changed

+72
-1
lines changed

1 file changed

+72
-1
lines changed

tests/test_chat.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import tempfile
33

44
import pytest
5-
from chatlas import ChatOpenAI, Turn
65
from pydantic import BaseModel
76

7+
from chatlas import ChatOpenAI, ToolResult, Turn
8+
89

910
def test_simple_batch_chat():
1011
chat = ChatOpenAI()
@@ -91,6 +92,76 @@ def test_basic_export(snapshot):
9192
assert snapshot == f.read()
9293

9394

95+
def test_tool_results():
96+
chat = ChatOpenAI(system_prompt="Be very terse, not even punctuation.")
97+
98+
def get_date():
99+
"""Gets the current date"""
100+
return ToolResult("2024-01-01", response_output=["Tool result..."])
101+
102+
chat.register_tool(get_date)
103+
chat.on_tool_request(lambda req: [f"Requesting tool {req.name}..."])
104+
105+
results = []
106+
for chunk in chat.stream("What's the date?"):
107+
results.append(chunk)
108+
109+
# Make sure values haven't been str()'d yet
110+
assert ["Requesting tool get_date..."] in results
111+
assert ["Tool result..."] in results
112+
113+
response_str = "".join(str(chunk) for chunk in results)
114+
115+
assert "Requesting tool get_date..." in response_str
116+
assert "Tool result..." in response_str
117+
assert "2024-01-01" in response_str
118+
119+
chat.register_tool(get_date, on_request=lambda req: f"Calling {req.name}...")
120+
121+
response = chat.chat("What's the date?")
122+
assert "Calling get_date..." in str(response)
123+
assert "Requesting tool get_date..." not in str(response)
124+
assert "Tool result..." in str(response)
125+
assert "2024-01-01" in str(response)
126+
127+
128+
@pytest.mark.asyncio
129+
async def test_tool_results_async():
130+
chat = ChatOpenAI(system_prompt="Be very terse, not even punctuation.")
131+
132+
async def get_date():
133+
"""Gets the current date"""
134+
import asyncio
135+
136+
await asyncio.sleep(0.1)
137+
return ToolResult("2024-01-01", response_output=["Tool result..."])
138+
139+
chat.register_tool(get_date)
140+
chat.on_tool_request(lambda req: [f"Requesting tool {req.name}..."])
141+
142+
results = []
143+
async for chunk in await chat.stream_async("What's the date?"):
144+
results.append(chunk)
145+
146+
# Make sure values haven't been str()'d yet
147+
assert ["Requesting tool get_date..."] in results
148+
assert ["Tool result..."] in results
149+
150+
response_str = "".join(str(chunk) for chunk in results)
151+
152+
assert "Requesting tool get_date..." in response_str
153+
assert "Tool result..." in response_str
154+
assert "2024-01-01" in response_str
155+
156+
chat.register_tool(get_date, on_request=lambda req: [f"Calling {req.name}..."])
157+
158+
response = await chat.chat_async("What's the date?")
159+
assert "Calling get_date..." in await response.get_content()
160+
assert "Requesting tool get_date..." not in await response.get_content()
161+
assert "Tool result..." in await response.get_content()
162+
assert "2024-01-01" in await response.get_content()
163+
164+
94165
def test_extract_data():
95166
chat = ChatOpenAI()
96167

0 commit comments

Comments
 (0)