diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index 3b6562b51aef..5692bc0e6523 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -1327,95 +1327,6 @@ Here's how to use Vertex AI with the LiteLLM Proxy Server -## Authentication - vertex_project, vertex_location, etc. - -Set your vertex credentials via: -- dynamic params -OR -- env vars - - -### **Dynamic Params** - -You can set: -- `vertex_credentials` (str) - can be a json string or filepath to your vertex ai service account.json -- `vertex_location` (str) - place where vertex model is deployed (us-central1, asia-southeast1, etc.). Some models support the global location, please see [Vertex AI documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#supported_models) -- `vertex_project` Optional[str] - use if vertex project different from the one in vertex_credentials - -as dynamic params for a `litellm.completion` call. - - - - -```python -from litellm import completion -import json - -## GET CREDENTIALS -file_path = 'path/to/vertex_ai_service_account.json' - -# Load the JSON file -with open(file_path, 'r') as file: - vertex_credentials = json.load(file) - -# Convert to JSON string -vertex_credentials_json = json.dumps(vertex_credentials) - - -response = completion( - model="vertex_ai/gemini-2.5-pro", - messages=[{"content": "You are a good bot.","role": "system"}, {"content": "Hello, how are you?","role": "user"}], - vertex_credentials=vertex_credentials_json, - vertex_project="my-special-project", - vertex_location="my-special-location" -) -``` - - - - -```yaml -model_list: - - model_name: gemini-1.5-pro - litellm_params: - model: gemini-1.5-pro - vertex_credentials: os.environ/VERTEX_FILE_PATH_ENV_VAR # os.environ["VERTEX_FILE_PATH_ENV_VAR"] = "/path/to/service_account.json" - vertex_project: "my-special-project" - vertex_location: "my-special-location: -``` - - - - - - - -### **Environment Variables** - -You can set: -- `GOOGLE_APPLICATION_CREDENTIALS` - store the filepath for your service_account.json in here (used by vertex sdk directly). -- VERTEXAI_LOCATION - place where vertex model is deployed (us-central1, asia-southeast1, etc.) -- VERTEXAI_PROJECT - Optional[str] - use if vertex project different from the one in vertex_credentials - -1. GOOGLE_APPLICATION_CREDENTIALS - -```bash -export GOOGLE_APPLICATION_CREDENTIALS="/path/to/service_account.json" -``` - -2. VERTEXAI_LOCATION - -```bash -export VERTEXAI_LOCATION="us-central1" # can be any vertex location -``` - -3. VERTEXAI_PROJECT - -```bash -export VERTEXAI_PROJECT="my-test-project" # ONLY use if model project is different from service account project -``` - - ## Specifying Safety Settings In certain use-cases you may need to make calls to the models and pass [safety settings](https://ai.google.dev/docs/safety_setting_gemini) different from the defaults. To do so, simple pass the `safety_settings` argument to `completion` or `acompletion`. For example: diff --git a/docs/my-website/docs/providers/vertex_auth.md b/docs/my-website/docs/providers/vertex_auth.md new file mode 100644 index 000000000000..b700764c2372 --- /dev/null +++ b/docs/my-website/docs/providers/vertex_auth.md @@ -0,0 +1,155 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Vertex AI Authentication + +Set your vertex credentials via: +- dynamic params +OR +- env vars + + +### **Dynamic Params** + +You can set: +- `vertex_credentials` (str) - can be a json string or filepath to your vertex ai service account.json +- `vertex_location` (str) - place where vertex model is deployed (us-central1, asia-southeast1, etc.). Some models support the global location, please see [Vertex AI documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#supported_models) +- `vertex_project` Optional[str] - use if vertex project different from the one in vertex_credentials + +as dynamic params for a `litellm.completion` call. + + + + +```python +from litellm import completion +import json + +## GET CREDENTIALS +file_path = 'path/to/vertex_ai_service_account.json' + +# Load the JSON file +with open(file_path, 'r') as file: + vertex_credentials = json.load(file) + +# Convert to JSON string +vertex_credentials_json = json.dumps(vertex_credentials) + + +response = completion( + model="vertex_ai/gemini-2.5-pro", + messages=[{"content": "You are a good bot.","role": "system"}, {"content": "Hello, how are you?","role": "user"}], + vertex_credentials=vertex_credentials_json, + vertex_project="my-special-project", + vertex_location="my-special-location" +) +``` + + + + +```yaml +model_list: + - model_name: gemini-1.5-pro + litellm_params: + model: gemini-1.5-pro + vertex_credentials: os.environ/VERTEX_FILE_PATH_ENV_VAR # os.environ["VERTEX_FILE_PATH_ENV_VAR"] = "/path/to/service_account.json" + vertex_project: "my-special-project" + vertex_location: "my-special-location: +``` + + + + + + + +### **Environment Variables** + +You can set: +- `GOOGLE_APPLICATION_CREDENTIALS` - store the filepath for your service_account.json in here (used by vertex sdk directly). +- VERTEXAI_LOCATION - place where vertex model is deployed (us-central1, asia-southeast1, etc.) +- VERTEXAI_PROJECT - Optional[str] - use if vertex project different from the one in vertex_credentials + +1. GOOGLE_APPLICATION_CREDENTIALS + +```bash +export GOOGLE_APPLICATION_CREDENTIALS="/path/to/service_account.json" +``` + +2. VERTEXAI_LOCATION + +```bash +export VERTEXAI_LOCATION="us-central1" # can be any vertex location +``` + +3. VERTEXAI_PROJECT + +```bash +export VERTEXAI_PROJECT="my-test-project" # ONLY use if model project is different from service account project +``` + +## AWS to GCP Federation (No Metadata Required) + +Use AWS credentials to access Vertex AI without EC2 metadata endpoints. Ideal when `http://169.254.169.254` is blocked. + +**Quick Setup:** + +1. Create a credential file with your AWS auth params: + +```json +{ + "type": "external_account", + "audience": "//iam.googleapis.com/projects/PROJECT_NUMBER/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID", + "subject_token_type": "urn:ietf:params:aws:token-type:aws4_request", + "token_url": "https://sts.googleapis.com/v1/token", + "service_account_impersonation_url": "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/SA_EMAIL:generateAccessToken", + "credential_source": { + "environment_id": "aws1" + }, + "aws_role_name": "arn:aws:iam::123456789012:role/MyRole", + "aws_region_name": "us-east-1" +} +``` + +2. Use it in your code: + + + + +```python +import litellm +import json + +with open('aws_gcp_credentials.json', 'r') as f: + credentials = json.load(f) + +response = litellm.completion( + model="vertex_ai/gemini-pro", + messages=[{"role": "user", "content": "Hello!"}], + vertex_credentials=credentials, + vertex_project="my-gcp-project", + vertex_location="us-central1" +) +``` + + + + +```yaml +model_list: + - model_name: gemini-pro + litellm_params: + model: vertex_ai/gemini-pro + vertex_credentials: /path/to/aws_gcp_credentials.json + vertex_project: my-gcp-project + vertex_location: us-central1 +``` + + + + +**Supported AWS auth methods:** `aws_role_name`, `aws_profile_name`, `aws_access_key_id`/`aws_secret_access_key`, `aws_web_identity_token` + +**Prerequisites:** You need a GCP Workload Identity Pool configured with an AWS provider. [Setup guide](https://cloud.google.com/iam/docs/workload-identity-federation-with-other-clouds#aws) \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 1c17216ac354..333c7ab7b0b1 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -426,6 +426,7 @@ const sidebars = { label: "Vertex AI", items: [ "providers/vertex", + "providers/vertex_auth", "providers/vertex_partner", "providers/vertex_self_deployed", "providers/vertex_image", diff --git a/litellm/llms/vertex_ai/aws_credentials_supplier.py b/litellm/llms/vertex_ai/aws_credentials_supplier.py new file mode 100644 index 000000000000..01027a0da73c --- /dev/null +++ b/litellm/llms/vertex_ai/aws_credentials_supplier.py @@ -0,0 +1,122 @@ +""" +AWS Credentials Supplier for GCP Workload Identity Federation + +This module provides a custom AWS credentials supplier that uses boto3 credentials +instead of EC2 metadata endpoints, enabling AWS to GCP federation in environments +where metadata service access is blocked. +""" + +from typing import TYPE_CHECKING, Any, Mapping + +if TYPE_CHECKING: + from botocore.credentials import Credentials as BotoCredentials +else: + BotoCredentials = Any + + +class Boto3AwsSecurityCredentialsSupplier: + """ + Custom AWS credentials supplier that uses boto3 credentials instead of EC2 metadata endpoints. + + This allows AWS to GCP Workload Identity Federation without relying on the metadata service + (http://169.254.169.254). It wraps boto3 credentials obtained via BaseAWSLLM and provides + them to Google's aws.Credentials class. + + Example: + ```python + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + from google.auth import aws + + # Get AWS credentials using BaseAWSLLM (supports all auth flows) + aws_llm = BaseAWSLLM() + boto3_creds = aws_llm.get_credentials( + aws_role_name="arn:aws:iam::123456789012:role/MyRole", + aws_session_name="my-session", + aws_region_name="us-east-1" + ) + + # Create custom supplier + supplier = Boto3AwsSecurityCredentialsSupplier( + boto3_credentials=boto3_creds, + aws_region="us-east-1" + ) + + # Use with Google's aws.Credentials (bypasses metadata) + gcp_credentials = aws.Credentials( + audience="//iam.googleapis.com/projects/.../providers/...", + subject_token_type="urn:ietf:params:aws:token-type:aws4_request", + token_url="https://sts.googleapis.com/v1/token", + aws_security_credentials_supplier=supplier, + credential_source=None, # Not using metadata + ) + ``` + """ + + def __init__( + self, boto3_credentials: BotoCredentials, aws_region: str = "us-east-1" + ) -> None: + """ + Initialize the AWS credentials supplier. + + Args: + boto3_credentials: botocore.credentials.Credentials object from boto3/BaseAWSLLM. + This can come from any AWS auth flow (role assumption, profile, + web identity token, explicit credentials, etc.) + aws_region: AWS region name. Defaults to "us-east-1" + """ + self._credentials = boto3_credentials + self._region = aws_region + + def get_aws_security_credentials( + self, context: Any, request: Any + ) -> Mapping[str, str]: + """ + Get AWS security credentials from the boto3 credentials object. + + This method is called by Google's aws.Credentials class to obtain AWS credentials + for the token exchange process. It extracts the credentials from the boto3 + Credentials object, handling both frozen and unfrozen credential formats. + + Args: + context: Supplier context (unused, required by interface) + request: HTTP request object (unused, required by interface) + + Returns: + Dict containing: + - access_key_id: AWS access key ID + - secret_access_key: AWS secret access key + - security_token: AWS session token (or empty string if not present) + """ + # Refresh credentials if needed and get frozen credentials + # Frozen credentials are immutable snapshots of the current credential values + if hasattr(self._credentials, "get_frozen_credentials"): + frozen_creds = self._credentials.get_frozen_credentials() + return { + "access_key_id": frozen_creds.access_key, + "secret_access_key": frozen_creds.secret_key, + "security_token": frozen_creds.token or "", + } + else: + # Fallback for credentials that don't support get_frozen_credentials + return { + "access_key_id": self._credentials.access_key, + "secret_access_key": self._credentials.secret_key, + "security_token": getattr(self._credentials, "token", "") or "", + } + + def get_aws_region(self, context: Any, request: Any) -> str: + """ + Get the AWS region for credential verification. + + This method is called by Google's aws.Credentials class to determine which + AWS region to use for credential verification requests. + + Args: + context: Supplier context (unused, required by interface) + request: HTTP request object (unused, required by interface) + + Returns: + AWS region name (e.g., "us-east-1", "us-west-2") + """ + return self._region + diff --git a/litellm/llms/vertex_ai/vertex_llm_base.py b/litellm/llms/vertex_ai/vertex_llm_base.py index 6d194d41addb..1c77293bcf83 100644 --- a/litellm/llms/vertex_ai/vertex_llm_base.py +++ b/litellm/llms/vertex_ai/vertex_llm_base.py @@ -12,6 +12,9 @@ from litellm._logging import verbose_logger from litellm.litellm_core_utils.asyncify import asyncify from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.vertex_ai.aws_credentials_supplier import ( + Boto3AwsSecurityCredentialsSupplier, +) from litellm.secret_managers.main import get_secret_str from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES, VertexPartnerProvider @@ -89,7 +92,20 @@ def load_auth( else "" ) if isinstance(environment_id, str) and "aws" in environment_id: - creds = self._credentials_from_identity_pool_with_aws(json_obj) + # Check if explicit AWS auth parameters are provided + aws_params = self._extract_aws_params(json_obj) + if aws_params: + verbose_logger.debug( + "Explicit AWS auth parameters detected, using custom federation flow" + ) + creds = self._credentials_from_aws_with_explicit_auth( + json_obj=json_obj, aws_params=aws_params + ) + else: + verbose_logger.debug( + "No explicit AWS auth parameters, using default metadata-based flow" + ) + creds = self._credentials_from_identity_pool_with_aws(json_obj) else: creds = self._credentials_from_identity_pool(json_obj) # Check if the JSON object contains Authorized User configuration (via gcloud auth application-default login) @@ -140,6 +156,100 @@ def _credentials_from_identity_pool_with_aws(self, json_obj): return aws.Credentials.from_info(json_obj) + def _extract_aws_params(self, json_obj: dict) -> Optional[dict]: + """ + Extract AWS authentication parameters from credential config. + + Returns: + Dict of AWS auth params if any are present, None otherwise + """ + aws_keys = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_role_name", + "aws_session_name", + "aws_profile_name", + "aws_web_identity_token", + "aws_region_name", + "aws_external_id", + "aws_sts_endpoint", + ] + + aws_params = {k: json_obj.get(k) for k in aws_keys if k in json_obj} + return aws_params if aws_params else None + + def _credentials_from_aws_with_explicit_auth( + self, json_obj: dict, aws_params: dict + ) -> GoogleCredentialsObject: + """ + Create GCP credentials using explicit AWS authentication (no metadata endpoints). + Reuses BaseAWSLLM for all AWS auth flows (roles, profiles, web identity tokens, etc.). + + Args: + json_obj: The external_account credential configuration + aws_params: Dict of AWS authentication parameters + + Returns: + Google credentials object configured with custom AWS supplier + """ + from google.auth import aws + + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + verbose_logger.debug( + "Using explicit AWS authentication for GCP federation (no metadata endpoints)" + ) + verbose_logger.debug(f"AWS parameters provided: {list(aws_params.keys())}") + + # Use BaseAWSLLM to get AWS credentials (handles all auth flows) + aws_llm = BaseAWSLLM() + boto3_credentials = aws_llm.get_credentials( + aws_access_key_id=aws_params.get("aws_access_key_id"), + aws_secret_access_key=aws_params.get("aws_secret_access_key"), + aws_session_token=aws_params.get("aws_session_token"), + aws_region_name=aws_params.get("aws_region_name"), + aws_session_name=aws_params.get("aws_session_name"), + aws_profile_name=aws_params.get("aws_profile_name"), + aws_role_name=aws_params.get("aws_role_name"), + aws_web_identity_token=aws_params.get("aws_web_identity_token"), + aws_sts_endpoint=aws_params.get("aws_sts_endpoint"), + aws_external_id=aws_params.get("aws_external_id"), + ) + + # Create custom supplier that uses boto3 credentials (bypasses metadata) + supplier = Boto3AwsSecurityCredentialsSupplier( + boto3_credentials=boto3_credentials, + aws_region=aws_params.get("aws_region_name", "us-east-1"), + ) + + verbose_logger.debug( + "Created custom AWS credentials supplier, creating GCP credentials" + ) + + # Validate required fields for external account credentials + token_url = json_obj.get("token_url") + if not token_url: + raise ValueError( + "token_url is required for external account credentials with AWS federation" + ) + + # Create GCP credentials with custom supplier (bypasses metadata) + return aws.Credentials( + audience=json_obj.get("audience"), + subject_token_type=json_obj.get("subject_token_type"), + token_url=token_url, + credential_source=None, # Not using metadata endpoints + aws_security_credentials_supplier=supplier, # Custom supplier + service_account_impersonation_url=json_obj.get( + "service_account_impersonation_url" + ), + client_id=json_obj.get("client_id"), + client_secret=json_obj.get("client_secret"), + quota_project_id=json_obj.get("quota_project_id"), + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + def _credentials_from_authorized_user(self, json_obj, scopes): import google.oauth2.credentials diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_llm_base.py b/tests/test_litellm/llms/vertex_ai/test_vertex_llm_base.py index c85f60700848..54ef18a96b07 100644 --- a/tests/test_litellm/llms/vertex_ai/test_vertex_llm_base.py +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_llm_base.py @@ -12,6 +12,9 @@ ) # Adds the parent directory to the system path import litellm +from litellm.llms.vertex_ai.aws_credentials_supplier import ( + Boto3AwsSecurityCredentialsSupplier, +) from litellm.llms.vertex_ai.vertex_llm_base import VertexBase @@ -878,3 +881,192 @@ def test_check_custom_proxy_streaming_parameter(self): expected_no_streaming_url = "https://proxy.example.com/generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-lite:generateContent" assert result_url_no_streaming == expected_no_streaming_url, f"Expected {expected_no_streaming_url}, got {result_url_no_streaming}" + + def test_aws_federation_with_explicit_credentials(self): + """ + Test AWS to GCP federation using explicit AWS credentials (no metadata endpoints). + + This test verifies that when AWS auth parameters are provided in the credential config, + LiteLLM uses BaseAWSLLM to get AWS credentials and creates a custom supplier for GCP. + """ + vertex_base = VertexBase() + + credentials = { + "type": "external_account", + "audience": "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/my-pool/providers/aws-provider", + "subject_token_type": "urn:ietf:params:aws:token-type:aws4_request", + "token_url": "https://sts.googleapis.com/v1/token", + "service_account_impersonation_url": "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/my-sa@project.iam.gserviceaccount.com:generateAccessToken", + "credential_source": { + "environment_id": "aws1", + "regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" + }, + "aws_role_name": "arn:aws:iam::123456789012:role/MyRole", + "aws_session_name": "litellm-test", + "aws_region_name": "us-east-1" + } + + # Mock boto3 credentials + mock_boto3_creds = MagicMock() + mock_boto3_creds.access_key = "AKIAIOSFODNN7EXAMPLE" + mock_boto3_creds.secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + mock_boto3_creds.token = "fake-session-token" + + def mock_get_frozen_credentials(): + frozen = MagicMock() + frozen.access_key = "AKIAIOSFODNN7EXAMPLE" + frozen.secret_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + frozen.token = "fake-session-token" + return frozen + + mock_boto3_creds.get_frozen_credentials = mock_get_frozen_credentials + + # Mock GCP credentials + mock_gcp_creds = MagicMock() + mock_gcp_creds.token = "gcp-token" + mock_gcp_creds.expired = False + mock_gcp_creds.project_id = "test-project" + + with patch("litellm.llms.bedrock.base_aws_llm.BaseAWSLLM.get_credentials", return_value=mock_boto3_creds) as mock_aws_creds, \ + patch("google.auth.aws.Credentials", return_value=mock_gcp_creds) as mock_gcp_aws_creds, \ + patch.object(vertex_base, "refresh_auth") as mock_refresh: + + def mock_refresh_impl(creds): + creds.token = "refreshed-gcp-token" + + mock_refresh.side_effect = mock_refresh_impl + + # Call load_auth + creds, project_id = vertex_base.load_auth( + credentials=credentials, + project_id=None + ) + + # Verify BaseAWSLLM.get_credentials was called with correct params + mock_aws_creds.assert_called_once() + call_kwargs = mock_aws_creds.call_args[1] + assert call_kwargs["aws_role_name"] == "arn:aws:iam::123456789012:role/MyRole" + assert call_kwargs["aws_session_name"] == "litellm-test" + assert call_kwargs["aws_region_name"] == "us-east-1" + + # Verify google.auth.aws.Credentials was called + assert mock_gcp_aws_creds.called + call_kwargs = mock_gcp_aws_creds.call_args[1] + + # Verify custom supplier was used + assert call_kwargs["aws_security_credentials_supplier"] is not None + assert call_kwargs["credential_source"] is None # Not using metadata + + # Verify credentials were refreshed + assert mock_refresh.called + assert creds.token == "refreshed-gcp-token" + + def test_boto3_aws_security_credentials_supplier(self): + """ + Test Boto3AwsSecurityCredentialsSupplier correctly wraps boto3 credentials. + """ + # Mock boto3 credentials with get_frozen_credentials + mock_boto3_creds = MagicMock() + + frozen_creds = MagicMock() + frozen_creds.access_key = "AKIATEST123" + frozen_creds.secret_key = "secretkey123" + frozen_creds.token = "session-token-123" + + mock_boto3_creds.get_frozen_credentials = MagicMock(return_value=frozen_creds) + + # Create supplier + supplier = Boto3AwsSecurityCredentialsSupplier( + boto3_credentials=mock_boto3_creds, + aws_region="us-west-2" + ) + + # Test get_aws_security_credentials + creds = supplier.get_aws_security_credentials(context=None, request=None) + + assert creds["access_key_id"] == "AKIATEST123" + assert creds["secret_access_key"] == "secretkey123" + assert creds["security_token"] == "session-token-123" + + # Test get_aws_region + region = supplier.get_aws_region(context=None, request=None) + assert region == "us-west-2" + + def test_extract_aws_params(self): + """ + Test _extract_aws_params correctly identifies AWS auth parameters. + """ + vertex_base = VertexBase() + + # Test with AWS role params + json_obj = { + "type": "external_account", + "audience": "//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/pool/providers/aws", + "aws_role_name": "arn:aws:iam::123456789012:role/MyRole", + "aws_session_name": "my-session", + "aws_region_name": "us-east-1" + } + + aws_params = vertex_base._extract_aws_params(json_obj) + + assert aws_params is not None + assert aws_params["aws_role_name"] == "arn:aws:iam::123456789012:role/MyRole" + assert aws_params["aws_session_name"] == "my-session" + assert aws_params["aws_region_name"] == "us-east-1" + + # Test with no AWS params + json_obj_no_aws = { + "type": "external_account", + "audience": "//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/pool/providers/aws" + } + + aws_params_none = vertex_base._extract_aws_params(json_obj_no_aws) + assert aws_params_none is None + + def test_aws_federation_fallback_to_metadata(self): + """ + Test that when no AWS auth params are provided, it falls back to metadata-based flow. + """ + vertex_base = VertexBase() + + credentials = { + "type": "external_account", + "audience": "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/my-pool/providers/aws-provider", + "subject_token_type": "urn:ietf:params:aws:token-type:aws4_request", + "token_url": "https://sts.googleapis.com/v1/token", + "credential_source": { + "environment_id": "aws1", + "regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" + } + } + + mock_creds = MagicMock() + mock_creds.token = "token-from-metadata" + mock_creds.expired = False + mock_creds.project_id = "test-project" + + with patch.object( + vertex_base, "_credentials_from_identity_pool_with_aws", return_value=mock_creds + ) as mock_metadata_auth, \ + patch.object( + vertex_base, "_credentials_from_aws_with_explicit_auth" + ) as mock_explicit_auth, \ + patch.object(vertex_base, "refresh_auth") as mock_refresh: + + def mock_refresh_impl(creds): + creds.token = "refreshed-token" + + mock_refresh.side_effect = mock_refresh_impl + + # Call load_auth - should use metadata-based flow + creds, project_id = vertex_base.load_auth( + credentials=credentials, + project_id=None + ) + + # Verify metadata-based auth was used + assert mock_metadata_auth.called + # Verify explicit auth was NOT used + assert not mock_explicit_auth.called + # Verify credentials were returned + assert creds.token == "refreshed-token"