33
44from openai import AsyncAzureOpenAI , AsyncOpenAI
55from openai .types .chat import ChatCompletionMessageParam
6- from pydantic_ai import Agent , RunContext
6+ from pydantic_ai import Agent
77from pydantic_ai .messages import ModelMessagesTypeAdapter
88from pydantic_ai .models .openai import OpenAIModel
99from pydantic_ai .providers .openai import OpenAIProvider
1010from pydantic_ai .settings import ModelSettings
1111
1212from fastapi_app .api_models import (
1313 AIChatRoles ,
14- BrandFilter ,
1514 ChatRequestOverrides ,
1615 Filter ,
1716 ItemPublic ,
1817 Message ,
19- PriceFilter ,
2018 RAGContext ,
2119 RetrievalResponse ,
2220 RetrievalResponseDelta ,
21+ SearchArguments ,
2322 SearchResults ,
2423 ThoughtStep ,
2524)
@@ -59,7 +58,7 @@ def __init__(
5958 ),
6059 system_prompt = self .query_prompt_template ,
6160 tools = [self .search_database ],
62- output_type = SearchResults ,
61+ output_type = SearchArguments ,
6362 )
6463 self .answer_agent = Agent (
6564 pydantic_chat_model ,
@@ -73,10 +72,7 @@ def __init__(
7372
7473 async def search_database (
7574 self ,
76- ctx : RunContext [ChatParams ],
77- search_query : str ,
78- price_filter : Optional [PriceFilter ] = None ,
79- brand_filter : Optional [BrandFilter ] = None ,
75+ search_arguments : SearchArguments ,
8076 ) -> SearchResults :
8177 """
8278 Search PostgreSQL database for relevant products based on user query
@@ -91,52 +87,55 @@ async def search_database(
9187 """
9288 # Only send non-None filters
9389 filters : list [Filter ] = []
94- if price_filter :
95- filters .append (price_filter )
96- if brand_filter :
97- filters .append (brand_filter )
90+ if search_arguments . price_filter :
91+ filters .append (search_arguments . price_filter )
92+ if search_arguments . brand_filter :
93+ filters .append (search_arguments . brand_filter )
9894 results = await self .searcher .search_and_embed (
99- search_query ,
100- top = ctx . deps .top ,
101- enable_vector_search = ctx . deps .enable_vector_search ,
102- enable_text_search = ctx . deps .enable_text_search ,
95+ search_arguments . search_query ,
96+ top = self . chat_params .top ,
97+ enable_vector_search = self . chat_params .enable_vector_search ,
98+ enable_text_search = self . chat_params .enable_text_search ,
10399 filters = filters ,
104100 )
105101 return SearchResults (
106- query = search_query , items = [ItemPublic .model_validate (item .to_dict ()) for item in results ], filters = filters
102+ query = search_arguments .search_query ,
103+ items = [ItemPublic .model_validate (item .to_dict ()) for item in results ],
104+ filters = filters ,
107105 )
108106
109107 async def prepare_context (self ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
110108 few_shots = ModelMessagesTypeAdapter .validate_json (self .query_fewshots )
111109 user_query = f"Find search results for user query: { self .chat_params .original_user_query } "
112- results = await self .search_agent .run (
110+ search_agent_runner = await self .search_agent .run (
113111 user_query ,
114112 message_history = few_shots + self .chat_params .past_messages ,
115- deps = self . chat_params ,
113+ output_type = SearchArguments ,
116114 )
117- items = results .output .items
115+ search_arguments = search_agent_runner .output
116+ search_results = await self .search_database (search_arguments = search_arguments )
118117 thoughts = [
119118 ThoughtStep (
120119 title = "Prompt to generate search arguments" ,
121- description = results .all_messages (),
120+ description = search_agent_runner .all_messages (),
122121 props = self .model_for_thoughts ,
123122 ),
124123 ThoughtStep (
125124 title = "Search using generated search arguments" ,
126- description = results . output .query ,
125+ description = search_results .query ,
127126 props = {
128127 "top" : self .chat_params .top ,
129128 "vector_search" : self .chat_params .enable_vector_search ,
130129 "text_search" : self .chat_params .enable_text_search ,
131- "filters" : results . output .filters ,
130+ "filters" : search_results .filters ,
132131 },
133132 ),
134133 ThoughtStep (
135134 title = "Search results" ,
136- description = items ,
135+ description = search_results . items ,
137136 ),
138137 ]
139- return items , thoughts
138+ return search_results . items , thoughts
140139
141140 async def answer (
142141 self ,
0 commit comments