diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index ec6579c58..be5337f0d 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -194,7 +194,7 @@ def format_request_messages( formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) formatted_messages.extend(cls._format_regular_messages(messages)) - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + return [message for message in formatted_messages if "content" in message or "tool_calls" in message] @override def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index ab421e6c7..2b217ad91 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -204,10 +204,18 @@ def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> ], ) + formatted_contents = [cls.format_request_message_content(content) for content in contents] + + # If single text content, use string format for better model compatibility + if len(formatted_contents) == 1 and formatted_contents[0].get("type") == "text": + content: str | list[dict[str, Any]] = formatted_contents[0]["text"] + else: + content = formatted_contents + return { "role": "tool", "tool_call_id": tool_result["toolUseId"], - "content": [cls.format_request_message_content(content) for content in contents], + "content": content, } @classmethod @@ -369,18 +377,21 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic formatted_message = { "role": message["role"], - "content": formatted_contents, + **({"content": formatted_contents} if formatted_contents else {}), **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), } formatted_messages.append(formatted_message) # Process tool messages to extract images into separate user messages # OpenAI API requires images to be in user role messages only + # All tool messages must be grouped together before any user messages with images + user_messages_with_images = [] for tool_msg in formatted_tool_messages: tool_msg_clean, user_msg_with_images = cls._split_tool_message_images(tool_msg) formatted_messages.append(tool_msg_clean) if user_msg_with_images: - formatted_messages.append(user_msg_with_images) + user_messages_with_images.append(user_msg_with_images) + formatted_messages.extend(user_messages_with_images) return formatted_messages @@ -407,7 +418,7 @@ def format_request_messages( formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) formatted_messages.extend(cls._format_regular_messages(messages)) - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + return [message for message in formatted_messages if "content" in message or "tool_calls" in message] def format_request( self, diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index f5e1837bf..9bb0e09ca 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -812,3 +812,39 @@ def __init__(self, usage): assert metadata_events[0]["metadata"]["usage"]["inputTokens"] == 10 assert metadata_events[0]["metadata"]["usage"]["outputTokens"] == 5 assert metadata_events[0]["metadata"]["usage"]["totalTokens"] == 15 + + +def test_format_request_messages_with_tool_calls_no_content(): + """Test that assistant messages with only tool calls are included and have no content field.""" + messages = [ + {"role": "user", "content": [{"text": "Use the calculator"}]}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + }, + ] + + tru_result = LiteLLMModel.format_request_messages(messages) + + exp_result = [ + {"role": "user", "content": [{"text": "Use the calculator", "type": "text"}]}, + { + "role": "assistant", + "tool_calls": [ + { + "function": {"arguments": '{"expression": "2+2"}', "name": "calculator"}, + "id": "c1", + "type": "function", + } + ], + }, + ] + assert tru_result == exp_result diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 4f8652632..241c22b64 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -180,6 +180,23 @@ def test_format_request_tool_message(): assert tru_result == exp_result +def test_format_request_tool_message_single_text_returns_string(): + """Test that single text content is returned as string for model compatibility.""" + tool_result = { + "content": [{"text": '{"result": "success"}'}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + exp_result = { + "content": '{"result": "success"}', + "role": "tool", + "tool_call_id": "c1", + } + assert tru_result == exp_result + + def test_split_tool_message_images_with_image(): """Test that images are extracted from tool messages.""" tool_message = { @@ -441,7 +458,7 @@ def test_format_request_messages(system_prompt): ], }, { - "content": [{"text": "4", "type": "text"}], + "content": "4", "role": "tool", "tool_call_id": "c1", }, @@ -1397,3 +1414,122 @@ def test_format_request_filters_location_source_document(model, caplog): assert len(formatted_content) == 1 assert formatted_content[0]["type"] == "text" assert "Location sources are not supported by OpenAI" in caplog.text + + +def test_format_request_messages_with_tool_calls_no_content(): + """Test that assistant messages with only tool calls are included and have no content field.""" + messages = [ + {"role": "user", "content": [{"text": "Use the calculator"}]}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + }, + ] + + tru_result = OpenAIModel.format_request_messages(messages) + + exp_result = [ + {"role": "user", "content": [{"text": "Use the calculator", "type": "text"}]}, + { + "role": "assistant", + "tool_calls": [ + { + "function": {"arguments": '{"expression": "2+2"}', "name": "calculator"}, + "id": "c1", + "type": "function", + } + ], + }, + ] + assert tru_result == exp_result + + +def test_format_request_messages_multiple_tool_calls_with_images(): + """Test that multiple tool calls with image results are formatted correctly. + + OpenAI requires all tool response messages to immediately follow the assistant + message with tool_calls, before any other messages. When tools return images, + the images are moved to user messages, but these must come after ALL tool messages. + """ + messages = [ + {"role": "user", "content": [{"text": "Run the tools"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"input": {}, "name": "tool1", "toolUseId": "call_1"}}, + {"toolUse": {"input": {}, "name": "tool2", "toolUseId": "call_2"}}, + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "call_1", + "content": [{"image": {"format": "png", "source": {"bytes": b"img1"}}}], + "status": "success", + } + }, + { + "toolResult": { + "toolUseId": "call_2", + "content": [{"image": {"format": "png", "source": {"bytes": b"img2"}}}], + "status": "success", + } + }, + ], + }, + ] + + tru_result = OpenAIModel.format_request_messages(messages) + + image_placeholder = ( + "Tool successfully returned an image. The image is being provided in the following user message." + ) + exp_result = [ + {"role": "user", "content": [{"text": "Run the tools", "type": "text"}]}, + { + "role": "assistant", + "tool_calls": [ + {"function": {"arguments": "{}", "name": "tool1"}, "id": "call_1", "type": "function"}, + {"function": {"arguments": "{}", "name": "tool2"}, "id": "call_2", "type": "function"}, + ], + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": [{"type": "text", "text": image_placeholder}], + }, + { + "role": "tool", + "tool_call_id": "call_2", + "content": [{"type": "text", "text": image_placeholder}], + }, + { + "role": "user", + "content": [ + { + "image_url": {"detail": "auto", "format": "image/png", "url": ""}, + "type": "image_url", + } + ], + }, + { + "role": "user", + "content": [ + { + "image_url": {"detail": "auto", "format": "image/png", "url": ""}, + "type": "image_url", + } + ], + }, + ] + assert tru_result == exp_result diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index e9738d3d9..a244bf753 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -113,6 +113,7 @@ def capture_first_node(self, event): return VerifyHook() +@pytest.mark.timeout(120) def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent, hook_provider): """Test swarm execution with string input.""" # Create the swarm