Skip to content

Commit add8cd8

Browse files
feat(gemini): Support gemini-embedding-001 and fix models/ prefix in metadata keys (#3813)
# Add support for Google Gemini `gemini-embedding-001` embedding model and correctly registers model type MR message created with the assistance of Claude-4.5-sonnet This resolves #3755 ## What does this PR do? This PR adds support for the `gemini-embedding-001` Google embedding model to the llama-stack Gemini provider. This model provides high-dimensional embeddings (3072 dimensions) compared to the existing `text-embedding-004` model (768 dimensions). Old embeddings models (such as text-embedding-004) will be deprecated soon according to Google ([Link](https://developers.googleblog.com/en/gemini-embedding-available-gemini-api/)) ## Problem The Gemini provider only supported the `text-embedding-004` embedding model. The newer `gemini-embedding-001` model, which provides higher-dimensional embeddings for improved semantic representation, was not available through llama-stack. ## Solution This PR consists of three commits that implement, fix the model registration, and enable embedding generation: ### Commit 1: Initial addition of gemini-embedding-001 Added metadata for `gemini-embedding-001` to the `embedding_model_metadata` dictionary: ```python embedding_model_metadata: dict[str, dict[str, int]] = { "text-embedding-004": {"embedding_dimension": 768, "context_length": 2048}, "gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048}, # NEW } ``` **Issue discovered:** The model was not being registered correctly because the dictionary keys didn't match the model IDs returned by Gemini's API. ### Commit 2: Fix model ID matching with `models/` prefix Updated both dictionary keys to include the `models/` prefix to match Gemini's OpenAI-compatible API response format: ```python embedding_model_metadata: dict[str, dict[str, int]] = { "models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048}, # UPDATED "models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048}, # UPDATED } ``` **Root cause:** Gemini's OpenAI-compatible API returns model IDs with the `models/` prefix (e.g., `models/text-embedding-004`). The `OpenAIMixin.list_models()` method directly matches these IDs against the `embedding_model_metadata` dictionary keys. Without the prefix, the models were being registered as LLMs instead of embedding models. ### Commit 3: Fix embedding generation for providers without usage stats Fixed a bug in `OpenAIMixin.openai_embeddings()` that prevented embedding generation for providers (like Gemini) that don't return usage statistics: ```python # Before (Line 351-354): usage = OpenAIEmbeddingUsage( prompt_tokens=response.usage.prompt_tokens, # ← Crashed with AttributeError total_tokens=response.usage.total_tokens, ) # After (Lines 351-362): if response.usage: usage = OpenAIEmbeddingUsage( prompt_tokens=response.usage.prompt_tokens, total_tokens=response.usage.total_tokens, ) else: usage = OpenAIEmbeddingUsage( prompt_tokens=0, # Default when not provided total_tokens=0, # Default when not provided ) ``` **Impact:** This fix enables embedding generation for **all** Gemini embedding models, not just the newly added one. ## Changes ### Modified Files **`llama_stack/providers/remote/inference/gemini/gemini.py`** - Line 17: Updated `text-embedding-004` key to `models/text-embedding-004` - Line 18: Added `models/gemini-embedding-001` with correct metadata **`llama_stack/providers/utils/inference/openai_mixin.py`** - Lines 351-362: Added null check for `response.usage` to handle providers without usage statistics ## Key Technical Details ### Model ID Matching Flow 1. `list_provider_model_ids()` calls Gemini's `/v1/models` endpoint 2. API returns model IDs like: `models/text-embedding-004`, `models/gemini-embedding-001` 3. `OpenAIMixin.list_models()` (line 410) checks: `if metadata := self.embedding_model_metadata.get(provider_model_id)` 4. If matched, registers as `model_type: "embedding"` with metadata; otherwise registers as `model_type: "llm"` ### Why Both Keys Needed the Prefix The `text-embedding-004` model was already working because there was likely separate configuration or manual registration handling it. For auto-discovery to work correctly for **both** models, both keys must match the API's model ID format exactly. ## How to test this PR Verified the changes by: 1. **Model Auto-Discovery**: Started llama-stack server and confirmed models are auto-discovered from Gemini API 2. **Model Registration**: Confirmed both embedding models are correctly registered and visible ```bash curl http://localhost:8325/v1/models | jq '.data[] | select(.provider_id == "gemini" and .model_type == "embedding")' ``` **Results:** - ✅ `gemini/models/text-embedding-004` - 768 dimensions - `model_type: "embedding"` - ✅ `gemini/models/gemini-embedding-001` - 3072 dimensions - `model_type: "embedding"` 3. **Before Fix (Commit 1)**: Models appeared as `model_type: "llm"` without embedding metadata 4. **After Fix (Commit 2)**: Models correctly identified as `model_type: "embedding"` with proper metadata 5. **Generate Embeddings**: Verified embedding generation works ```bash curl -X POST http://localhost:8325/v1/embeddings \ -H "Content-Type: application/json" \ -d '{"model": "gemini/models/gemini-embedding-001", "input": "test"}' | \ jq '.data[0].embedding | length' ```
1 parent ce8ea2f commit add8cd8

File tree

1 file changed

+62
-1
lines changed
  • llama_stack/providers/remote/inference/gemini

1 file changed

+62
-1
lines changed

llama_stack/providers/remote/inference/gemini/gemini.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
from openai import NOT_GIVEN
8+
9+
from llama_stack.apis.inference import (
10+
OpenAIEmbeddingData,
11+
OpenAIEmbeddingsRequestWithExtraBody,
12+
OpenAIEmbeddingsResponse,
13+
OpenAIEmbeddingUsage,
14+
)
715
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
816

917
from .config import GeminiConfig
@@ -14,8 +22,61 @@ class GeminiInferenceAdapter(OpenAIMixin):
1422

1523
provider_data_api_key_field: str = "gemini_api_key"
1624
embedding_model_metadata: dict[str, dict[str, int]] = {
17-
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
25+
"models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
26+
"models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048},
1827
}
1928

2029
def get_base_url(self):
2130
return "https://generativelanguage.googleapis.com/v1beta/openai/"
31+
32+
async def openai_embeddings(
33+
self,
34+
params: OpenAIEmbeddingsRequestWithExtraBody,
35+
) -> OpenAIEmbeddingsResponse:
36+
"""
37+
Override embeddings method to handle Gemini's missing usage statistics.
38+
Gemini's embedding API doesn't return usage information, so we provide default values.
39+
"""
40+
# Prepare request parameters
41+
request_params = {
42+
"model": await self._get_provider_model_id(params.model),
43+
"input": params.input,
44+
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
45+
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
46+
"user": params.user if params.user is not None else NOT_GIVEN,
47+
}
48+
49+
# Add extra_body if present
50+
extra_body = params.model_extra
51+
if extra_body:
52+
request_params["extra_body"] = extra_body
53+
54+
# Call OpenAI embeddings API with properly typed parameters
55+
response = await self.client.embeddings.create(**request_params)
56+
57+
data = []
58+
for i, embedding_data in enumerate(response.data):
59+
data.append(
60+
OpenAIEmbeddingData(
61+
embedding=embedding_data.embedding,
62+
index=i,
63+
)
64+
)
65+
66+
# Gemini doesn't return usage statistics - use default values
67+
if hasattr(response, "usage") and response.usage:
68+
usage = OpenAIEmbeddingUsage(
69+
prompt_tokens=response.usage.prompt_tokens,
70+
total_tokens=response.usage.total_tokens,
71+
)
72+
else:
73+
usage = OpenAIEmbeddingUsage(
74+
prompt_tokens=0,
75+
total_tokens=0,
76+
)
77+
78+
return OpenAIEmbeddingsResponse(
79+
data=data,
80+
model=params.model,
81+
usage=usage,
82+
)

0 commit comments

Comments
 (0)