diff --git a/src/cohere/aws_client.py b/src/cohere/aws_client.py index 8aea9d15c..12a168276 100644 --- a/src/cohere/aws_client.py +++ b/src/cohere/aws_client.py @@ -239,6 +239,7 @@ def _event_hook(request: httpx.Request) -> None: ) request.url = URL(url) request.headers["host"] = request.url.host + headers["host"] = request.url.host if endpoint == "rerank": body["api_version"] = get_api_version(version=api_version) diff --git a/src/cohere/manually_maintained/cohere_aws/client.py b/src/cohere/manually_maintained/cohere_aws/client.py index 065a5fd04..ee9b94ce8 100644 --- a/src/cohere/manually_maintained/cohere_aws/client.py +++ b/src/cohere/manually_maintained/cohere_aws/client.py @@ -20,19 +20,29 @@ class Client: def __init__( self, aws_region: typing.Optional[str] = None, + mode: Mode = Mode.SAGEMAKER, ): """ By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with `aws configure set region us-west-2` or override it with `region_name` parameter. """ - self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region) - self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region) + self.mode = mode if os.environ.get('AWS_DEFAULT_REGION') is None: os.environ['AWS_DEFAULT_REGION'] = aws_region - self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client) - self.mode = Mode.SAGEMAKER + if self.mode == Mode.SAGEMAKER: + self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region) + self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region) + self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client) + elif self.mode == Mode.BEDROCK: + self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region) + self._service_client = lazy_boto3().client("bedrock", region_name=aws_region) + self._sess = None + self._endpoint_name = None + def _require_sagemaker(self) -> None: + if self.mode != Mode.SAGEMAKER: + raise CohereError("This method is only supported in SageMaker mode.") def _does_endpoint_exist(self, endpoint_name: str) -> bool: try: @@ -50,6 +60,7 @@ def connect_to_endpoint(self, endpoint_name: str) -> None: Raises: CohereError: Connection to the endpoint failed. """ + self._require_sagemaker() if not self._does_endpoint_exist(endpoint_name): raise CohereError(f"Endpoint {endpoint_name} does not exist.") self._endpoint_name = endpoint_name @@ -137,6 +148,7 @@ def create_endpoint( will be used to get the role. This should work when one uses the client inside SageMaker. If this errors out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker. """ + self._require_sagemaker() # First, check if endpoint already exists if self._does_endpoint_exist(endpoint_name): if recreate: @@ -550,11 +562,15 @@ def embed( variant: Optional[str] = None, input_type: Optional[str] = None, model_id: Optional[str] = None, + output_dimension: Optional[int] = None, + embedding_types: Optional[List[str]] = None, ) -> Embeddings: json_params = { 'texts': texts, 'truncate': truncate, - "input_type": input_type + "input_type": input_type, + "output_dimension": output_dimension, + "embedding_types": embedding_types, } for key, value in list(json_params.items()): if value is None: @@ -805,6 +821,7 @@ def export_finetune( This should work when one uses the client inside SageMaker. If this errors out, the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker. """ + self._require_sagemaker() if name == "model": raise ValueError("name cannot be 'model'") @@ -948,6 +965,7 @@ def summarize( additional_command: Optional[str] = "", variant: Optional[str] = None ) -> Summary: + self._require_sagemaker() if self._endpoint_name is None: raise CohereError("No endpoint connected. " @@ -989,6 +1007,7 @@ def summarize( def delete_endpoint(self) -> None: + self._require_sagemaker() if self._endpoint_name is None: raise CohereError("No endpoint connected.") try: diff --git a/tests/test_aws_client_unit.py b/tests/test_aws_client_unit.py new file mode 100644 index 000000000..34d17c99f --- /dev/null +++ b/tests/test_aws_client_unit.py @@ -0,0 +1,208 @@ +""" +Unit tests (mocked, no AWS credentials needed) for AWS client fixes. + +Covers: +- Fix 1: SigV4 signing uses the correct host header after URL rewrite +- Fix 2: cohere_aws.Client conditionally initializes based on mode +- Fix 3: embed() accepts and passes output_dimension and embedding_types +""" + +import inspect +import json +import os +import unittest +from unittest.mock import MagicMock, patch + +import httpx + +from cohere.manually_maintained.cohere_aws.mode import Mode + + +class TestSigV4HostHeader(unittest.TestCase): + """Fix 1: The headers dict passed to AWSRequest for SigV4 signing must + contain the rewritten Bedrock/SageMaker host, not the stale api.cohere.com.""" + + def test_sigv4_signs_with_correct_host(self) -> None: + captured_aws_request_kwargs: dict = {} + + mock_aws_request_cls = MagicMock() + + def capture_aws_request(**kwargs): # type: ignore + captured_aws_request_kwargs.update(kwargs) + mock_req = MagicMock() + mock_req.prepare.return_value = MagicMock( + headers={"host": "bedrock-runtime.us-east-1.amazonaws.com"} + ) + return mock_req + + mock_aws_request_cls.side_effect = capture_aws_request + + mock_botocore = MagicMock() + mock_botocore.awsrequest.AWSRequest = mock_aws_request_cls + mock_botocore.auth.SigV4Auth.return_value = MagicMock() + + mock_boto3 = MagicMock() + mock_session = MagicMock() + mock_session.region_name = "us-east-1" + mock_session.get_credentials.return_value = MagicMock() + mock_boto3.Session.return_value = mock_session + + with patch("cohere.aws_client.lazy_botocore", return_value=mock_botocore), \ + patch("cohere.aws_client.lazy_boto3", return_value=mock_boto3): + + from cohere.aws_client import map_request_to_bedrock + + hook = map_request_to_bedrock(service="bedrock", aws_region="us-east-1") + + request = httpx.Request( + method="POST", + url="https://api.cohere.com/v1/chat", + headers={"connection": "keep-alive"}, + json={"model": "cohere.command-r-plus-v1:0", "message": "hello"}, + ) + + self.assertEqual(request.url.host, "api.cohere.com") + + hook(request) + + self.assertIn("bedrock-runtime.us-east-1.amazonaws.com", str(request.url)) + + signed_headers = captured_aws_request_kwargs["headers"] + self.assertEqual( + signed_headers["host"], + "bedrock-runtime.us-east-1.amazonaws.com", + ) + + +class TestModeConditionalInit(unittest.TestCase): + """Fix 2: cohere_aws.Client should initialize different boto3 clients + depending on mode, and default to SAGEMAKER for backwards compat.""" + + def test_sagemaker_mode_creates_sagemaker_clients(self) -> None: + mock_boto3 = MagicMock() + mock_sagemaker = MagicMock() + + with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ + patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \ + patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}): + + from cohere.manually_maintained.cohere_aws.client import Client + + client = Client(aws_region="us-east-1") + + self.assertEqual(client.mode, Mode.SAGEMAKER) + + service_names = [c[0][0] for c in mock_boto3.client.call_args_list] + self.assertIn("sagemaker-runtime", service_names) + self.assertIn("sagemaker", service_names) + self.assertNotIn("bedrock-runtime", service_names) + self.assertNotIn("bedrock", service_names) + + mock_sagemaker.Session.assert_called_once() + + def test_bedrock_mode_creates_bedrock_clients(self) -> None: + mock_boto3 = MagicMock() + mock_sagemaker = MagicMock() + + with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ + patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \ + patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-west-2"}): + + from cohere.manually_maintained.cohere_aws.client import Client + + client = Client(aws_region="us-west-2", mode=Mode.BEDROCK) + + self.assertEqual(client.mode, Mode.BEDROCK) + + service_names = [c[0][0] for c in mock_boto3.client.call_args_list] + self.assertIn("bedrock-runtime", service_names) + self.assertIn("bedrock", service_names) + self.assertNotIn("sagemaker-runtime", service_names) + self.assertNotIn("sagemaker", service_names) + + mock_sagemaker.Session.assert_not_called() + + def test_default_mode_is_sagemaker(self) -> None: + from cohere.manually_maintained.cohere_aws.client import Client + + sig = inspect.signature(Client.__init__) + self.assertEqual(sig.parameters["mode"].default, Mode.SAGEMAKER) + + +class TestEmbedV4Params(unittest.TestCase): + """Fix 3: embed() should accept output_dimension and embedding_types, + pass them through to the request body, and strip them when None.""" + + @staticmethod + def _make_bedrock_client(): # type: ignore + mock_boto3 = MagicMock() + mock_botocore = MagicMock() + captured_body: dict = {} + + def fake_invoke_model(**kwargs): # type: ignore + captured_body.update(json.loads(kwargs["body"])) + mock_body = MagicMock() + mock_body.read.return_value = json.dumps({"embeddings": [[0.1, 0.2]]}).encode() + return {"body": mock_body} + + mock_bedrock_client = MagicMock() + mock_bedrock_client.invoke_model.side_effect = fake_invoke_model + + def fake_boto3_client(service_name, **kwargs): # type: ignore + if service_name == "bedrock-runtime": + return mock_bedrock_client + return MagicMock() + + mock_boto3.client.side_effect = fake_boto3_client + return mock_boto3, mock_botocore, captured_body + + def test_embed_accepts_new_params(self) -> None: + from cohere.manually_maintained.cohere_aws.client import Client + + sig = inspect.signature(Client.embed) + self.assertIn("output_dimension", sig.parameters) + self.assertIn("embedding_types", sig.parameters) + self.assertIsNone(sig.parameters["output_dimension"].default) + self.assertIsNone(sig.parameters["embedding_types"].default) + + def test_embed_passes_params_to_bedrock(self) -> None: + mock_boto3, mock_botocore, captured_body = self._make_bedrock_client() + + with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ + patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \ + patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \ + patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}): + + from cohere.manually_maintained.cohere_aws.client import Client + + client = Client(aws_region="us-east-1", mode=Mode.BEDROCK) + client.embed( + texts=["hello world"], + input_type="search_document", + model_id="cohere.embed-english-v3", + output_dimension=256, + embedding_types=["float", "int8"], + ) + + self.assertEqual(captured_body["output_dimension"], 256) + self.assertEqual(captured_body["embedding_types"], ["float", "int8"]) + + def test_embed_omits_none_params(self) -> None: + mock_boto3, mock_botocore, captured_body = self._make_bedrock_client() + + with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ + patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \ + patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \ + patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}): + + from cohere.manually_maintained.cohere_aws.client import Client + + client = Client(aws_region="us-east-1", mode=Mode.BEDROCK) + client.embed( + texts=["hello world"], + input_type="search_document", + model_id="cohere.embed-english-v3", + ) + + self.assertNotIn("output_dimension", captured_body) + self.assertNotIn("embedding_types", captured_body) diff --git a/tests/test_bedrock_client.py b/tests/test_bedrock_client.py index d588ca38c..819399c8c 100644 --- a/tests/test_bedrock_client.py +++ b/tests/test_bedrock_client.py @@ -10,6 +10,17 @@ aws_region = os.getenv("AWS_REGION") endpoint_type = os.getenv("ENDPOINT_TYPE") + +def _setup_boto3_env(): + """Bridge custom test env vars to standard boto3 credential env vars.""" + if aws_access_key: + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key + if aws_secret_key: + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_key + if aws_session_token: + os.environ["AWS_SESSION_TOKEN"] = aws_session_token + + @unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set") class TestClient(unittest.TestCase): platform: str = "bedrock" @@ -109,3 +120,101 @@ def test_chat_stream(self) -> None: self.assertIsNotNone(event.response.text) self.assertSetEqual(response_types, {"text-generation", "stream-end"}) + + +@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set") +class TestBedrockClientV2(unittest.TestCase): + """Integration tests for BedrockClientV2 (httpx-based). + + Fix 1 validation: If these pass, SigV4 signing uses the correct host header, + since the request would fail with a signature mismatch otherwise. + """ + + client: cohere.ClientV2 = cohere.BedrockClientV2( + aws_access_key=aws_access_key, + aws_secret_key=aws_secret_key, + aws_session_token=aws_session_token, + aws_region=aws_region, + ) + + def test_embed(self) -> None: + response = self.client.embed( + model="cohere.embed-multilingual-v3", + texts=["I love Cohere!"], + input_type="search_document", + embedding_types=["float"], + ) + self.assertIsNotNone(response) + + def test_embed_with_output_dimension(self) -> None: + response = self.client.embed( + model="cohere.embed-english-v3", + texts=["I love Cohere!"], + input_type="search_document", + embedding_types=["float"], + output_dimension=256, + ) + self.assertIsNotNone(response) + + +@unittest.skipIf(None == os.getenv("TEST_AWS"), "tests skipped because TEST_AWS is not set") +class TestCohereAwsBedrockClient(unittest.TestCase): + """Integration tests for cohere_aws.Client in Bedrock mode (boto3-based). + + Validates: + - Fix 2: Client can be initialized with mode=BEDROCK without importing sagemaker + - Fix 3: embed() accepts output_dimension and embedding_types + """ + client: typing.Any = None + + @classmethod + def setUpClass(cls) -> None: + _setup_boto3_env() + from cohere.manually_maintained.cohere_aws.client import Client + from cohere.manually_maintained.cohere_aws.mode import Mode + cls.client = Client(aws_region=aws_region, mode=Mode.BEDROCK) + + def test_client_is_bedrock_mode(self) -> None: + from cohere.manually_maintained.cohere_aws.mode import Mode + self.assertEqual(self.client.mode, Mode.BEDROCK) + + def test_embed(self) -> None: + response = self.client.embed( + texts=["I love Cohere!"], + input_type="search_document", + model_id="cohere.embed-multilingual-v3", + ) + self.assertIsNotNone(response) + self.assertIsNotNone(response.embeddings) + self.assertGreater(len(response.embeddings), 0) + + def test_embed_with_embedding_types(self) -> None: + response = self.client.embed( + texts=["I love Cohere!"], + input_type="search_document", + model_id="cohere.embed-multilingual-v3", + embedding_types=["float"], + ) + self.assertIsNotNone(response) + self.assertIsNotNone(response.embeddings) + + def test_embed_with_output_dimension(self) -> None: + response = self.client.embed( + texts=["I love Cohere!"], + input_type="search_document", + model_id="cohere.embed-english-v3", + output_dimension=256, + embedding_types=["float"], + ) + self.assertIsNotNone(response) + self.assertIsNotNone(response.embeddings) + + def test_embed_without_new_params(self) -> None: + """Backwards compat: embed() still works without the new v4 params.""" + response = self.client.embed( + texts=["I love Cohere!"], + input_type="search_document", + model_id="cohere.embed-multilingual-v3", + ) + self.assertIsNotNone(response) + self.assertIsNotNone(response.embeddings)