11from __future__ import annotations
22
3- from collections .abc import AsyncIterator , Callable , Mapping
4- from contextlib import asynccontextmanager
3+ from collections .abc import AsyncIterator , Callable , Iterator , Mapping
4+ from contextlib import asynccontextmanager , contextmanager
55from contextvars import ContextVar
66from dataclasses import dataclass
77from datetime import datetime
@@ -33,14 +33,7 @@ class _RequestParams:
3333 model_settings : dict [str , Any ] | None
3434 model_request_parameters : ModelRequestParameters
3535 serialized_run_context : Any | None = None
36- model_key : str | None = None
37- model_name : str | None = None
38-
39-
40- @dataclass (frozen = True )
41- class ModelSelection :
42- model_key : str | None = None
43- model_name : str | None = None
36+ model_selection : str | None = None
4437
4538
4639TemporalProviderFactory = Callable [[str , RunContext [Any ] | None , Any | None ], Provider [Any ]]
@@ -85,22 +78,18 @@ def __init__(
8578 deps_type : type [AgentDepsT ],
8679 run_context_type : type [TemporalRunContext [AgentDepsT ]] = TemporalRunContext [AgentDepsT ],
8780 event_stream_handler : EventStreamHandler [Any ] | None = None ,
88- model_selection_var : ContextVar [ModelSelection ] | None = None ,
8981 model_instances : Mapping [str , Model ] | None = None ,
9082 provider_factory : TemporalProviderFactory | None = None ,
9183 ):
9284 super ().__init__ (model )
9385 self .activity_config = activity_config
9486 self .run_context_type = run_context_type
9587 self .event_stream_handler = event_stream_handler
96- self ._model_selection_var = model_selection_var
88+ self ._model_selection_var : ContextVar [ str | None ] = ContextVar ( '_temporal_model_selection' , default = None )
9789 self ._model_instances = dict (model_instances or {})
9890 self ._provider_factory = provider_factory
9991
100- request_activity_name = f'{ activity_name_prefix } __model_request'
101- request_stream_activity_name = f'{ activity_name_prefix } __model_request_stream'
102-
103- @activity .defn (name = request_activity_name )
92+ @activity .defn (name = f'{ activity_name_prefix } __model_request' )
10493 async def request_activity (params : _RequestParams , deps : Any ) -> ModelResponse :
10594 model_for_request = self ._resolve_model (params , deps )
10695 return await model_for_request .request (
@@ -112,7 +101,7 @@ async def request_activity(params: _RequestParams, deps: Any) -> ModelResponse:
112101 self .request_activity = request_activity
113102 self .request_activity .__annotations__ ['deps' ] = deps_type
114103
115- async def request_stream_activity (params : _RequestParams , deps : Any ) -> ModelResponse :
104+ async def request_stream_activity (params : _RequestParams , deps : AgentDepsT ) -> ModelResponse :
116105 # An error is raised in `request_stream` if no `event_stream_handler` is set.
117106 assert self .event_stream_handler is not None
118107
@@ -135,7 +124,9 @@ async def request_stream_activity(params: _RequestParams, deps: Any) -> ModelRes
135124 # Set type hint explicitly so that Temporal can take care of serialization and deserialization
136125 request_stream_activity .__annotations__ ['deps' ] = deps_type
137126
138- self .request_stream_activity = activity .defn (name = request_stream_activity_name )(request_stream_activity )
127+ self .request_stream_activity = activity .defn (name = f'{ activity_name_prefix } __model_request_stream' )(
128+ request_stream_activity
129+ )
139130
140131 @property
141132 def temporal_activities (self ) -> list [Callable [..., Any ]]:
@@ -168,8 +159,7 @@ async def request(
168159 model_settings = cast (dict [str , Any ] | None , model_settings ),
169160 model_request_parameters = model_request_parameters ,
170161 serialized_run_context = serialized_run_context ,
171- model_key = selection .model_key ,
172- model_name = selection .model_name ,
162+ model_selection = selection ,
173163 ),
174164 deps ,
175165 ],
@@ -212,8 +202,7 @@ async def request_stream(
212202 model_settings = cast (dict [str , Any ] | None , model_settings ),
213203 model_request_parameters = model_request_parameters ,
214204 serialized_run_context = serialized_run_context ,
215- model_key = selection .model_key ,
216- model_name = selection .model_name ,
205+ model_selection = selection ,
217206 ),
218207 run_context .deps ,
219208 ],
@@ -225,27 +214,27 @@ def _validate_model_request_parameters(self, model_request_parameters: ModelRequ
225214 if model_request_parameters .allow_image_output :
226215 raise UserError ('Image output is not supported with Temporal because of the 2MB payload size limit.' )
227216
228- def _current_selection (self ) -> ModelSelection :
229- if self ._model_selection_var is None :
230- return ModelSelection (model_key = 'default' , model_name = None )
231- selection = self ._model_selection_var .get ()
232- if selection .model_key is None and selection .model_name is None :
233- return ModelSelection (model_key = 'default' , model_name = None )
234- return selection
217+ @contextmanager
218+ def using_model (self , selection : str | None ) -> Iterator [None ]:
219+ """Context manager to set the model selection for the duration of a block."""
220+ token = self ._model_selection_var .set (selection )
221+ try :
222+ yield
223+ finally :
224+ self ._model_selection_var .reset (token )
225+
226+ def _current_selection (self ) -> str | None :
227+ return self ._model_selection_var .get ()
235228
236229 def _resolve_model (self , params : _RequestParams , deps : Any | None ) -> Model :
237- if params .model_key :
238- if params .model_key in self ._model_instances :
239- return self ._model_instances [params .model_key ]
240- raise UserError (
241- f'Model "{ params .model_key } " is not registered with this TemporalAgent. '
242- 'Register model instances using `additional_models` when constructing the agent.'
243- )
230+ selection = params .model_selection
231+ if selection is None :
232+ return self .wrapped
244233
245- if params . model_name :
246- return self ._infer_model ( params . model_name , params , deps )
234+ if selection in self . _model_instances :
235+ return self ._model_instances [ selection ]
247236
248- return self .wrapped
237+ return self ._infer_model ( selection , params , deps )
249238
250239 def _infer_model (self , model_name : str , params : _RequestParams , deps : Any | None ) -> Model :
251240 run_context : RunContext [Any ] | None = None
0 commit comments