Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def _event_hook(request: httpx.Request) -> None:
)
request.url = URL(url)
request.headers["host"] = request.url.host
headers["host"] = request.url.host

if endpoint == "rerank":
body["api_version"] = get_api_version(version=api_version)
Expand Down
29 changes: 24 additions & 5 deletions src/cohere/manually_maintained/cohere_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,29 @@ class Client:
def __init__(
self,
aws_region: typing.Optional[str] = None,
mode: Mode = Mode.SAGEMAKER,
):
"""
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
`aws configure set region us-west-2` or override it with `region_name` parameter.
"""
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
self.mode = mode
if os.environ.get('AWS_DEFAULT_REGION') is None:
os.environ['AWS_DEFAULT_REGION'] = aws_region
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
self.mode = Mode.SAGEMAKER

if self.mode == Mode.SAGEMAKER:
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
elif self.mode == Mode.BEDROCK:
self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("bedrock", region_name=aws_region)
self._sess = None
self._endpoint_name = None

def _require_sagemaker(self) -> None:
if self.mode != Mode.SAGEMAKER:
raise CohereError("This method is only supported in SageMaker mode.")

def _does_endpoint_exist(self, endpoint_name: str) -> bool:
try:
Expand All @@ -50,6 +60,7 @@ def connect_to_endpoint(self, endpoint_name: str) -> None:
Raises:
CohereError: Connection to the endpoint failed.
"""
self._require_sagemaker()
if not self._does_endpoint_exist(endpoint_name):
raise CohereError(f"Endpoint {endpoint_name} does not exist.")
self._endpoint_name = endpoint_name
Expand Down Expand Up @@ -137,6 +148,7 @@ def create_endpoint(
will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
"""
self._require_sagemaker()
# First, check if endpoint already exists
if self._does_endpoint_exist(endpoint_name):
if recreate:
Expand Down Expand Up @@ -550,11 +562,15 @@ def embed(
variant: Optional[str] = None,
input_type: Optional[str] = None,
model_id: Optional[str] = None,
output_dimension: Optional[int] = None,
embedding_types: Optional[List[str]] = None,
) -> Embeddings:
json_params = {
'texts': texts,
'truncate': truncate,
"input_type": input_type
"input_type": input_type,
"output_dimension": output_dimension,
"embedding_types": embedding_types,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Embed response parsing breaks with embedding_types parameter

Medium Severity

Adding embedding_types to the request parameters changes the Cohere API response format, but response parsing isn't updated. When embedding_types is specified, the API returns embeddings as a dict (e.g., {"float": [[...]], "int8": [[...]]}) with response_type: "embeddings_by_type", instead of a flat List[List[float]]. Both _bedrock_embed and _sagemaker_embed unconditionally pass response['embeddings'] to the Embeddings constructor, which expects a list. This causes len() and iteration on the returned Embeddings object to silently produce wrong results.

Additional Locations (1)

Fix in Cursor Fix in Web

}
for key, value in list(json_params.items()):
if value is None:
Expand Down Expand Up @@ -805,6 +821,7 @@ def export_finetune(
This should work when one uses the client inside SageMaker. If this errors out,
the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker.
"""
self._require_sagemaker()
if name == "model":
raise ValueError("name cannot be 'model'")

Expand Down Expand Up @@ -948,6 +965,7 @@ def summarize(
additional_command: Optional[str] = "",
variant: Optional[str] = None
) -> Summary:
self._require_sagemaker()

if self._endpoint_name is None:
raise CohereError("No endpoint connected. "
Expand Down Expand Up @@ -989,6 +1007,7 @@ def summarize(


def delete_endpoint(self) -> None:
self._require_sagemaker()
if self._endpoint_name is None:
raise CohereError("No endpoint connected.")
try:
Expand Down
208 changes: 208 additions & 0 deletions tests/test_aws_client_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""
Unit tests (mocked, no AWS credentials needed) for AWS client fixes.

Covers:
- Fix 1: SigV4 signing uses the correct host header after URL rewrite
- Fix 2: cohere_aws.Client conditionally initializes based on mode
- Fix 3: embed() accepts and passes output_dimension and embedding_types
"""

import inspect
import json
import os
import unittest
from unittest.mock import MagicMock, patch

import httpx

from cohere.manually_maintained.cohere_aws.mode import Mode


class TestSigV4HostHeader(unittest.TestCase):
"""Fix 1: The headers dict passed to AWSRequest for SigV4 signing must
contain the rewritten Bedrock/SageMaker host, not the stale api.cohere.com."""

def test_sigv4_signs_with_correct_host(self) -> None:
captured_aws_request_kwargs: dict = {}

mock_aws_request_cls = MagicMock()

def capture_aws_request(**kwargs): # type: ignore
captured_aws_request_kwargs.update(kwargs)
mock_req = MagicMock()
mock_req.prepare.return_value = MagicMock(
headers={"host": "bedrock-runtime.us-east-1.amazonaws.com"}
)
return mock_req

mock_aws_request_cls.side_effect = capture_aws_request

mock_botocore = MagicMock()
mock_botocore.awsrequest.AWSRequest = mock_aws_request_cls
mock_botocore.auth.SigV4Auth.return_value = MagicMock()

mock_boto3 = MagicMock()
mock_session = MagicMock()
mock_session.region_name = "us-east-1"
mock_session.get_credentials.return_value = MagicMock()
mock_boto3.Session.return_value = mock_session

with patch("cohere.aws_client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.aws_client.lazy_boto3", return_value=mock_boto3):

from cohere.aws_client import map_request_to_bedrock

hook = map_request_to_bedrock(service="bedrock", aws_region="us-east-1")

request = httpx.Request(
method="POST",
url="https://api.cohere.com/v1/chat",
headers={"connection": "keep-alive"},
json={"model": "cohere.command-r-plus-v1:0", "message": "hello"},
)

self.assertEqual(request.url.host, "api.cohere.com")

hook(request)

self.assertIn("bedrock-runtime.us-east-1.amazonaws.com", str(request.url))

signed_headers = captured_aws_request_kwargs["headers"]
self.assertEqual(
signed_headers["host"],
"bedrock-runtime.us-east-1.amazonaws.com",
)


class TestModeConditionalInit(unittest.TestCase):
"""Fix 2: cohere_aws.Client should initialize different boto3 clients
depending on mode, and default to SAGEMAKER for backwards compat."""

def test_sagemaker_mode_creates_sagemaker_clients(self) -> None:
mock_boto3 = MagicMock()
mock_sagemaker = MagicMock()

with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):

from cohere.manually_maintained.cohere_aws.client import Client

client = Client(aws_region="us-east-1")

self.assertEqual(client.mode, Mode.SAGEMAKER)

service_names = [c[0][0] for c in mock_boto3.client.call_args_list]
self.assertIn("sagemaker-runtime", service_names)
self.assertIn("sagemaker", service_names)
self.assertNotIn("bedrock-runtime", service_names)
self.assertNotIn("bedrock", service_names)

mock_sagemaker.Session.assert_called_once()

def test_bedrock_mode_creates_bedrock_clients(self) -> None:
mock_boto3 = MagicMock()
mock_sagemaker = MagicMock()

with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-west-2"}):

from cohere.manually_maintained.cohere_aws.client import Client

client = Client(aws_region="us-west-2", mode=Mode.BEDROCK)

self.assertEqual(client.mode, Mode.BEDROCK)

service_names = [c[0][0] for c in mock_boto3.client.call_args_list]
self.assertIn("bedrock-runtime", service_names)
self.assertIn("bedrock", service_names)
self.assertNotIn("sagemaker-runtime", service_names)
self.assertNotIn("sagemaker", service_names)

mock_sagemaker.Session.assert_not_called()

def test_default_mode_is_sagemaker(self) -> None:
from cohere.manually_maintained.cohere_aws.client import Client

sig = inspect.signature(Client.__init__)
self.assertEqual(sig.parameters["mode"].default, Mode.SAGEMAKER)


class TestEmbedV4Params(unittest.TestCase):
"""Fix 3: embed() should accept output_dimension and embedding_types,
pass them through to the request body, and strip them when None."""

@staticmethod
def _make_bedrock_client(): # type: ignore
mock_boto3 = MagicMock()
mock_botocore = MagicMock()
captured_body: dict = {}

def fake_invoke_model(**kwargs): # type: ignore
captured_body.update(json.loads(kwargs["body"]))
mock_body = MagicMock()
mock_body.read.return_value = json.dumps({"embeddings": [[0.1, 0.2]]}).encode()
return {"body": mock_body}

mock_bedrock_client = MagicMock()
mock_bedrock_client.invoke_model.side_effect = fake_invoke_model

def fake_boto3_client(service_name, **kwargs): # type: ignore
if service_name == "bedrock-runtime":
return mock_bedrock_client
return MagicMock()

mock_boto3.client.side_effect = fake_boto3_client
return mock_boto3, mock_botocore, captured_body

def test_embed_accepts_new_params(self) -> None:
from cohere.manually_maintained.cohere_aws.client import Client

sig = inspect.signature(Client.embed)
self.assertIn("output_dimension", sig.parameters)
self.assertIn("embedding_types", sig.parameters)
self.assertIsNone(sig.parameters["output_dimension"].default)
self.assertIsNone(sig.parameters["embedding_types"].default)

def test_embed_passes_params_to_bedrock(self) -> None:
mock_boto3, mock_botocore, captured_body = self._make_bedrock_client()

with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):

from cohere.manually_maintained.cohere_aws.client import Client

client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
client.embed(
texts=["hello world"],
input_type="search_document",
model_id="cohere.embed-english-v3",
output_dimension=256,
embedding_types=["float", "int8"],
)

self.assertEqual(captured_body["output_dimension"], 256)
self.assertEqual(captured_body["embedding_types"], ["float", "int8"])

def test_embed_omits_none_params(self) -> None:
mock_boto3, mock_botocore, captured_body = self._make_bedrock_client()

with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):

from cohere.manually_maintained.cohere_aws.client import Client

client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
client.embed(
texts=["hello world"],
input_type="search_document",
model_id="cohere.embed-english-v3",
)

self.assertNotIn("output_dimension", captured_body)
self.assertNotIn("embedding_types", captured_body)
Loading