diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index 41136b753..94b372aca 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -584,7 +584,7 @@ def __init__( validated_http_options = HttpOptions.model_validate(http_options) except ValidationError as e: raise ValueError('Invalid http_options') from e - elif isinstance(http_options, HttpOptions): + elif http_options and _common.is_duck_type_of(http_options, HttpOptions): validated_http_options = http_options if validated_http_options.base_url_resource_scope and not validated_http_options.base_url: @@ -592,6 +592,7 @@ def __init__( raise ValueError( 'base_url must be set when base_url_resource_scope is set.' ) + print('validated_http_options: ', validated_http_options) # Retrieve implicitly set values from the environment. env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None) @@ -649,7 +650,13 @@ def __init__( else None ) - if not self.location and not self.api_key and not self.custom_base_url: + if ( + not self.location + and not self.api_key + ): + if not self.custom_base_url: + self.location = 'global' + elif self.custom_base_url.endswith('.googleapis.com'): self.location = 'global' # Skip fetching project from ADC if base url is provided in http options. @@ -667,12 +674,16 @@ def __init__( if not has_sufficient_auth and not self.custom_base_url: # Skip sufficient auth check if base url is provided in http options. raise ValueError( - 'Project or API key must be set when using the Vertex ' - 'AI API.' + 'Project or API key must be set when using the Vertex AI API.' ) - if self.api_key or self.location == 'global': + if ( + self.api_key or self.location == 'global' + ) and not self.custom_base_url: self._http_options.base_url = f'https://aiplatform.googleapis.com/' - elif self.custom_base_url and not ((project and location) or api_key): + elif ( + self.custom_base_url + and not self.custom_base_url.endswith('.googleapis.com') + ) and not ((project and location) or api_key): # Avoid setting default base url and api version if base_url provided. # API gateway proxy can use the auth in custom headers, not url. # Enable custom url if auth is not sufficient. diff --git a/google/genai/_common.py b/google/genai/_common.py index 8923fca0f..f4a6984f7 100644 --- a/google/genai/_common.py +++ b/google/genai/_common.py @@ -814,3 +814,34 @@ def recursive_dict_update( target_dict[key] = value else: target_dict[key] = value + + +def is_duck_type_of(obj: Any, cls: type[pydantic.BaseModel]) -> bool: + """Checks if an object has all of the fields of a Pydantic model. + + This is a duck-typing alternative to `isinstance` to solve dual-import + problems. It returns False for dictionaries, which should be handled by + `isinstance(obj, dict)`. + + Args: + obj: The object to check. + cls: The Pydantic model class to duck-type against. + + Returns: + True if the object has all the fields defined in the Pydantic model, False + otherwise. + """ + if isinstance(obj, dict) or not hasattr(cls, 'model_fields'): + return False + + # Check if the object has all of the Pydantic model's defined fields. + all_matched = all(hasattr(obj, field) for field in cls.model_fields) + if not all_matched and isinstance(obj, pydantic.BaseModel): + # Check the other way around if obj is a Pydantic model. + # Check if the Pydantic model has all of the object's defined fields. + try: + obj_private = cls() + all_matched = all(hasattr(obj_private, f) for f in type(obj).model_fields) + except ValueError: + return False + return all_matched diff --git a/google/genai/_transformers.py b/google/genai/_transformers.py index be3303e6d..91a32a881 100644 --- a/google/genai/_transformers.py +++ b/google/genai/_transformers.py @@ -29,6 +29,7 @@ from typing import Any, GenericAlias, List, Optional, Sequence, Union # type: ignore[attr-defined] from ._mcp_utils import mcp_to_gemini_tool from ._common import get_value_by_path as getv +from ._common import is_duck_type_of if typing.TYPE_CHECKING: import PIL.Image @@ -72,37 +73,6 @@ metric_name_api_sdk_map = {v: k for k, v in metric_name_sdk_api_map.items()} -def _is_duck_type_of(obj: Any, cls: type[pydantic.BaseModel]) -> bool: - """Checks if an object has all of the fields of a Pydantic model. - - This is a duck-typing alternative to `isinstance` to solve dual-import - problems. It returns False for dictionaries, which should be handled by - `isinstance(obj, dict)`. - - Args: - obj: The object to check. - cls: The Pydantic model class to duck-type against. - - Returns: - True if the object has all the fields defined in the Pydantic model, False - otherwise. - """ - if isinstance(obj, dict) or not hasattr(cls, 'model_fields'): - return False - - # Check if the object has all of the Pydantic model's defined fields. - all_matched = all(hasattr(obj, field) for field in cls.model_fields) - if not all_matched and isinstance(obj, pydantic.BaseModel): - # Check the other way around if obj is a Pydantic model. - # Check if the Pydantic model has all of the object's defined fields. - try: - obj_private = cls() - all_matched = all(hasattr(obj_private, f) for f in type(obj).model_fields) - except ValueError: - return False - return all_matched - - def _resource_name( client: _api_client.BaseApiClient, resource_name: str, @@ -311,7 +281,7 @@ def t_function_response( raise ValueError('function_response is required.') if isinstance(function_response, dict): return types.FunctionResponse.model_validate(function_response) - elif _is_duck_type_of(function_response, types.FunctionResponse): + elif is_duck_type_of(function_response, types.FunctionResponse): return function_response else: raise TypeError( @@ -347,7 +317,7 @@ def t_blob(blob: types.BlobImageUnionDict) -> types.Blob: if not blob: raise ValueError('blob is required.') - if _is_duck_type_of(blob, types.Blob): + if is_duck_type_of(blob, types.Blob): return blob # type: ignore[return-value] if isinstance(blob, dict): @@ -388,7 +358,7 @@ def t_part(part: Optional[types.PartUnionDict]) -> types.Part: raise ValueError('content part is required.') if isinstance(part, str): return types.Part(text=part) - if _is_duck_type_of(part, types.File): + if is_duck_type_of(part, types.File): if not part.uri or not part.mime_type: # type: ignore[union-attr] raise ValueError('file uri and mime_type are required.') return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type) # type: ignore[union-attr] @@ -397,7 +367,7 @@ def t_part(part: Optional[types.PartUnionDict]) -> types.Part: return types.Part.model_validate(part) except pydantic.ValidationError: return types.Part(file_data=types.FileData.model_validate(part)) - if _is_duck_type_of(part, types.Part): + if is_duck_type_of(part, types.Part): return part # type: ignore[return-value] if 'image' in part.__class__.__name__.lower(): @@ -454,7 +424,7 @@ def t_content( ) -> types.Content: if content is None: raise ValueError('content is required.') - if _is_duck_type_of(content, types.Content): + if is_duck_type_of(content, types.Content): return content # type: ignore[return-value] if isinstance(content, dict): try: @@ -466,9 +436,9 @@ def t_content( if possible_part.function_call else types.UserContent(parts=[possible_part]) ) - if _is_duck_type_of(content, types.File): + if is_duck_type_of(content, types.File): return types.UserContent(parts=[t_part(content)]) # type: ignore[arg-type] - if _is_duck_type_of(content, types.Part): + if is_duck_type_of(content, types.Part): return ( types.ModelContent(parts=[content]) # type: ignore[arg-type] if content.function_call # type: ignore[union-attr] @@ -521,8 +491,8 @@ def _is_part( ) -> TypeGuard[types.PartUnionDict]: if ( isinstance(part, str) - or _is_duck_type_of(part, types.File) - or _is_duck_type_of(part, types.Part) + or is_duck_type_of(part, types.File) + or is_duck_type_of(part, types.Part) ): return True @@ -592,7 +562,7 @@ def _handle_current_part( # append to result # if list, we only accept a list of types.PartUnion for content in contents: - if _is_duck_type_of(content, types.Content) or isinstance(content, list): + if is_duck_type_of(content, types.Content) or isinstance(content, list): _append_accumulated_parts_as_content(result, accumulated_parts) if isinstance(content, list): result.append(types.UserContent(parts=content)) # type: ignore[arg-type] @@ -889,7 +859,7 @@ def t_schema( return types.Schema.model_validate(origin) if isinstance(origin, EnumMeta): return _process_enum(origin, client) - if _is_duck_type_of(origin, types.Schema): + if is_duck_type_of(origin, types.Schema): if dict(origin) == dict(types.Schema()): # type: ignore [arg-type] # response_schema value was coerced to an empty Schema instance because # it did not adhere to the Schema field annotation @@ -931,7 +901,7 @@ def t_speech_config( ) -> Optional[types.SpeechConfig]: if not origin: return None - if _is_duck_type_of(origin, types.SpeechConfig): + if is_duck_type_of(origin, types.SpeechConfig): return origin # type: ignore[return-value] if isinstance(origin, str): return types.SpeechConfig( @@ -948,7 +918,7 @@ def t_speech_config( def t_live_speech_config( origin: types.SpeechConfigOrDict, ) -> Optional[types.SpeechConfig]: - if _is_duck_type_of(origin, types.SpeechConfig): + if is_duck_type_of(origin, types.SpeechConfig): speech_config = origin if isinstance(origin, dict): speech_config = types.SpeechConfig.model_validate(origin) @@ -974,7 +944,7 @@ def t_tool( ) ] ) - elif McpTool is not None and _is_duck_type_of(origin, McpTool): + elif McpTool is not None and is_duck_type_of(origin, McpTool): return mcp_to_gemini_tool(origin) elif isinstance(origin, dict): return types.Tool.model_validate(origin) @@ -1017,7 +987,7 @@ def t_batch_job_source( ) -> types.BatchJobSource: if isinstance(src, dict): src = types.BatchJobSource(**src) - if _is_duck_type_of(src, types.BatchJobSource): + if is_duck_type_of(src, types.BatchJobSource): vertex_sources = sum( [src.gcs_uri is not None, src.bigquery_uri is not None] # type: ignore[union-attr] ) @@ -1068,7 +1038,7 @@ def t_embedding_batch_job_source( if isinstance(src, dict): src = types.EmbeddingsBatchJobSource(**src) - if _is_duck_type_of(src, types.EmbeddingsBatchJobSource): + if is_duck_type_of(src, types.EmbeddingsBatchJobSource): mldev_sources = sum([ src.inlined_requests is not None, src.file_name is not None, @@ -1103,7 +1073,7 @@ def t_batch_job_destination( ) else: raise ValueError(f'Unsupported destination: {dest}') - elif _is_duck_type_of(dest, types.BatchJobDestination): + elif is_duck_type_of(dest, types.BatchJobDestination): return dest else: raise ValueError(f'Unsupported destination: {dest}') @@ -1203,11 +1173,11 @@ def t_file_name( name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]], ) -> str: # Remove the files/ prefix since it's added to the url path. - if _is_duck_type_of(name, types.File): + if is_duck_type_of(name, types.File): name = name.name # type: ignore[union-attr] - elif _is_duck_type_of(name, types.Video): + elif is_duck_type_of(name, types.Video): name = name.uri # type: ignore[union-attr] - elif _is_duck_type_of(name, types.GeneratedVideo): + elif is_duck_type_of(name, types.GeneratedVideo): if name.video is not None: # type: ignore[union-attr] name = name.video.uri # type: ignore[union-attr] else: @@ -1252,7 +1222,7 @@ def t_tuning_job_status(status: str) -> Union[types.JobState, str]: def t_content_strict(content: types.ContentOrDict) -> types.Content: if isinstance(content, dict): return types.Content.model_validate(content) - elif _is_duck_type_of(content, types.Content): + elif is_duck_type_of(content, types.Content): return content else: raise ValueError( diff --git a/google/genai/tests/client/test_client_initialization.py b/google/genai/tests/client/test_client_initialization.py index abeb780b7..7b0136044 100644 --- a/google/genai/tests/client/test_client_initialization.py +++ b/google/genai/tests/client/test_client_initialization.py @@ -449,6 +449,42 @@ def test_vertexai_default_location_to_global_with_explicit_project_and_env_apike assert not client.models._api_client.api_key +def test_vertexai_default_location_to_global_with_vertexai_base_url( + monkeypatch, +): + # Test case 4: When project and vertex base url are set + project_id = "env_project_id" + + with monkeypatch.context() as m: + m.delenv("GOOGLE_CLOUD_LOCATION", raising=False) + m.setenv("GOOGLE_CLOUD_PROJECT", project_id) + client = Client( + vertexai=True, + http_options={'base_url': 'https://fake-url.googleapis.com'}, + ) + # Implicit project takes precedence over implicit api_key + assert client.models._api_client.location == "global" + assert client.models._api_client.project == project_id + + +def test_vertexai_default_location_to_global_with_arbitrary_base_url( + monkeypatch, +): + # Test case 5: When project and arbitrary base url (proxy) are set + project_id = "env_project_id" + + with monkeypatch.context() as m: + m.delenv("GOOGLE_CLOUD_LOCATION", raising=False) + m.setenv("GOOGLE_CLOUD_PROJECT", project_id) + client = Client( + vertexai=True, + http_options={'base_url': 'https://fake-url.com'}, + ) + # Implicit project takes precedence over implicit api_key + assert not client.models._api_client.location + assert not client.models._api_client.project + + def test_vertexai_default_location_to_global_with_env_project_and_env_apikey( monkeypatch, ): diff --git a/google/genai/tests/transformers/test_duck_type.py b/google/genai/tests/common/test_duck_type.py similarity index 76% rename from google/genai/tests/transformers/test_duck_type.py rename to google/genai/tests/common/test_duck_type.py index ca2da5934..bddee485c 100644 --- a/google/genai/tests/transformers/test_duck_type.py +++ b/google/genai/tests/common/test_duck_type.py @@ -2,7 +2,7 @@ import pydantic -from ... import _transformers +from ... import _common class TestIsDuckTypeOf(unittest.TestCase): @@ -28,7 +28,7 @@ def test_is_duck_type_of_true_for_pydantic_object(self): field4="c", field5="d", ) - self.assertTrue(_transformers._is_duck_type_of(obj, self.FakePydanticModel)) + self.assertTrue(_common.is_duck_type_of(obj, self.FakePydanticModel)) def test_is_duck_type_of_true_for_duck_typed_object(self): class DuckTypedObject: @@ -41,7 +41,7 @@ def __init__(self): self.field5 = "d" obj = DuckTypedObject() - self.assertTrue(_transformers._is_duck_type_of(obj, self.FakePydanticModel)) + self.assertTrue(_common.is_duck_type_of(obj, self.FakePydanticModel)) def test_is_duck_type_of_false_for_different_many_fields(self): class DifferentFieldsObject: @@ -51,7 +51,7 @@ def __init__(self): obj = DifferentFieldsObject() self.assertFalse( - _transformers._is_duck_type_of(obj, self.FakePydanticModel) + _common.is_duck_type_of(obj, self.FakePydanticModel) ) def test_is_duck_type_of_false_for_missing_fields(self): @@ -59,12 +59,12 @@ def test_is_duck_type_of_false_for_missing_fields(self): obj = self.FakePydanticModelWithLessFields( field1="a", field2=1, field3="b", field4="c" ) - self.assertFalse(_transformers._is_duck_type_of(obj, self.FakePydanticModel)) + self.assertFalse(_common.is_duck_type_of(obj, self.FakePydanticModel)) def test_is_duck_type_of_false_for_dict(self): obj = {"field1": "a", "field2": 1} self.assertFalse( - _transformers._is_duck_type_of(obj, self.FakePydanticModel) + _common.is_duck_type_of(obj, self.FakePydanticModel) ) def test_is_duck_type_of_false_for_non_pydantic_class(self): @@ -75,7 +75,7 @@ class SomeObject: pass obj = SomeObject() - self.assertFalse(_transformers._is_duck_type_of(obj, NonPydanticModel)) + self.assertFalse(_common.is_duck_type_of(obj, NonPydanticModel)) def test_is_duck_type_of_true_with_extra_fields(self): class ExtraFieldsObject: @@ -89,7 +89,7 @@ def __init__(self): self.field6 = "extra" obj = ExtraFieldsObject() - self.assertTrue(_transformers._is_duck_type_of(obj, self.FakePydanticModel)) + self.assertTrue(_common.is_duck_type_of(obj, self.FakePydanticModel)) if __name__ == "__main__":