Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions litellm/llms/cohere/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import List, Optional, Literal
from typing import List, Optional, Literal, Tuple

from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException
Expand Down Expand Up @@ -281,7 +281,7 @@ def _parse_citation_start(self, chunk: dict) -> Optional[dict]:
return {"citations": [citation_data]}
return None

def _parse_message_end(self, chunk: dict) -> tuple[bool, str, Optional[ChatCompletionUsageBlock]]:
def _parse_message_end(self, chunk: dict) -> Tuple[bool, str, Optional[ChatCompletionUsageBlock]]:
"""Parse message-end events to extract finish info and usage."""
data = chunk.get("data", {})
delta = data.get("delta", {})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def _create_vertex_response_logging_payload_for_generate_content(
return kwargs

@staticmethod
def batch_prediction_jobs_handler(
def batch_prediction_jobs_handler( # noqa: PLR0915
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
Expand Down
58 changes: 32 additions & 26 deletions tests/llm_translation/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,13 +507,13 @@ async def test_cohere_v2_chat_completion(sync_mode):

if sync_mode:
response = completion(
model="cohere_chat/v2/command-r-plus",
model="cohere_chat/v2/command-a-03-2025",
messages=messages,
max_tokens=50
)
else:
response = await litellm.acompletion(
model="cohere_chat/v2/command-r-plus",
model="cohere_chat/v2/command-a-03-2025",
messages=messages,
max_tokens=50
)
Expand Down Expand Up @@ -544,7 +544,7 @@ async def test_cohere_v2_streaming(stream):
]

response = await litellm.acompletion(
model="cohere_chat/v2/command-r-plus",
model="cohere_chat/v2/command-a-03-2025",
messages=messages,
max_tokens=100,
stream=stream
Expand Down Expand Up @@ -605,7 +605,7 @@ def test_cohere_v2_tool_calling():
]

response = completion(
model="cohere_chat/v2/command-r-plus",
model="cohere_chat/v2/command-a-03-2025",
messages=messages,
tools=tools,
tool_choice="auto",
Expand Down Expand Up @@ -648,17 +648,21 @@ async def test_cohere_v2_citations(stream):

documents = [
{
"title": "Renewable Energy Benefits",
"text": "Renewable energy sources like solar and wind power reduce greenhouse gas emissions and provide sustainable energy solutions."
"data": {
"title": "Test Document 1",
"snippet": "This is test content 1"
}
},
{
"title": "Environmental Impact",
"text": "Solar panels and wind turbines have minimal environmental impact compared to fossil fuel power plants."
"data": {
"title": "Test Document 2",
"snippet": "This is test content 2"
}
}
]

response = await litellm.acompletion(
model="cohere_chat/v2/command-r-plus",
model="cohere_chat/v2/command-a-03-2025",
messages=messages,
documents=documents,
max_tokens=100,
Expand Down Expand Up @@ -703,7 +707,7 @@ def test_cohere_v2_parameter_mapping():

# Test various parameters that should be mapped correctly
response = completion(
model="cohere_chat/v2/command-r-plus",
model="cohere_chat/v2/command-a-03-2025",
messages=messages,
temperature=0.7,
max_tokens=50,
Expand Down Expand Up @@ -737,7 +741,7 @@ async def test_cohere_v2_multiple_generations():
]

response = await litellm.acompletion(
model="cohere_chat/v2/command-r-plus",
model="cohere_chat/v2/command-a-03-2025",
messages=messages,
n=3, # Request 3 generations
max_tokens=30
Expand Down Expand Up @@ -776,7 +780,7 @@ def test_cohere_v2_error_handling():
# Test with empty messages
try:
response = completion(
model="cohere_chat/v2/command-r-plus",
model="cohere_chat/v2/command-a-03-2025",
messages=[], # Empty messages
max_tokens=10
)
Expand All @@ -790,9 +794,9 @@ def test_cohere_v2_error_handling():


@pytest.mark.asyncio
async def test_cohere_documents_citation_options_in_request_body():
async def test_cohere_documents_options_in_request_body():
"""
Test that documents and citation_options parameters are properly included
Test that documents parameters is properly included
in the request body after transformation (sent via extra_body).
"""
# Create a mock response
Expand All @@ -809,19 +813,23 @@ async def test_cohere_documents_citation_options_in_request_body():
try:
# Test documents and citation_options parameters
test_documents = [
{"title": "Test Document 1", "text": "This is test content 1"},
{"title": "Test Document 2", "text": "This is test content 2"}
{
"data": {
"title": "Test Document 1",
"snippet": "This is test content 1"
}
},
{
"data": {
"title": "Test Document 2",
"snippet": "This is test content 2"
}
}
]
test_citation_options = {
"return_citations": True,
"return_confidence": True
}

await litellm.acompletion(
model="cohere_chat/command-r",
model="cohere_chat/command-a-03-2025",
messages=[{"role": "user", "content": "Test message"}],
documents=test_documents,
citation_options=test_citation_options
)
except Exception:
pass # We only care about the request body validation
Expand All @@ -836,8 +844,6 @@ async def test_cohere_documents_citation_options_in_request_body():
# Validate that documents and citation_options are in the request body
assert "documents" in request_data
assert request_data["documents"] == test_documents
assert "citation_options" in request_data
assert request_data["citation_options"] == test_citation_options


@pytest.mark.asyncio
Expand All @@ -853,7 +859,7 @@ async def test_cohere_v2_conversation_history():
]

response = await litellm.acompletion(
model="cohere_chat/v2/command-r-plus",
model="cohere_chat/v2/command-a-03-2025",
messages=messages,
max_tokens=50
)
Expand Down
Loading