Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,14 +584,15 @@ def __init__(
validated_http_options = HttpOptions.model_validate(http_options)
except ValidationError as e:
raise ValueError('Invalid http_options') from e
elif isinstance(http_options, HttpOptions):
elif http_options and _common.is_duck_type_of(http_options, HttpOptions):
validated_http_options = http_options

if validated_http_options.base_url_resource_scope and not validated_http_options.base_url:
# base_url_resource_scope is only valid when base_url is set.
raise ValueError(
'base_url must be set when base_url_resource_scope is set.'
)
print('validated_http_options: ', validated_http_options)

# Retrieve implicitly set values from the environment.
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
Expand Down Expand Up @@ -649,7 +650,13 @@ def __init__(
else None
)

if not self.location and not self.api_key and not self.custom_base_url:
if (
not self.location
and not self.api_key
):
if not self.custom_base_url:
self.location = 'global'
elif self.custom_base_url.endswith('.googleapis.com'):
self.location = 'global'

# Skip fetching project from ADC if base url is provided in http options.
Expand All @@ -667,12 +674,16 @@ def __init__(
if not has_sufficient_auth and not self.custom_base_url:
# Skip sufficient auth check if base url is provided in http options.
raise ValueError(
'Project or API key must be set when using the Vertex '
'AI API.'
'Project or API key must be set when using the Vertex AI API.'
)
if self.api_key or self.location == 'global':
if (
self.api_key or self.location == 'global'
) and not self.custom_base_url:
self._http_options.base_url = f'https://aiplatform.googleapis.com/'
elif self.custom_base_url and not ((project and location) or api_key):
elif (
self.custom_base_url
and not self.custom_base_url.endswith('.googleapis.com')
) and not ((project and location) or api_key):
# Avoid setting default base url and api version if base_url provided.
# API gateway proxy can use the auth in custom headers, not url.
# Enable custom url if auth is not sufficient.
Expand Down
31 changes: 31 additions & 0 deletions google/genai/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,34 @@ def recursive_dict_update(
target_dict[key] = value
else:
target_dict[key] = value


def is_duck_type_of(obj: Any, cls: type[pydantic.BaseModel]) -> bool:
"""Checks if an object has all of the fields of a Pydantic model.
This is a duck-typing alternative to `isinstance` to solve dual-import
problems. It returns False for dictionaries, which should be handled by
`isinstance(obj, dict)`.
Args:
obj: The object to check.
cls: The Pydantic model class to duck-type against.
Returns:
True if the object has all the fields defined in the Pydantic model, False
otherwise.
"""
if isinstance(obj, dict) or not hasattr(cls, 'model_fields'):
return False

# Check if the object has all of the Pydantic model's defined fields.
all_matched = all(hasattr(obj, field) for field in cls.model_fields)
if not all_matched and isinstance(obj, pydantic.BaseModel):
# Check the other way around if obj is a Pydantic model.
# Check if the Pydantic model has all of the object's defined fields.
try:
obj_private = cls()
all_matched = all(hasattr(obj_private, f) for f in type(obj).model_fields)
except ValueError:
return False
return all_matched
74 changes: 22 additions & 52 deletions google/genai/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Any, GenericAlias, List, Optional, Sequence, Union # type: ignore[attr-defined]
from ._mcp_utils import mcp_to_gemini_tool
from ._common import get_value_by_path as getv
from ._common import is_duck_type_of

if typing.TYPE_CHECKING:
import PIL.Image
Expand Down Expand Up @@ -72,37 +73,6 @@
metric_name_api_sdk_map = {v: k for k, v in metric_name_sdk_api_map.items()}


def _is_duck_type_of(obj: Any, cls: type[pydantic.BaseModel]) -> bool:
"""Checks if an object has all of the fields of a Pydantic model.
This is a duck-typing alternative to `isinstance` to solve dual-import
problems. It returns False for dictionaries, which should be handled by
`isinstance(obj, dict)`.
Args:
obj: The object to check.
cls: The Pydantic model class to duck-type against.
Returns:
True if the object has all the fields defined in the Pydantic model, False
otherwise.
"""
if isinstance(obj, dict) or not hasattr(cls, 'model_fields'):
return False

# Check if the object has all of the Pydantic model's defined fields.
all_matched = all(hasattr(obj, field) for field in cls.model_fields)
if not all_matched and isinstance(obj, pydantic.BaseModel):
# Check the other way around if obj is a Pydantic model.
# Check if the Pydantic model has all of the object's defined fields.
try:
obj_private = cls()
all_matched = all(hasattr(obj_private, f) for f in type(obj).model_fields)
except ValueError:
return False
return all_matched


def _resource_name(
client: _api_client.BaseApiClient,
resource_name: str,
Expand Down Expand Up @@ -311,7 +281,7 @@ def t_function_response(
raise ValueError('function_response is required.')
if isinstance(function_response, dict):
return types.FunctionResponse.model_validate(function_response)
elif _is_duck_type_of(function_response, types.FunctionResponse):
elif is_duck_type_of(function_response, types.FunctionResponse):
return function_response
else:
raise TypeError(
Expand Down Expand Up @@ -347,7 +317,7 @@ def t_blob(blob: types.BlobImageUnionDict) -> types.Blob:
if not blob:
raise ValueError('blob is required.')

if _is_duck_type_of(blob, types.Blob):
if is_duck_type_of(blob, types.Blob):
return blob # type: ignore[return-value]

if isinstance(blob, dict):
Expand Down Expand Up @@ -388,7 +358,7 @@ def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
raise ValueError('content part is required.')
if isinstance(part, str):
return types.Part(text=part)
if _is_duck_type_of(part, types.File):
if is_duck_type_of(part, types.File):
if not part.uri or not part.mime_type: # type: ignore[union-attr]
raise ValueError('file uri and mime_type are required.')
return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type) # type: ignore[union-attr]
Expand All @@ -397,7 +367,7 @@ def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
return types.Part.model_validate(part)
except pydantic.ValidationError:
return types.Part(file_data=types.FileData.model_validate(part))
if _is_duck_type_of(part, types.Part):
if is_duck_type_of(part, types.Part):
return part # type: ignore[return-value]

if 'image' in part.__class__.__name__.lower():
Expand Down Expand Up @@ -454,7 +424,7 @@ def t_content(
) -> types.Content:
if content is None:
raise ValueError('content is required.')
if _is_duck_type_of(content, types.Content):
if is_duck_type_of(content, types.Content):
return content # type: ignore[return-value]
if isinstance(content, dict):
try:
Expand All @@ -466,9 +436,9 @@ def t_content(
if possible_part.function_call
else types.UserContent(parts=[possible_part])
)
if _is_duck_type_of(content, types.File):
if is_duck_type_of(content, types.File):
return types.UserContent(parts=[t_part(content)]) # type: ignore[arg-type]
if _is_duck_type_of(content, types.Part):
if is_duck_type_of(content, types.Part):
return (
types.ModelContent(parts=[content]) # type: ignore[arg-type]
if content.function_call # type: ignore[union-attr]
Expand Down Expand Up @@ -521,8 +491,8 @@ def _is_part(
) -> TypeGuard[types.PartUnionDict]:
if (
isinstance(part, str)
or _is_duck_type_of(part, types.File)
or _is_duck_type_of(part, types.Part)
or is_duck_type_of(part, types.File)
or is_duck_type_of(part, types.Part)
):
return True

Expand Down Expand Up @@ -592,7 +562,7 @@ def _handle_current_part(
# append to result
# if list, we only accept a list of types.PartUnion
for content in contents:
if _is_duck_type_of(content, types.Content) or isinstance(content, list):
if is_duck_type_of(content, types.Content) or isinstance(content, list):
_append_accumulated_parts_as_content(result, accumulated_parts)
if isinstance(content, list):
result.append(types.UserContent(parts=content)) # type: ignore[arg-type]
Expand Down Expand Up @@ -889,7 +859,7 @@ def t_schema(
return types.Schema.model_validate(origin)
if isinstance(origin, EnumMeta):
return _process_enum(origin, client)
if _is_duck_type_of(origin, types.Schema):
if is_duck_type_of(origin, types.Schema):
if dict(origin) == dict(types.Schema()): # type: ignore [arg-type]
# response_schema value was coerced to an empty Schema instance because
# it did not adhere to the Schema field annotation
Expand Down Expand Up @@ -931,7 +901,7 @@ def t_speech_config(
) -> Optional[types.SpeechConfig]:
if not origin:
return None
if _is_duck_type_of(origin, types.SpeechConfig):
if is_duck_type_of(origin, types.SpeechConfig):
return origin # type: ignore[return-value]
if isinstance(origin, str):
return types.SpeechConfig(
Expand All @@ -948,7 +918,7 @@ def t_speech_config(
def t_live_speech_config(
origin: types.SpeechConfigOrDict,
) -> Optional[types.SpeechConfig]:
if _is_duck_type_of(origin, types.SpeechConfig):
if is_duck_type_of(origin, types.SpeechConfig):
speech_config = origin
if isinstance(origin, dict):
speech_config = types.SpeechConfig.model_validate(origin)
Expand All @@ -974,7 +944,7 @@ def t_tool(
)
]
)
elif McpTool is not None and _is_duck_type_of(origin, McpTool):
elif McpTool is not None and is_duck_type_of(origin, McpTool):
return mcp_to_gemini_tool(origin)
elif isinstance(origin, dict):
return types.Tool.model_validate(origin)
Expand Down Expand Up @@ -1017,7 +987,7 @@ def t_batch_job_source(
) -> types.BatchJobSource:
if isinstance(src, dict):
src = types.BatchJobSource(**src)
if _is_duck_type_of(src, types.BatchJobSource):
if is_duck_type_of(src, types.BatchJobSource):
vertex_sources = sum(
[src.gcs_uri is not None, src.bigquery_uri is not None] # type: ignore[union-attr]
)
Expand Down Expand Up @@ -1068,7 +1038,7 @@ def t_embedding_batch_job_source(
if isinstance(src, dict):
src = types.EmbeddingsBatchJobSource(**src)

if _is_duck_type_of(src, types.EmbeddingsBatchJobSource):
if is_duck_type_of(src, types.EmbeddingsBatchJobSource):
mldev_sources = sum([
src.inlined_requests is not None,
src.file_name is not None,
Expand Down Expand Up @@ -1103,7 +1073,7 @@ def t_batch_job_destination(
)
else:
raise ValueError(f'Unsupported destination: {dest}')
elif _is_duck_type_of(dest, types.BatchJobDestination):
elif is_duck_type_of(dest, types.BatchJobDestination):
return dest
else:
raise ValueError(f'Unsupported destination: {dest}')
Expand Down Expand Up @@ -1203,11 +1173,11 @@ def t_file_name(
name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]],
) -> str:
# Remove the files/ prefix since it's added to the url path.
if _is_duck_type_of(name, types.File):
if is_duck_type_of(name, types.File):
name = name.name # type: ignore[union-attr]
elif _is_duck_type_of(name, types.Video):
elif is_duck_type_of(name, types.Video):
name = name.uri # type: ignore[union-attr]
elif _is_duck_type_of(name, types.GeneratedVideo):
elif is_duck_type_of(name, types.GeneratedVideo):
if name.video is not None: # type: ignore[union-attr]
name = name.video.uri # type: ignore[union-attr]
else:
Expand Down Expand Up @@ -1252,7 +1222,7 @@ def t_tuning_job_status(status: str) -> Union[types.JobState, str]:
def t_content_strict(content: types.ContentOrDict) -> types.Content:
if isinstance(content, dict):
return types.Content.model_validate(content)
elif _is_duck_type_of(content, types.Content):
elif is_duck_type_of(content, types.Content):
return content
else:
raise ValueError(
Expand Down
36 changes: 36 additions & 0 deletions google/genai/tests/client/test_client_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,42 @@ def test_vertexai_default_location_to_global_with_explicit_project_and_env_apike
assert not client.models._api_client.api_key


def test_vertexai_default_location_to_global_with_vertexai_base_url(
monkeypatch,
):
# Test case 4: When project and vertex base url are set
project_id = "env_project_id"

with monkeypatch.context() as m:
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
m.setenv("GOOGLE_CLOUD_PROJECT", project_id)
client = Client(
vertexai=True,
http_options={'base_url': 'https://fake-url.googleapis.com'},
)
# Implicit project takes precedence over implicit api_key
assert client.models._api_client.location == "global"
assert client.models._api_client.project == project_id


def test_vertexai_default_location_to_global_with_arbitrary_base_url(
monkeypatch,
):
# Test case 5: When project and arbitrary base url (proxy) are set
project_id = "env_project_id"

with monkeypatch.context() as m:
m.delenv("GOOGLE_CLOUD_LOCATION", raising=False)
m.setenv("GOOGLE_CLOUD_PROJECT", project_id)
client = Client(
vertexai=True,
http_options={'base_url': 'https://fake-url.com'},
)
# Implicit project takes precedence over implicit api_key
assert not client.models._api_client.location
assert not client.models._api_client.project


def test_vertexai_default_location_to_global_with_env_project_and_env_apikey(
monkeypatch,
):
Expand Down
Loading