-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Allow multiple models in temporal agent #3537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@DouweM let me know if this is more along the lines of what you are thinking. I did test this locally as well in our repo and it worked as expected. |
e089f1e to
ef40b85
Compare
|
@mattbrandman Thanks for working on this Matt! A few high level thoughts:
|
f6a2ec9 to
9445096
Compare
|
@DouweM changed the PR to be more inline with your comments above |
|
Confirmed this works locally. One thing that does appear to need updating but I'm not entirely sure where is that telemetry is printing the default model registered |
d1d85b9 to
4c7e487
Compare
| *, | ||
| name: str | None = None, | ||
| additional_models: Mapping[str, Model | models.KnownModelName | str] | ||
| | Sequence[models.KnownModelName | str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model strings that will just be passed to infer_model don't really need to be pre-registered, do they? So we can probably drop this type from the union
| - When providing a sequence, only provider/model strings are allowed and they will be referenced by their literal string when calling `run(model=...)`. | ||
| Model instances must be registered via the mapping form so they can be referenced by name. | ||
| provider_factory: | ||
| Optional callable used when instantiating models from provider strings (both pre-registered and those supplied at runtime). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this also used on the model name the agent is initially created with? I would kind of expect it to
|
|
||
| self._registered_model_instances: dict[str, Model] = {'default': wrapped_model} | ||
| self._registered_model_names: dict[str, str] = {} | ||
| self._default_selection = ModelSelection(model_key='default', model_name=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if we need a whole ModelSelection object. I'd just store a single string, check if it's a key in the model instance map, and if not, call infer_model on it. But there may be another purpose to this that I'm missing
| self._register_additional_model(key, value) | ||
| else: | ||
| for value in additional_models: | ||
| if not isinstance(value, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type already makes this impossible right? We typically don't check things the type checker would've caught already
| deps_type=self.deps_type, | ||
| run_context_type=self.run_context_type, | ||
| event_stream_handler=self.event_stream_handler, | ||
| model_selection_var=self._model_selection_var, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of having a context var that's owned at this level and then passed into the TemporalModel, could TemporalModel own the context var and have a contextmanager method that we could use like with self._temporal_model.using_model('...'):?
| key = self._normalize_model_name(name) | ||
| self._registered_model_instances[key] = model | ||
|
|
||
| def _register_string_model(self, name: str, model_identifier: str) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not necessary; see above
| self._provider_factory = provider_factory | ||
|
|
||
| request_activity_name = f'{activity_name_prefix}__model_request' | ||
| request_stream_activity_name = f'{activity_name_prefix}__model_request_stream' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we don't edit this anymore, I'd rather move them back down where they were
| self.request_activity.__annotations__['deps'] = deps_type | ||
|
|
||
| async def request_stream_activity(params: _RequestParams, deps: AgentDepsT) -> ModelResponse: | ||
| async def request_stream_activity(params: _RequestParams, deps: Any) -> ModelResponse: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary change?
25c4692 to
bcd5b79
Compare
|
@DouweM made updates based on feedback. |
501af27 to
26269ac
Compare
This pull request introduces enhanced support for model selection and management within Temporal agents, as well as improved handling of run context propagation. The main changes allow registering multiple models with a Temporal agent, selecting models by name or provider string at runtime (inside workflows), and ensuring the current run context is properly tracked across async boundaries. These improvements make it easier to use and configure multiple models in Temporal workflows, while maintaining safety and clarity in model selection.
Model selection and registration for Temporal agents:
Added support for registering multiple models with a Temporal agent via the new
additional_modelsargument, and for selecting a model by name or provider string at runtime within workflows. This includes validation to prevent duplicate or invalid model names and ensures that only registered models or provider strings can be selected during workflow execution. (pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py, [1] [2] [3] [4] [5] [6] [7] [8]Introduced the
TemporalProviderFactorytype and support for passing a provider factory to Temporal agents and models, enabling custom provider instantiation logic (e.g., injecting API keys from dependencies) when resolving models from provider strings. (pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py, [1] [2];pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py, [3] [4]Model selection logic in Temporal model activities:
pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py, pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.pyR81-R112)Run context propagation improvements:
CURRENT_RUN_CONTEXTcontext variable to track the current run context across asynchronous boundaries, and updated agent graph methods to set and reset this variable during model requests and streaming. This ensures that context-dependent logic (such as provider factories) has access to the correct run context throughout execution. (pydantic_ai_slim/pydantic_ai/_run_context.py, [1] [2];pydantic_ai_slim/pydantic_ai/_agent_graph.py, [3] [4] [5] [6]Other improvements and minor changes:
pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py, [1];pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py, [2] [3]These changes collectively make Temporal agents more flexible, robust, and easier to configure for advanced use cases involving multiple models and dynamic provider selection.