Skip to content

Commit 25c4692

Browse files
committed
fixes types
1 parent 5d35547 commit 25c4692

File tree

3 files changed

+93
-118
lines changed

3 files changed

+93
-118
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
ToolFuncEither,
4040
)
4141

42-
from ._model import ModelSelection, TemporalModel, TemporalProviderFactory
42+
from ._model import TemporalModel, TemporalProviderFactory
4343
from ._run_context import TemporalRunContext
4444
from ._toolset import TemporalWrapperToolset, temporalize_toolset
4545

@@ -124,11 +124,6 @@ def __init__(
124124
)
125125

126126
self._registered_model_instances: dict[str, Model] = {'default': wrapped_model}
127-
self._default_selection = ModelSelection(model_key='default', model_name=None)
128-
self._model_selection_var: ContextVar[ModelSelection] = ContextVar(
129-
'_temporal_model_selection',
130-
default=self._default_selection,
131-
)
132127

133128
# start_to_close_timeout is required
134129
activity_config = activity_config or ActivityConfig(start_to_close_timeout=timedelta(seconds=60))
@@ -182,7 +177,6 @@ async def streamed_response():
182177
deps_type=self.deps_type,
183178
run_context_type=self.run_context_type,
184179
event_stream_handler=self.event_stream_handler,
185-
model_selection_var=self._model_selection_var,
186180
model_instances=self._registered_model_instances,
187181
provider_factory=provider_factory,
188182
)
@@ -229,24 +223,22 @@ def _normalize_model_name(self, name: str) -> str:
229223
raise UserError(f'Duplicate model name {normalized!r} provided to `additional_models`.')
230224
return normalized
231225

232-
def _select_model(self, model: models.Model | models.KnownModelName | str | None = None) -> ModelSelection:
233-
"""Select the appropriate model based on the runtime parameter."""
234-
if model is None:
235-
return self._default_selection
226+
def _select_model(self, model: models.Model | models.KnownModelName | str | None = None) -> str | None:
227+
"""Select the appropriate model based on the runtime parameter.
228+
229+
Returns a string that will be checked against registered model instances,
230+
or passed to infer_model if not found. Returns None to use the default model.
231+
"""
232+
if model is None or model == 'default':
233+
return None
236234

237235
if isinstance(model, Model):
238236
raise UserError(
239237
'Model instances cannot be selected at runtime inside a Temporal workflow. '
240238
'Register the model via the mapping form of `additional_models` and reference it by name.'
241239
)
242240

243-
if model == 'default':
244-
return self._default_selection
245-
246-
if model in self._registered_model_instances:
247-
return ModelSelection(model_key=model, model_name=None)
248-
249-
return ModelSelection(model_key=None, model_name=model)
241+
return model
250242

251243
@property
252244
def name(self) -> str | None:
@@ -299,11 +291,13 @@ def temporal_activities(self) -> list[Callable[..., Any]]:
299291
return self._temporal_activities
300292

301293
@contextmanager
302-
def _temporal_overrides(self, selection: ModelSelection | None = None) -> Iterator[None]:
294+
def _temporal_overrides(self, selection: str | None = None) -> Iterator[None]:
303295
# We reset tools here as the temporalized function toolset is already in self._toolsets.
304-
with super().override(model=self._temporal_model, toolsets=self._toolsets, tools=[]):
296+
with (
297+
super().override(model=self._temporal_model, toolsets=self._toolsets, tools=[]),
298+
self._temporal_model.using_model(selection),
299+
):
305300
token = self._temporal_overrides_active.set(True)
306-
selection_token = self._model_selection_var.set(selection or self._default_selection)
307301
try:
308302
yield
309303
except PydanticSerializationError as e:
@@ -312,7 +306,6 @@ def _temporal_overrides(self, selection: ModelSelection | None = None) -> Iterat
312306
) from e
313307
finally:
314308
self._temporal_overrides_active.reset(token)
315-
self._model_selection_var.reset(selection_token)
316309

317310
@overload
318311
async def run(
@@ -416,12 +409,12 @@ async def main():
416409
'Event stream handler cannot be set at agent run time inside a Temporal workflow, it must be set at agent creation time.'
417410
)
418411

419-
selection = self._default_selection
412+
selection: str | None = None
420413
if workflow.in_workflow() and model is not None:
421414
selection = self._select_model(model)
422415
model = None
423416

424-
with self._temporal_overrides(selection if workflow.in_workflow() else self._default_selection):
417+
with self._temporal_overrides(selection if workflow.in_workflow() else None):
425418
return await super().run(
426419
user_prompt,
427420
output_type=output_type,

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from collections.abc import AsyncIterator, Callable, Mapping
4-
from contextlib import asynccontextmanager
3+
from collections.abc import AsyncIterator, Callable, Iterator, Mapping
4+
from contextlib import asynccontextmanager, contextmanager
55
from contextvars import ContextVar
66
from dataclasses import dataclass
77
from datetime import datetime
@@ -33,14 +33,7 @@ class _RequestParams:
3333
model_settings: dict[str, Any] | None
3434
model_request_parameters: ModelRequestParameters
3535
serialized_run_context: Any | None = None
36-
model_key: str | None = None
37-
model_name: str | None = None
38-
39-
40-
@dataclass(frozen=True)
41-
class ModelSelection:
42-
model_key: str | None = None
43-
model_name: str | None = None
36+
model_selection: str | None = None
4437

4538

4639
TemporalProviderFactory = Callable[[str, RunContext[Any] | None, Any | None], Provider[Any]]
@@ -85,22 +78,18 @@ def __init__(
8578
deps_type: type[AgentDepsT],
8679
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
8780
event_stream_handler: EventStreamHandler[Any] | None = None,
88-
model_selection_var: ContextVar[ModelSelection] | None = None,
8981
model_instances: Mapping[str, Model] | None = None,
9082
provider_factory: TemporalProviderFactory | None = None,
9183
):
9284
super().__init__(model)
9385
self.activity_config = activity_config
9486
self.run_context_type = run_context_type
9587
self.event_stream_handler = event_stream_handler
96-
self._model_selection_var = model_selection_var
88+
self._model_selection_var: ContextVar[str | None] = ContextVar('_temporal_model_selection', default=None)
9789
self._model_instances = dict(model_instances or {})
9890
self._provider_factory = provider_factory
9991

100-
request_activity_name = f'{activity_name_prefix}__model_request'
101-
request_stream_activity_name = f'{activity_name_prefix}__model_request_stream'
102-
103-
@activity.defn(name=request_activity_name)
92+
@activity.defn(name=f'{activity_name_prefix}__model_request')
10493
async def request_activity(params: _RequestParams, deps: Any) -> ModelResponse:
10594
model_for_request = self._resolve_model(params, deps)
10695
return await model_for_request.request(
@@ -112,7 +101,7 @@ async def request_activity(params: _RequestParams, deps: Any) -> ModelResponse:
112101
self.request_activity = request_activity
113102
self.request_activity.__annotations__['deps'] = deps_type
114103

115-
async def request_stream_activity(params: _RequestParams, deps: Any) -> ModelResponse:
104+
async def request_stream_activity(params: _RequestParams, deps: AgentDepsT) -> ModelResponse:
116105
# An error is raised in `request_stream` if no `event_stream_handler` is set.
117106
assert self.event_stream_handler is not None
118107

@@ -135,7 +124,9 @@ async def request_stream_activity(params: _RequestParams, deps: Any) -> ModelRes
135124
# Set type hint explicitly so that Temporal can take care of serialization and deserialization
136125
request_stream_activity.__annotations__['deps'] = deps_type
137126

138-
self.request_stream_activity = activity.defn(name=request_stream_activity_name)(request_stream_activity)
127+
self.request_stream_activity = activity.defn(name=f'{activity_name_prefix}__model_request_stream')(
128+
request_stream_activity
129+
)
139130

140131
@property
141132
def temporal_activities(self) -> list[Callable[..., Any]]:
@@ -168,8 +159,7 @@ async def request(
168159
model_settings=cast(dict[str, Any] | None, model_settings),
169160
model_request_parameters=model_request_parameters,
170161
serialized_run_context=serialized_run_context,
171-
model_key=selection.model_key,
172-
model_name=selection.model_name,
162+
model_selection=selection,
173163
),
174164
deps,
175165
],
@@ -212,8 +202,7 @@ async def request_stream(
212202
model_settings=cast(dict[str, Any] | None, model_settings),
213203
model_request_parameters=model_request_parameters,
214204
serialized_run_context=serialized_run_context,
215-
model_key=selection.model_key,
216-
model_name=selection.model_name,
205+
model_selection=selection,
217206
),
218207
run_context.deps,
219208
],
@@ -225,27 +214,27 @@ def _validate_model_request_parameters(self, model_request_parameters: ModelRequ
225214
if model_request_parameters.allow_image_output:
226215
raise UserError('Image output is not supported with Temporal because of the 2MB payload size limit.')
227216

228-
def _current_selection(self) -> ModelSelection:
229-
if self._model_selection_var is None:
230-
return ModelSelection(model_key='default', model_name=None)
231-
selection = self._model_selection_var.get()
232-
if selection.model_key is None and selection.model_name is None:
233-
return ModelSelection(model_key='default', model_name=None)
234-
return selection
217+
@contextmanager
218+
def using_model(self, selection: str | None) -> Iterator[None]:
219+
"""Context manager to set the model selection for the duration of a block."""
220+
token = self._model_selection_var.set(selection)
221+
try:
222+
yield
223+
finally:
224+
self._model_selection_var.reset(token)
225+
226+
def _current_selection(self) -> str | None:
227+
return self._model_selection_var.get()
235228

236229
def _resolve_model(self, params: _RequestParams, deps: Any | None) -> Model:
237-
if params.model_key:
238-
if params.model_key in self._model_instances:
239-
return self._model_instances[params.model_key]
240-
raise UserError(
241-
f'Model "{params.model_key}" is not registered with this TemporalAgent. '
242-
'Register model instances using `additional_models` when constructing the agent.'
243-
)
230+
selection = params.model_selection
231+
if selection is None:
232+
return self.wrapped
244233

245-
if params.model_name:
246-
return self._infer_model(params.model_name, params, deps)
234+
if selection in self._model_instances:
235+
return self._model_instances[selection]
247236

248-
return self.wrapped
237+
return self._infer_model(selection, params, deps)
249238

250239
def _infer_model(self, model_name: str, params: _RequestParams, deps: Any | None) -> Model:
251240
run_context: RunContext[Any] | None = None

0 commit comments

Comments
 (0)