Skip to content

Commit 52b3830

Browse files
committed
Add tool timeout support
1 parent b01098c commit 52b3830

File tree

9 files changed

+282
-2
lines changed

9 files changed

+282
-2
lines changed

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import dataclass, field, replace
88
from typing import Any, Generic
99

10+
import anyio
1011
from opentelemetry.trace import Tracer
1112
from pydantic import ValidationError
1213
from typing_extensions import assert_never
@@ -35,6 +36,8 @@ class ToolManager(Generic[AgentDepsT]):
3536
"""The cached tools for this run step."""
3637
failed_tools: set[str] = field(default_factory=set)
3738
"""Names of tools that failed in this run step."""
39+
default_timeout: float | None = None
40+
"""Default timeout in seconds for tool execution. None means no timeout."""
3841

3942
@classmethod
4043
@contextmanager
@@ -62,6 +65,7 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe
6265
toolset=self.toolset,
6366
ctx=ctx,
6467
tools=await self.toolset.get_tools(ctx),
68+
default_timeout=self.default_timeout,
6569
)
6670

6771
@property
@@ -172,7 +176,23 @@ async def _call_tool(
172176
call.args or {}, allow_partial=pyd_allow_partial, context=ctx.validation_context
173177
)
174178

175-
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
179+
# Determine effective timeout: per-tool timeout takes precedence over default
180+
effective_timeout = tool.timeout if tool.timeout is not None else self.default_timeout
181+
182+
if effective_timeout is not None:
183+
try:
184+
with anyio.fail_after(effective_timeout):
185+
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
186+
except TimeoutError:
187+
m = _messages.RetryPromptPart(
188+
tool_name=name,
189+
content=f"Tool '{name}' timed out after {effective_timeout} seconds. Please try a different approach.",
190+
tool_call_id=call.tool_call_id,
191+
)
192+
self.failed_tools.add(name)
193+
raise ToolRetryError(m) from None
194+
else:
195+
result = await self.toolset.call_tool(name, args_dict, ctx, tool)
176196

177197
return result
178198
except (ValidationError, ModelRetry) as e:

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
147147
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
148148
_max_result_retries: int = dataclasses.field(repr=False)
149149
_max_tool_retries: int = dataclasses.field(repr=False)
150+
_tool_timeout: float | None = dataclasses.field(repr=False)
150151
_validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False)
151152

152153
_event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False)
@@ -179,6 +180,7 @@ def __init__(
179180
instrument: InstrumentationSettings | bool | None = None,
180181
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
181182
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
183+
tool_timeout: float | None = None,
182184
) -> None: ...
183185

184186
@overload
@@ -206,6 +208,7 @@ def __init__(
206208
instrument: InstrumentationSettings | bool | None = None,
207209
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
208210
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
211+
tool_timeout: float | None = None,
209212
) -> None: ...
210213

211214
def __init__(
@@ -231,6 +234,7 @@ def __init__(
231234
instrument: InstrumentationSettings | bool | None = None,
232235
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
233236
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
237+
tool_timeout: float | None = None,
234238
**_deprecated_kwargs: Any,
235239
):
236240
"""Create an agent.
@@ -285,6 +289,9 @@ def __init__(
285289
Each processor takes a list of messages and returns a modified list of messages.
286290
Processors can be sync or async and are applied in sequence.
287291
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools.
292+
tool_timeout: Default timeout in seconds for tool execution. If a tool takes longer than this,
293+
a retry prompt is returned to the model. Individual tools can override this with their own timeout.
294+
Defaults to None (no timeout).
288295
"""
289296
if model is None or defer_model_check:
290297
self._model = model
@@ -318,6 +325,7 @@ def __init__(
318325

319326
self._max_result_retries = output_retries if output_retries is not None else retries
320327
self._max_tool_retries = retries
328+
self._tool_timeout = tool_timeout
321329

322330
self._validation_context = validation_context
323331

@@ -569,7 +577,7 @@ async def main():
569577
output_toolset.max_retries = self._max_result_retries
570578
output_toolset.output_validators = output_validators
571579
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
572-
tool_manager = ToolManager[AgentDepsT](toolset)
580+
tool_manager = ToolManager[AgentDepsT](toolset, default_timeout=self._tool_timeout)
573581

574582
# Build the graph
575583
graph = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_)
@@ -1031,6 +1039,7 @@ def tool(
10311039
sequential: bool = False,
10321040
requires_approval: bool = False,
10331041
metadata: dict[str, Any] | None = None,
1042+
timeout: float | None = None,
10341043
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
10351044

10361045
def tool(
@@ -1049,6 +1058,7 @@ def tool(
10491058
sequential: bool = False,
10501059
requires_approval: bool = False,
10511060
metadata: dict[str, Any] | None = None,
1061+
timeout: float | None = None,
10521062
) -> Any:
10531063
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
10541064
@@ -1098,6 +1108,8 @@ async def spam(ctx: RunContext[str], y: float) -> float:
10981108
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
10991109
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
11001110
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
1111+
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
1112+
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
11011113
"""
11021114

11031115
def tool_decorator(
@@ -1118,6 +1130,7 @@ def tool_decorator(
11181130
sequential=sequential,
11191131
requires_approval=requires_approval,
11201132
metadata=metadata,
1133+
timeout=timeout,
11211134
)
11221135
return func_
11231136

@@ -1142,6 +1155,7 @@ def tool_plain(
11421155
sequential: bool = False,
11431156
requires_approval: bool = False,
11441157
metadata: dict[str, Any] | None = None,
1158+
timeout: float | None = None,
11451159
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
11461160

11471161
def tool_plain(
@@ -1160,6 +1174,7 @@ def tool_plain(
11601174
sequential: bool = False,
11611175
requires_approval: bool = False,
11621176
metadata: dict[str, Any] | None = None,
1177+
timeout: float | None = None,
11631178
) -> Any:
11641179
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
11651180
@@ -1209,6 +1224,8 @@ async def spam(ctx: RunContext[str]) -> float:
12091224
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
12101225
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
12111226
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
1227+
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
1228+
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
12121229
"""
12131230

12141231
def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
@@ -1227,6 +1244,7 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams
12271244
sequential=sequential,
12281245
requires_approval=requires_approval,
12291246
metadata=metadata,
1247+
timeout=timeout,
12301248
)
12311249
return func_
12321250

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ class Tool(Generic[ToolAgentDepsT]):
262262
sequential: bool
263263
requires_approval: bool
264264
metadata: dict[str, Any] | None
265+
timeout: float | None
265266
function_schema: _function_schema.FunctionSchema
266267
"""
267268
The base JSON schema for the tool's parameters.
@@ -285,6 +286,7 @@ def __init__(
285286
sequential: bool = False,
286287
requires_approval: bool = False,
287288
metadata: dict[str, Any] | None = None,
289+
timeout: float | None = None,
288290
function_schema: _function_schema.FunctionSchema | None = None,
289291
):
290292
"""Create a new tool instance.
@@ -341,6 +343,8 @@ async def prep_my_tool(
341343
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
342344
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
343345
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
346+
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
347+
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
344348
function_schema: The function schema to use for the tool. If not provided, it will be generated.
345349
"""
346350
self.function = function
@@ -362,6 +366,7 @@ async def prep_my_tool(
362366
self.sequential = sequential
363367
self.requires_approval = requires_approval
364368
self.metadata = metadata
369+
self.timeout = timeout
365370

366371
@classmethod
367372
def from_schema(
@@ -417,6 +422,7 @@ def tool_def(self):
417422
strict=self.strict,
418423
sequential=self.sequential,
419424
metadata=self.metadata,
425+
timeout=self.timeout,
420426
kind='unapproved' if self.requires_approval else 'function',
421427
)
422428

@@ -503,6 +509,13 @@ class ToolDefinition:
503509
For MCP tools, this contains the `meta`, `annotations`, and `output_schema` fields from the tool definition.
504510
"""
505511

512+
timeout: float | None = None
513+
"""Timeout in seconds for tool execution.
514+
515+
If the tool takes longer than this, a retry prompt is returned to the model.
516+
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
517+
"""
518+
506519
@property
507520
def defer(self) -> bool:
508521
"""Whether calls to this tool will be deferred.

pydantic_ai_slim/pydantic_ai/toolsets/abstract.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ class ToolsetTool(Generic[AgentDepsT]):
5757
5858
For example, a [`pydantic.TypeAdapter(...).validator`](https://docs.pydantic.dev/latest/concepts/type_adapter/) or [`pydantic_core.SchemaValidator`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.SchemaValidator).
5959
"""
60+
timeout: float | None = None
61+
"""Timeout in seconds for tool execution.
62+
63+
If the tool takes longer than this, a retry prompt is returned to the model.
64+
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
65+
"""
6066

6167

6268
class AbstractToolset(ABC, Generic[AgentDepsT]):

pydantic_ai_slim/pydantic_ai/toolsets/combined.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
8080
args_validator=tool.args_validator,
8181
source_toolset=toolset,
8282
source_tool=tool,
83+
timeout=tool.timeout,
8384
)
8485
return all_tools
8586

pydantic_ai_slim/pydantic_ai/toolsets/function.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def tool(
119119
sequential: bool | None = None,
120120
requires_approval: bool | None = None,
121121
metadata: dict[str, Any] | None = None,
122+
timeout: float | None = None,
122123
) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ...
123124

124125
def tool(
@@ -137,6 +138,7 @@ def tool(
137138
sequential: bool | None = None,
138139
requires_approval: bool | None = None,
139140
metadata: dict[str, Any] | None = None,
141+
timeout: float | None = None,
140142
) -> Any:
141143
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
142144
@@ -193,6 +195,8 @@ async def spam(ctx: RunContext[str], y: float) -> float:
193195
If `None`, the default value is determined by the toolset.
194196
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
195197
If `None`, the default value is determined by the toolset. If provided, it will be merged with the toolset's metadata.
198+
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
199+
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
196200
"""
197201

198202
def tool_decorator(
@@ -213,6 +217,7 @@ def tool_decorator(
213217
sequential=sequential,
214218
requires_approval=requires_approval,
215219
metadata=metadata,
220+
timeout=timeout,
216221
)
217222
return func_
218223

@@ -233,6 +238,7 @@ def add_function(
233238
sequential: bool | None = None,
234239
requires_approval: bool | None = None,
235240
metadata: dict[str, Any] | None = None,
241+
timeout: float | None = None,
236242
) -> None:
237243
"""Add a function as a tool to the toolset.
238244
@@ -267,6 +273,8 @@ def add_function(
267273
If `None`, the default value is determined by the toolset.
268274
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
269275
If `None`, the default value is determined by the toolset. If provided, it will be merged with the toolset's metadata.
276+
timeout: Timeout in seconds for tool execution. If the tool takes longer, a retry prompt is returned to the model.
277+
Defaults to None (no timeout). Overrides the agent-level `tool_timeout` if set.
270278
"""
271279
if docstring_format is None:
272280
docstring_format = self.docstring_format
@@ -295,6 +303,7 @@ def add_function(
295303
sequential=sequential,
296304
requires_approval=requires_approval,
297305
metadata=metadata,
306+
timeout=timeout,
298307
)
299308
self.add_tool(tool)
300309

@@ -340,6 +349,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
340349
args_validator=tool.function_schema.validator,
341350
call_func=tool.function_schema.call,
342351
is_async=tool.function_schema.is_async,
352+
timeout=tool_def.timeout,
343353
)
344354
return tools
345355

tests/models/test_model_request_parameters.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test_model_request_parameters_are_serializable():
6868
'sequential': False,
6969
'kind': 'function',
7070
'metadata': None,
71+
'timeout': None,
7172
}
7273
],
7374
'builtin_tools': [
@@ -131,6 +132,7 @@ def test_model_request_parameters_are_serializable():
131132
'sequential': False,
132133
'kind': 'function',
133134
'metadata': None,
135+
'timeout': None,
134136
}
135137
],
136138
'prompted_output_template': None,

tests/test_logfire.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ async def my_ret(x: int) -> str:
547547
'sequential': False,
548548
'kind': 'function',
549549
'metadata': None,
550+
'timeout': None,
550551
}
551552
],
552553
'builtin_tools': [],
@@ -994,6 +995,7 @@ class MyOutput:
994995
'sequential': False,
995996
'kind': 'output',
996997
'metadata': None,
998+
'timeout': None,
997999
}
9981000
],
9991001
'prompted_output_template': None,

0 commit comments

Comments
 (0)