@@ -129,15 +129,15 @@ def __init__(self, messages: List[Union[BaseMessage, List[BaseMessage]]], invoca
129129 for tool_call in message .additional_kwargs .get ('tool_calls' , []):
130130 if tool_call .get ('id' , '' ):
131131 tool_call_id_name_map [tool_call .get ('id' , '' )] = tool_call .get ('function' , {}).get ('name' , '' )
132- for tool_call in message .get ( ' tool_calls' , []) :
132+ for tool_call in message .tool_calls :
133133 if tool_call .get ('id' , '' ):
134134 tool_call_id_name_map [tool_call .get ('id' , '' )] = tool_call .get ('name' , '' )
135135
136136 for message in process_messages :
137137 if isinstance (message , (AIMessageChunk , AIMessage )):
138138 tool_calls = convert_tool_calls_by_additional_kwargs (message .additional_kwargs .get ('tool_calls' , []))
139139 if len (tool_calls ) == 0 :
140- tool_calls = convert_tool_calls_by_raw (message .get ( ' tool_calls' , []) )
140+ tool_calls = convert_tool_calls_by_raw (message .tool_calls )
141141 self ._messages .append (Message (role = message .type , content = message .content , tool_calls = tool_calls ))
142142 elif isinstance (message , ToolMessage ):
143143 name = ''
@@ -200,14 +200,14 @@ def to_json(self):
200200def convert_tool_calls_by_raw (tool_calls : list ) -> List [ToolCall ]:
201201 format_tool_calls : List [ToolCall ] = []
202202 for tool_call in tool_calls :
203- function = ToolFunction (name = tool_call .get ('name' , '' ), arguments = json . loads ( tool_call .get ('args' , {}) ))
203+ function = ToolFunction (name = tool_call .get ('name' , '' ), arguments = tool_call .get ('args' , {}))
204204 format_tool_calls .append (ToolCall (id = tool_call .get ('id' , '' ), type = tool_call .get ('type' , '' ), function = function ))
205205 return format_tool_calls
206206
207207
208208def convert_tool_calls_by_additional_kwargs (tool_calls : list ) -> List [ToolCall ]:
209209 format_tool_calls : List [ToolCall ] = []
210210 for tool_call in tool_calls :
211- function = ToolFunction (name = tool_call .get ('function' , {}).get ('name' , '' ), arguments = json .loads (tool_call .get ('function' , {}).get ('arguments' , {} )))
211+ function = ToolFunction (name = tool_call .get ('function' , {}).get ('name' , '' ), arguments = json .loads (tool_call .get ('function' , {}).get ('arguments' , '{}' )))
212212 format_tool_calls .append (ToolCall (id = tool_call .get ('id' , '' ), type = tool_call .get ('type' , '' ), function = function ))
213213 return format_tool_calls
0 commit comments