Skip to content
56 changes: 53 additions & 3 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,54 @@ def convert_to_openai_messages(
return oai_messages


def _remove_orphaned_tool_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
"""Remove ToolMessages that don't have a corresponding AIMessage with tool_calls.

When trimming messages, we may accidentally orphan ToolMessages by removing
the AIMessage that made the tool call. This function cleans up such orphans
to maintain valid message history.

Args:
messages: List of messages to clean.

Returns:
List of messages with orphaned ToolMessages removed.
"""
if not messages:
return messages

# Build a set of valid tool_call_ids from AIMessages
valid_tool_call_ids: set[str] = set()
for msg in messages:
if isinstance(msg, AIMessage):
# Check tool_calls attribute
if msg.tool_calls:
for tool_call in msg.tool_calls:
if tool_call_id := tool_call.get("id"):
valid_tool_call_ids.add(tool_call_id)
# Also check content blocks for Anthropic format (tool_use blocks)
if isinstance(msg.content, list):
for block in msg.content:
if (
isinstance(block, dict)
and block.get("type") == "tool_use"
and block.get("id")
):
valid_tool_call_ids.add(block["id"])

# Filter out ToolMessages with invalid tool_call_ids
cleaned_messages: list[BaseMessage] = []
for msg in messages:
if isinstance(msg, ToolMessage):
if msg.tool_call_id in valid_tool_call_ids:
cleaned_messages.append(msg)
# else: skip orphaned ToolMessage
else:
cleaned_messages.append(msg)

return cleaned_messages


def _first_max_tokens(
messages: Sequence[BaseMessage],
*,
Expand All @@ -1413,7 +1461,7 @@ def _first_max_tokens(
messages.pop()
else:
break
return messages
return _remove_orphaned_tool_messages(messages)

# Use binary search to find the maximum number of messages within token limit
left, right = 0, len(messages)
Expand Down Expand Up @@ -1504,7 +1552,9 @@ def _first_max_tokens(
else:
break

return messages[:idx]
trimmed = messages[:idx]
# Remove any orphaned ToolMessages that lost their corresponding AIMessage
return _remove_orphaned_tool_messages(trimmed)


def _last_max_tokens(
Expand Down Expand Up @@ -1559,7 +1609,7 @@ def _last_max_tokens(
if system_message:
result = [system_message, *result]

return result
return _remove_orphaned_tool_messages(result)


_MSG_CHUNK_MAP: dict[type[BaseMessage], type[BaseMessageChunk]] = {
Expand Down
267 changes: 267 additions & 0 deletions libs/core/tests/unit_tests/messages/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,3 +1582,270 @@ def test_convert_to_openai_messages_reasoning_content() -> None:
],
}
assert mixed_result == expected_mixed


def test_trim_messages_removes_orphaned_tool_messages() -> None:
"""Test that trim_messages removes orphaned ToolMessages.

When the corresponding AIMessage with tool_calls is trimmed away, the
ToolMessage becomes orphaned and should be removed.

This is the exact scenario from the bug report:
https://github.com/langchain-ai/langchain/issues/xxxxx
"""
messages = [
HumanMessage("What's the weather in Florida?"),
AIMessage(
[
{"type": "text", "text": "Let's check the weather in Florida"},
{
"type": "tool_use",
"id": "abc123",
"name": "get_weather",
"input": {"location": "Florida"},
},
],
tool_calls=[
{
"name": "get_weather",
"args": {"location": "Florida"},
"id": "abc123",
"type": "tool_call",
},
],
),
ToolMessage(
"It's sunny.",
name="get_weather",
tool_call_id="abc123",
),
HumanMessage("I see"),
AIMessage("Do you want to know anything else?"),
HumanMessage("No, thanks"),
AIMessage("You're welcome! Have a great day!"),
]

# Trim to last 5 messages (by count)
trimmed_messages = trim_messages(
messages,
strategy="last",
token_counter=len,
max_tokens=5,
)

# Should NOT include the orphaned ToolMessage
# The AIMessage with tool_calls was removed, so ToolMessage should be removed too
assert len(trimmed_messages) == 4
assert isinstance(trimmed_messages[0], HumanMessage)
assert trimmed_messages[0].content == "I see"
assert isinstance(trimmed_messages[1], AIMessage)
assert trimmed_messages[1].content == "Do you want to know anything else?"
assert isinstance(trimmed_messages[2], HumanMessage)
assert trimmed_messages[2].content == "No, thanks"
assert isinstance(trimmed_messages[3], AIMessage)
assert trimmed_messages[3].content == "You're welcome! Have a great day!"

# Verify no ToolMessages in result
assert not any(isinstance(msg, ToolMessage) for msg in trimmed_messages)


def test_trim_messages_preserves_valid_tool_calls() -> None:
"""Test that valid tool call sequences are preserved when they fit in the budget."""
messages = [
HumanMessage("What's 2+2?"),
AIMessage(
"Let me calculate that",
tool_calls=[
{
"name": "calculator",
"args": {"expression": "2+2"},
"id": "calc1",
"type": "tool_call",
},
],
),
ToolMessage("4", name="calculator", tool_call_id="calc1"),
AIMessage("The answer is 4"),
]

# Trim to include all messages
trimmed_messages = trim_messages(
messages,
strategy="last",
token_counter=len,
max_tokens=10,
)

# All messages should be preserved, including valid tool call sequence
assert len(trimmed_messages) == 4
assert isinstance(trimmed_messages[0], HumanMessage)
assert isinstance(trimmed_messages[1], AIMessage)
assert trimmed_messages[1].tool_calls[0]["id"] == "calc1"
assert isinstance(trimmed_messages[2], ToolMessage)
assert trimmed_messages[2].tool_call_id == "calc1"
assert isinstance(trimmed_messages[3], AIMessage)


def test_trim_messages_multiple_tool_calls() -> None:
"""Test handling of multiple tool calls in sequence."""
messages = [
HumanMessage("Get me weather and news"),
AIMessage(
"Fetching both...",
tool_calls=[
{
"name": "get_weather",
"args": {"location": "NYC"},
"id": "tool1",
"type": "tool_call",
},
{
"name": "get_news",
"args": {"topic": "tech"},
"id": "tool2",
"type": "tool_call",
},
],
),
ToolMessage("Sunny", name="get_weather", tool_call_id="tool1"),
ToolMessage("AI news update", name="get_news", tool_call_id="tool2"),
HumanMessage("Thanks"),
AIMessage("You're welcome!"),
]

# Trim to last 2 messages (should remove all tool-related messages)
trimmed_messages = trim_messages(
messages,
strategy="last",
token_counter=len,
max_tokens=2,
)

assert len(trimmed_messages) == 2
assert isinstance(trimmed_messages[0], HumanMessage)
assert trimmed_messages[0].content == "Thanks"
assert isinstance(trimmed_messages[1], AIMessage)
assert trimmed_messages[1].content == "You're welcome!"
# No ToolMessages should remain
assert not any(isinstance(msg, ToolMessage) for msg in trimmed_messages)


def test_trim_messages_partial_tool_orphaning() -> None:
"""Test when some tool calls are preserved but others are orphaned."""
messages = [
HumanMessage("First question"),
AIMessage(
"Let me check",
tool_calls=[
{
"name": "tool1",
"args": {},
"id": "old_tool",
"type": "tool_call",
},
],
),
ToolMessage("Result 1", name="tool1", tool_call_id="old_tool"),
HumanMessage("Second question"),
AIMessage(
"Checking again",
tool_calls=[
{
"name": "tool2",
"args": {},
"id": "new_tool",
"type": "tool_call",
},
],
),
ToolMessage("Result 2", name="tool2", tool_call_id="new_tool"),
AIMessage("Done"),
]

# Trim to last 4 messages - should keep only the second tool call sequence
trimmed_messages = trim_messages(
messages,
strategy="last",
token_counter=len,
max_tokens=4,
)

assert len(trimmed_messages) == 4
# Should have the second AIMessage with tool_calls
ai_msgs_with_tools = [
msg for msg in trimmed_messages if isinstance(msg, AIMessage) and msg.tool_calls
]
assert len(ai_msgs_with_tools) == 1
assert ai_msgs_with_tools[0].tool_calls[0]["id"] == "new_tool"

# Should have matching ToolMessage
tool_msgs = [msg for msg in trimmed_messages if isinstance(msg, ToolMessage)]
assert len(tool_msgs) == 1
assert tool_msgs[0].tool_call_id == "new_tool"


def test_trim_messages_strategy_first_removes_orphans() -> None:
"""Test that strategy='first' also removes orphaned ToolMessages."""
messages = [
HumanMessage("Start"),
AIMessage("Response"),
HumanMessage("Question"),
AIMessage(
"Tool time",
tool_calls=[
{
"name": "tool",
"args": {},
"id": "tool123",
"type": "tool_call",
},
],
),
ToolMessage("Result", name="tool", tool_call_id="tool123"),
AIMessage("Final"),
]

# Trim to first 4 messages - cuts off in the middle of tool call sequence
trimmed_messages = trim_messages(
messages,
strategy="first",
token_counter=len,
max_tokens=4,
)

# The AIMessage with tool_calls is included but ToolMessage is cut off
# However, since the ToolMessage isn't in the trimmed result, no orphaning occurs
assert len(trimmed_messages) == 4
assert isinstance(trimmed_messages[3], AIMessage)

# Now test with an orphaned ToolMessage at the beginning
# Trim first 3 messages which includes an orphan
messages_with_orphan = [
HumanMessage("A"),
ToolMessage("Orphaned", name="tool", tool_call_id="missing_id"),
HumanMessage("B"),
AIMessage(
"Using tool",
tool_calls=[
{
"name": "valid_tool",
"args": {},
"id": "valid123",
"type": "tool_call",
},
],
),
ToolMessage("Valid result", name="valid_tool", tool_call_id="valid123"),
]

trimmed = trim_messages(
messages_with_orphan,
strategy="first",
token_counter=len,
max_tokens=3,
)

# Should remove the orphaned ToolMessage but keep valid messages
assert len(trimmed) == 2 # HumanMessage("A") and HumanMessage("B")
tool_messages = [msg for msg in trimmed if isinstance(msg, ToolMessage)]
assert len(tool_messages) == 0