Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from litellm import ContextWindowExceededError

import dspy
from dspy.adapters.types.history import History
from dspy.adapters.types.tool import Tool
from dspy.primitives.example import Example
from dspy.primitives.module import Module
from dspy.signatures.signature import ensure_signature

Expand Down Expand Up @@ -93,6 +95,34 @@ def _format_trajectory(self, trajectory: dict[str, Any]):
trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x")
return adapter.format_user_message_content(trajectory_signature, trajectory)

def _format_history_with_trajectory(self, history: History | None) -> History | None:
if history is None or not history.messages:
return history

formatted_messages = []
changed = False

for message in history.messages:
if isinstance(message, Example):
message_data = dict(message.items())
elif isinstance(message, dict):
message_data = dict(message)
else:
formatted_messages.append(message)
continue

trajectory_value = message_data.get("trajectory")
if trajectory_value and not isinstance(trajectory_value, str):
message_data["trajectory"] = self._format_trajectory(trajectory_value)
changed = True

formatted_messages.append(message_data)

if not changed:
return history

return History(messages=formatted_messages)

def forward(self, **input_args):
trajectory = {}
max_iters = input_args.pop("max_iters", self.max_iters)
Expand Down Expand Up @@ -146,21 +176,29 @@ async def aforward(self, **input_args):
def _call_with_potential_trajectory_truncation(self, module, trajectory, **input_args):
for _ in range(3):
try:
return module(
**input_args,
trajectory=self._format_trajectory(trajectory),
)
call_kwargs = dict(input_args)
call_kwargs["trajectory"] = self._format_trajectory(trajectory)

history_value = call_kwargs.get("history")
if isinstance(history_value, History):
call_kwargs["history"] = self._format_history_with_trajectory(history_value)

return module(**call_kwargs)
except ContextWindowExceededError:
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
trajectory = self.truncate_trajectory(trajectory)

async def _async_call_with_potential_trajectory_truncation(self, module, trajectory, **input_args):
for _ in range(3):
try:
return await module.acall(
**input_args,
trajectory=self._format_trajectory(trajectory),
)
call_kwargs = dict(input_args)
call_kwargs["trajectory"] = self._format_trajectory(trajectory)

history_value = call_kwargs.get("history")
if isinstance(history_value, History):
call_kwargs["history"] = self._format_history_with_trajectory(history_value)

return await module.acall(**call_kwargs)
except ContextWindowExceededError:
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
trajectory = self.truncate_trajectory(trajectory)
Expand Down
67 changes: 67 additions & 0 deletions tests/predict/test_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,73 @@ def make_images():
assert sum(1 for part in observation_content if isinstance(part, dict) and part.get("type") == "image_url") == 2


def test_history_trajectory_uses_chat_format():
def math_tool(expression: str) -> str:
return str(eval(expression))

adapter = dspy.ChatAdapter()
lm = DummyLM(
[
{
"next_thought": "I should use the math tool.",
"next_tool_name": "math_tool",
"next_tool_args": {"expression": "2+2"},
},
{
"next_thought": "That answers the question; time to finish.",
"next_tool_name": "finish",
"next_tool_args": {},
},
{
"reasoning": "Computed 2+2 with the provided tool.",
"answer": "4",
},
{
"next_thought": "Reusing the math tool for another calculation.",
"next_tool_name": "math_tool",
"next_tool_args": {"expression": "3*4"},
},
{
"next_thought": "I have the second result and can finish now.",
"next_tool_name": "finish",
"next_tool_args": {},
},
{
"reasoning": "Computed 3*4 with the provided tool.",
"answer": "12",
},
],
adapter=adapter,
)

dspy.settings.configure(lm=lm, adapter=adapter)

class HistorySignature(dspy.Signature):
history: dspy.History = dspy.InputField()
question: str = dspy.InputField()
answer: str = dspy.OutputField()

react = dspy.ReAct(HistorySignature, tools=[math_tool])

history = dspy.History(messages=[])

q1 = "What is 2+2?"
first_outputs = react(history=history, question=q1)
history.messages.append({"question": q1, **first_outputs})

q2 = "Now compute 3*4."
react(history=history, question=q2)

messages = lm.history[-1]["messages"]
user_messages = [message for message in messages if message.get("role") == "user"]
assert len(user_messages) >= 2

first_user_content = user_messages[0]["content"]
assert "[[ ## trajectory ## ]]" in first_user_content
assert "[[ ## thought_0 ## ]]" in first_user_content
assert '"thought_0"' not in first_user_content


def test_tool_calling_with_pydantic_args():
class CalendarEvent(BaseModel):
name: str
Expand Down