Skip to content

Commit e642d7d

Browse files
committed
enhance input tool_calls obtain
1 parent 0d10489 commit e642d7d

File tree

1 file changed

+33
-6
lines changed

1 file changed

+33
-6
lines changed

cozeloop/integration/langchain/trace_model/llm_model.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from typing import List, Optional, Union, Dict, Any
77
from pydantic.dataclasses import dataclass
8-
from langchain_core.messages import BaseMessage, ToolMessage, AIMessageChunk
8+
from langchain_core.messages import BaseMessage, ToolMessage, AIMessageChunk, AIMessage
99
from langchain_core.outputs import Generation, ChatGeneration
1010

1111

@@ -122,11 +122,30 @@ def __init__(self, messages: List[Union[BaseMessage, List[BaseMessage]]], invoca
122122
elif isinstance(inner_messages, List):
123123
for message in inner_messages:
124124
process_messages.append(message)
125+
126+
tool_call_id_name_map = {}
127+
for message in process_messages:
128+
if isinstance(message, (AIMessageChunk, AIMessage)):
129+
for tool_call in message.additional_kwargs.get('tool_calls', []):
130+
if tool_call.get('id', ''):
131+
tool_call_id_name_map[tool_call.get('id', '')] = tool_call.get('function', {}).get('name', '')
132+
for tool_call in message.get('tool_calls', []):
133+
if tool_call.get('id', ''):
134+
tool_call_id_name_map[tool_call.get('id', '')] = tool_call.get('name', '')
135+
125136
for message in process_messages:
126-
if isinstance(message, AIMessageChunk):
127-
self._messages.append(Message(role=message.type, content=message.content, tool_calls=convert_tool_calls(message.additional_kwargs.get('tool_calls', []))))
137+
if isinstance(message, (AIMessageChunk, AIMessage)):
138+
tool_calls = convert_tool_calls_by_additional_kwargs(message.additional_kwargs.get('tool_calls', []))
139+
if len(tool_calls) == 0:
140+
tool_calls = convert_tool_calls_by_raw(message.get('tool_calls', []))
141+
self._messages.append(Message(role=message.type, content=message.content, tool_calls=tool_calls))
128142
elif isinstance(message, ToolMessage):
129-
tool_call = ToolCall(id=message.tool_call_id, type=message.type, function= ToolFunction(name=message.additional_kwargs.get('name', '')))
143+
name = ''
144+
if tool_call_id_name_map.get(message.tool_call_id, None) is not None:
145+
name = tool_call_id_name_map[message.tool_call_id]
146+
if message.additional_kwargs.get('name', ''):
147+
name = message.additional_kwargs.get('name', '')
148+
tool_call = ToolCall(id=message.tool_call_id, type=message.type, function=ToolFunction(name=name))
130149
self._messages.append(Message(role=message.type, content=message.content, tool_calls=[tool_call]))
131150
else:
132151
self._messages.append(Message(role=message.type, content=message.content))
@@ -161,7 +180,7 @@ def to_json(self):
161180
for i, generation in enumerate(self.generations):
162181
choice: Choice = None
163182
if isinstance(generation, ChatGeneration):
164-
tool_calls = convert_tool_calls(generation.message.additional_kwargs.get('tool_calls', []))
183+
tool_calls = convert_tool_calls_by_additional_kwargs(generation.message.additional_kwargs.get('tool_calls', []))
165184
if len(tool_calls) == 0 and 'function_call' in generation.message.additional_kwargs:
166185
function_call = generation.message.additional_kwargs.get('function_call', {})
167186
function = ToolFunction(name=function_call.get('name', ''), arguments=json.loads(function_call.get('arguments', {})))
@@ -178,7 +197,15 @@ def to_json(self):
178197
ensure_ascii=False)
179198

180199

181-
def convert_tool_calls(tool_calls: list) -> List[ToolCall]:
200+
def convert_tool_calls_by_raw(tool_calls: list) -> List[ToolCall]:
201+
format_tool_calls: List[ToolCall] = []
202+
for tool_call in tool_calls:
203+
function = ToolFunction(name=tool_call.get('name', ''), arguments=json.loads(tool_call.get('args', {})))
204+
format_tool_calls.append(ToolCall(id=tool_call.get('id', ''), type=tool_call.get('type', ''), function=function))
205+
return format_tool_calls
206+
207+
208+
def convert_tool_calls_by_additional_kwargs(tool_calls: list) -> List[ToolCall]:
182209
format_tool_calls: List[ToolCall] = []
183210
for tool_call in tool_calls:
184211
function = ToolFunction(name=tool_call.get('function', {}).get('name', ''), arguments=json.loads(tool_call.get('function', {}).get('arguments', {})))

0 commit comments

Comments
 (0)