-
Notifications
You must be signed in to change notification settings - Fork 82
fix: resolve AWS client SigV4 signing, forced SageMaker dep, and missing embed params #728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e542ef0
065e010
b20b617
a04c082
bae43fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,19 +20,29 @@ class Client: | |
| def __init__( | ||
| self, | ||
| aws_region: typing.Optional[str] = None, | ||
| mode: Mode = Mode.SAGEMAKER, | ||
| ): | ||
| """ | ||
| By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with | ||
| `aws configure set region us-west-2` or override it with `region_name` parameter. | ||
| """ | ||
| self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region) | ||
| self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region) | ||
| self.mode = mode | ||
| if os.environ.get('AWS_DEFAULT_REGION') is None: | ||
| os.environ['AWS_DEFAULT_REGION'] = aws_region | ||
| self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client) | ||
| self.mode = Mode.SAGEMAKER | ||
|
|
||
| if self.mode == Mode.SAGEMAKER: | ||
| self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region) | ||
| self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region) | ||
| self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client) | ||
| elif self.mode == Mode.BEDROCK: | ||
| self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region) | ||
| self._service_client = lazy_boto3().client("bedrock", region_name=aws_region) | ||
| self._sess = None | ||
| self._endpoint_name = None | ||
|
|
||
| def _require_sagemaker(self) -> None: | ||
| if self.mode != Mode.SAGEMAKER: | ||
| raise CohereError("This method is only supported in SageMaker mode.") | ||
|
|
||
| def _does_endpoint_exist(self, endpoint_name: str) -> bool: | ||
| try: | ||
|
|
@@ -50,6 +60,7 @@ def connect_to_endpoint(self, endpoint_name: str) -> None: | |
| Raises: | ||
| CohereError: Connection to the endpoint failed. | ||
| """ | ||
| self._require_sagemaker() | ||
| if not self._does_endpoint_exist(endpoint_name): | ||
| raise CohereError(f"Endpoint {endpoint_name} does not exist.") | ||
| self._endpoint_name = endpoint_name | ||
|
|
@@ -137,6 +148,7 @@ def create_endpoint( | |
| will be used to get the role. This should work when one uses the client inside SageMaker. If this errors | ||
| out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker. | ||
| """ | ||
| self._require_sagemaker() | ||
| # First, check if endpoint already exists | ||
| if self._does_endpoint_exist(endpoint_name): | ||
| if recreate: | ||
|
|
@@ -550,11 +562,15 @@ def embed( | |
| variant: Optional[str] = None, | ||
| input_type: Optional[str] = None, | ||
| model_id: Optional[str] = None, | ||
| output_dimension: Optional[int] = None, | ||
| embedding_types: Optional[List[str]] = None, | ||
| ) -> Embeddings: | ||
| json_params = { | ||
| 'texts': texts, | ||
| 'truncate': truncate, | ||
| "input_type": input_type | ||
| "input_type": input_type, | ||
| "output_dimension": output_dimension, | ||
| "embedding_types": embedding_types, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Embed response parsing breaks with embedding_types parameterMedium Severity Adding Additional Locations (1) |
||
| } | ||
| for key, value in list(json_params.items()): | ||
| if value is None: | ||
|
|
@@ -805,6 +821,7 @@ def export_finetune( | |
| This should work when one uses the client inside SageMaker. If this errors out, | ||
| the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker. | ||
| """ | ||
| self._require_sagemaker() | ||
| if name == "model": | ||
| raise ValueError("name cannot be 'model'") | ||
|
|
||
|
|
@@ -948,6 +965,7 @@ def summarize( | |
| additional_command: Optional[str] = "", | ||
| variant: Optional[str] = None | ||
| ) -> Summary: | ||
| self._require_sagemaker() | ||
|
|
||
| if self._endpoint_name is None: | ||
| raise CohereError("No endpoint connected. " | ||
|
|
@@ -989,6 +1007,7 @@ def summarize( | |
|
|
||
|
|
||
| def delete_endpoint(self) -> None: | ||
| self._require_sagemaker() | ||
| if self._endpoint_name is None: | ||
| raise CohereError("No endpoint connected.") | ||
| try: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,208 @@ | ||
| """ | ||
| Unit tests (mocked, no AWS credentials needed) for AWS client fixes. | ||
|
|
||
| Covers: | ||
| - Fix 1: SigV4 signing uses the correct host header after URL rewrite | ||
| - Fix 2: cohere_aws.Client conditionally initializes based on mode | ||
| - Fix 3: embed() accepts and passes output_dimension and embedding_types | ||
| """ | ||
|
|
||
| import inspect | ||
| import json | ||
| import os | ||
| import unittest | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import httpx | ||
|
|
||
| from cohere.manually_maintained.cohere_aws.mode import Mode | ||
|
|
||
|
|
||
| class TestSigV4HostHeader(unittest.TestCase): | ||
| """Fix 1: The headers dict passed to AWSRequest for SigV4 signing must | ||
| contain the rewritten Bedrock/SageMaker host, not the stale api.cohere.com.""" | ||
|
|
||
| def test_sigv4_signs_with_correct_host(self) -> None: | ||
| captured_aws_request_kwargs: dict = {} | ||
|
|
||
| mock_aws_request_cls = MagicMock() | ||
|
|
||
| def capture_aws_request(**kwargs): # type: ignore | ||
| captured_aws_request_kwargs.update(kwargs) | ||
| mock_req = MagicMock() | ||
| mock_req.prepare.return_value = MagicMock( | ||
| headers={"host": "bedrock-runtime.us-east-1.amazonaws.com"} | ||
| ) | ||
| return mock_req | ||
|
|
||
| mock_aws_request_cls.side_effect = capture_aws_request | ||
|
|
||
| mock_botocore = MagicMock() | ||
| mock_botocore.awsrequest.AWSRequest = mock_aws_request_cls | ||
| mock_botocore.auth.SigV4Auth.return_value = MagicMock() | ||
|
|
||
| mock_boto3 = MagicMock() | ||
| mock_session = MagicMock() | ||
| mock_session.region_name = "us-east-1" | ||
| mock_session.get_credentials.return_value = MagicMock() | ||
| mock_boto3.Session.return_value = mock_session | ||
|
|
||
| with patch("cohere.aws_client.lazy_botocore", return_value=mock_botocore), \ | ||
| patch("cohere.aws_client.lazy_boto3", return_value=mock_boto3): | ||
|
|
||
| from cohere.aws_client import map_request_to_bedrock | ||
|
|
||
| hook = map_request_to_bedrock(service="bedrock", aws_region="us-east-1") | ||
|
|
||
| request = httpx.Request( | ||
| method="POST", | ||
| url="https://api.cohere.com/v1/chat", | ||
| headers={"connection": "keep-alive"}, | ||
| json={"model": "cohere.command-r-plus-v1:0", "message": "hello"}, | ||
| ) | ||
|
|
||
| self.assertEqual(request.url.host, "api.cohere.com") | ||
|
|
||
| hook(request) | ||
|
|
||
| self.assertIn("bedrock-runtime.us-east-1.amazonaws.com", str(request.url)) | ||
|
|
||
| signed_headers = captured_aws_request_kwargs["headers"] | ||
| self.assertEqual( | ||
| signed_headers["host"], | ||
| "bedrock-runtime.us-east-1.amazonaws.com", | ||
| ) | ||
|
|
||
|
|
||
| class TestModeConditionalInit(unittest.TestCase): | ||
| """Fix 2: cohere_aws.Client should initialize different boto3 clients | ||
| depending on mode, and default to SAGEMAKER for backwards compat.""" | ||
|
|
||
| def test_sagemaker_mode_creates_sagemaker_clients(self) -> None: | ||
| mock_boto3 = MagicMock() | ||
| mock_sagemaker = MagicMock() | ||
|
|
||
| with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ | ||
| patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \ | ||
| patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}): | ||
|
|
||
| from cohere.manually_maintained.cohere_aws.client import Client | ||
|
|
||
| client = Client(aws_region="us-east-1") | ||
|
|
||
| self.assertEqual(client.mode, Mode.SAGEMAKER) | ||
|
|
||
| service_names = [c[0][0] for c in mock_boto3.client.call_args_list] | ||
| self.assertIn("sagemaker-runtime", service_names) | ||
| self.assertIn("sagemaker", service_names) | ||
| self.assertNotIn("bedrock-runtime", service_names) | ||
| self.assertNotIn("bedrock", service_names) | ||
|
|
||
| mock_sagemaker.Session.assert_called_once() | ||
|
|
||
| def test_bedrock_mode_creates_bedrock_clients(self) -> None: | ||
| mock_boto3 = MagicMock() | ||
| mock_sagemaker = MagicMock() | ||
|
|
||
| with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ | ||
| patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \ | ||
| patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-west-2"}): | ||
|
|
||
| from cohere.manually_maintained.cohere_aws.client import Client | ||
|
|
||
| client = Client(aws_region="us-west-2", mode=Mode.BEDROCK) | ||
|
|
||
| self.assertEqual(client.mode, Mode.BEDROCK) | ||
|
|
||
| service_names = [c[0][0] for c in mock_boto3.client.call_args_list] | ||
| self.assertIn("bedrock-runtime", service_names) | ||
| self.assertIn("bedrock", service_names) | ||
| self.assertNotIn("sagemaker-runtime", service_names) | ||
| self.assertNotIn("sagemaker", service_names) | ||
|
|
||
| mock_sagemaker.Session.assert_not_called() | ||
|
|
||
| def test_default_mode_is_sagemaker(self) -> None: | ||
| from cohere.manually_maintained.cohere_aws.client import Client | ||
|
|
||
| sig = inspect.signature(Client.__init__) | ||
| self.assertEqual(sig.parameters["mode"].default, Mode.SAGEMAKER) | ||
|
|
||
|
|
||
| class TestEmbedV4Params(unittest.TestCase): | ||
| """Fix 3: embed() should accept output_dimension and embedding_types, | ||
| pass them through to the request body, and strip them when None.""" | ||
|
|
||
| @staticmethod | ||
| def _make_bedrock_client(): # type: ignore | ||
| mock_boto3 = MagicMock() | ||
| mock_botocore = MagicMock() | ||
| captured_body: dict = {} | ||
|
|
||
| def fake_invoke_model(**kwargs): # type: ignore | ||
| captured_body.update(json.loads(kwargs["body"])) | ||
| mock_body = MagicMock() | ||
| mock_body.read.return_value = json.dumps({"embeddings": [[0.1, 0.2]]}).encode() | ||
| return {"body": mock_body} | ||
|
|
||
| mock_bedrock_client = MagicMock() | ||
| mock_bedrock_client.invoke_model.side_effect = fake_invoke_model | ||
|
|
||
| def fake_boto3_client(service_name, **kwargs): # type: ignore | ||
| if service_name == "bedrock-runtime": | ||
| return mock_bedrock_client | ||
| return MagicMock() | ||
|
|
||
| mock_boto3.client.side_effect = fake_boto3_client | ||
| return mock_boto3, mock_botocore, captured_body | ||
|
|
||
| def test_embed_accepts_new_params(self) -> None: | ||
| from cohere.manually_maintained.cohere_aws.client import Client | ||
|
|
||
| sig = inspect.signature(Client.embed) | ||
| self.assertIn("output_dimension", sig.parameters) | ||
| self.assertIn("embedding_types", sig.parameters) | ||
| self.assertIsNone(sig.parameters["output_dimension"].default) | ||
| self.assertIsNone(sig.parameters["embedding_types"].default) | ||
|
|
||
| def test_embed_passes_params_to_bedrock(self) -> None: | ||
| mock_boto3, mock_botocore, captured_body = self._make_bedrock_client() | ||
|
|
||
| with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ | ||
| patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \ | ||
| patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \ | ||
| patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}): | ||
|
|
||
| from cohere.manually_maintained.cohere_aws.client import Client | ||
|
|
||
| client = Client(aws_region="us-east-1", mode=Mode.BEDROCK) | ||
| client.embed( | ||
| texts=["hello world"], | ||
| input_type="search_document", | ||
| model_id="cohere.embed-english-v3", | ||
| output_dimension=256, | ||
| embedding_types=["float", "int8"], | ||
| ) | ||
|
|
||
| self.assertEqual(captured_body["output_dimension"], 256) | ||
| self.assertEqual(captured_body["embedding_types"], ["float", "int8"]) | ||
|
|
||
| def test_embed_omits_none_params(self) -> None: | ||
| mock_boto3, mock_botocore, captured_body = self._make_bedrock_client() | ||
|
|
||
| with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \ | ||
| patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \ | ||
| patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \ | ||
| patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}): | ||
|
|
||
| from cohere.manually_maintained.cohere_aws.client import Client | ||
|
|
||
| client = Client(aws_region="us-east-1", mode=Mode.BEDROCK) | ||
| client.embed( | ||
| texts=["hello world"], | ||
| input_type="search_document", | ||
| model_id="cohere.embed-english-v3", | ||
| ) | ||
|
|
||
| self.assertNotIn("output_dimension", captured_body) | ||
| self.assertNotIn("embedding_types", captured_body) |


Uh oh!
There was an error while loading. Please reload this page.