Skip to content

Commit f6a2ec9

Browse files
committed
allow arbitrary models
1 parent ef40b85 commit f6a2ec9

File tree

5 files changed

+391
-345
lines changed

5 files changed

+391
-345
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pydantic_graph.nodes import End, NodeRunEndT
2727

2828
from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
29+
from ._run_context import CURRENT_RUN_CONTEXT
2930
from .exceptions import ToolRetryError
3031
from .output import OutputDataT, OutputSpec
3132
from .settings import ModelSettings
@@ -438,25 +439,29 @@ async def stream(
438439
assert not self._did_stream, 'stream() should only be called once per node'
439440

440441
model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx)
441-
async with ctx.deps.model.request_stream(
442-
message_history, model_settings, model_request_parameters, run_context
443-
) as streamed_response:
444-
self._did_stream = True
445-
ctx.state.usage.requests += 1
446-
agent_stream = result.AgentStream[DepsT, T](
447-
_raw_stream_response=streamed_response,
448-
_output_schema=ctx.deps.output_schema,
449-
_model_request_parameters=model_request_parameters,
450-
_output_validators=ctx.deps.output_validators,
451-
_run_ctx=build_run_context(ctx),
452-
_usage_limits=ctx.deps.usage_limits,
453-
_tool_manager=ctx.deps.tool_manager,
454-
)
455-
yield agent_stream
456-
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
457-
# otherwise usage won't be properly counted:
458-
async for _ in agent_stream:
459-
pass
442+
token = CURRENT_RUN_CONTEXT.set(run_context)
443+
try:
444+
async with ctx.deps.model.request_stream(
445+
message_history, model_settings, model_request_parameters, run_context
446+
) as streamed_response:
447+
self._did_stream = True
448+
ctx.state.usage.requests += 1
449+
agent_stream = result.AgentStream[DepsT, T](
450+
_raw_stream_response=streamed_response,
451+
_output_schema=ctx.deps.output_schema,
452+
_model_request_parameters=model_request_parameters,
453+
_output_validators=ctx.deps.output_validators,
454+
_run_ctx=build_run_context(ctx),
455+
_usage_limits=ctx.deps.usage_limits,
456+
_tool_manager=ctx.deps.tool_manager,
457+
)
458+
yield agent_stream
459+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
460+
# otherwise usage won't be properly counted:
461+
async for _ in agent_stream:
462+
pass
463+
finally:
464+
CURRENT_RUN_CONTEXT.reset(token)
460465

461466
model_response = streamed_response.get()
462467

@@ -469,8 +474,12 @@ async def _make_request(
469474
if self._result is not None:
470475
return self._result # pragma: no cover
471476

472-
model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
473-
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
477+
model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx)
478+
token = CURRENT_RUN_CONTEXT.set(run_context)
479+
try:
480+
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
481+
finally:
482+
CURRENT_RUN_CONTEXT.reset(token)
474483
ctx.state.usage.requests += 1
475484

476485
return self._finish_handling(ctx, model_response)

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import dataclasses
44
from collections.abc import Sequence
5+
from contextvars import ContextVar
56
from dataclasses import field
6-
from typing import TYPE_CHECKING, Generic
7+
from typing import TYPE_CHECKING, Any, Generic
78

89
from opentelemetry.trace import NoOpTracer, Tracer
910
from typing_extensions import TypeVar
@@ -69,3 +70,10 @@ def last_attempt(self) -> bool:
6970
return self.retry == self.max_retries
7071

7172
__repr__ = _utils.dataclasses_no_defaults_repr
73+
74+
75+
CURRENT_RUN_CONTEXT: ContextVar[RunContext[Any] | None] = ContextVar(
76+
'pydantic_ai.current_run_context',
77+
default=None,
78+
)
79+
"""Context variable storing the current [`RunContext`][pydantic_ai._run_context.RunContext]."""

0 commit comments

Comments
 (0)