Skip to content

Commit bae1ba4

Browse files
committed
update
1 parent 53fcaa3 commit bae1ba4

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

cozeloop/integration/langchain/trace_model/llm_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,15 @@ def __init__(self, messages: List[Union[BaseMessage, List[BaseMessage]]], invoca
129129
for tool_call in message.additional_kwargs.get('tool_calls', []):
130130
if tool_call.get('id', ''):
131131
tool_call_id_name_map[tool_call.get('id', '')] = tool_call.get('function', {}).get('name', '')
132-
for tool_call in message.get('tool_calls', []):
132+
for tool_call in message.tool_calls:
133133
if tool_call.get('id', ''):
134134
tool_call_id_name_map[tool_call.get('id', '')] = tool_call.get('name', '')
135135

136136
for message in process_messages:
137137
if isinstance(message, (AIMessageChunk, AIMessage)):
138138
tool_calls = convert_tool_calls_by_additional_kwargs(message.additional_kwargs.get('tool_calls', []))
139139
if len(tool_calls) == 0:
140-
tool_calls = convert_tool_calls_by_raw(message.get('tool_calls', []))
140+
tool_calls = convert_tool_calls_by_raw(message.tool_calls)
141141
self._messages.append(Message(role=message.type, content=message.content, tool_calls=tool_calls))
142142
elif isinstance(message, ToolMessage):
143143
name = ''
@@ -200,14 +200,14 @@ def to_json(self):
200200
def convert_tool_calls_by_raw(tool_calls: list) -> List[ToolCall]:
201201
format_tool_calls: List[ToolCall] = []
202202
for tool_call in tool_calls:
203-
function = ToolFunction(name=tool_call.get('name', ''), arguments=json.loads(tool_call.get('args', {})))
203+
function = ToolFunction(name=tool_call.get('name', ''), arguments=tool_call.get('args', {}))
204204
format_tool_calls.append(ToolCall(id=tool_call.get('id', ''), type=tool_call.get('type', ''), function=function))
205205
return format_tool_calls
206206

207207

208208
def convert_tool_calls_by_additional_kwargs(tool_calls: list) -> List[ToolCall]:
209209
format_tool_calls: List[ToolCall] = []
210210
for tool_call in tool_calls:
211-
function = ToolFunction(name=tool_call.get('function', {}).get('name', ''), arguments=json.loads(tool_call.get('function', {}).get('arguments', {})))
211+
function = ToolFunction(name=tool_call.get('function', {}).get('name', ''), arguments=json.loads(tool_call.get('function', {}).get('arguments', '{}')))
212212
format_tool_calls.append(ToolCall(id=tool_call.get('id', ''), type=tool_call.get('type', ''), function=function))
213213
return format_tool_calls

0 commit comments

Comments
 (0)