From 30f0d387763d338301adf9328eb26256044ddb34 Mon Sep 17 00:00:00 2001 From: Alvin Date: Tue, 17 Feb 2026 03:11:57 +0800 Subject: [PATCH 01/16] feat(clients): add native Gemini client --- src/clients.rs | 3 + src/clients/gemini.rs | 563 ++++++++++++++++++++++++++++++++++++++++++ src/prelude.rs | 2 + 3 files changed, 568 insertions(+) create mode 100644 src/clients/gemini.rs diff --git a/src/clients.rs b/src/clients.rs index 185e8ce..e8f6182 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -7,6 +7,9 @@ pub mod openai_image; #[cfg(feature = "api-clients")] pub mod openai_stt; +#[cfg(feature = "api-clients")] +pub mod gemini; + #[cfg(feature = "realtime-clients")] pub mod openai_realtime; diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs new file mode 100644 index 0000000..e54de30 --- /dev/null +++ b/src/clients/gemini.rs @@ -0,0 +1,563 @@ +use crate::protocol::Tool; +use crate::protocol::*; +use crate::utils::asynchronous::{BoxPlatformSendFuture, BoxPlatformSendStream}; +use crate::utils::sse::parse_sse; +use async_stream::stream; +use reqwest::header::{HeaderMap, HeaderName}; +use serde::{Deserialize, Serialize}; +use std::{ + str::FromStr, + sync::{Arc, RwLock}, +}; +use url::Url; + +#[derive(Clone, Debug)] +struct GeminiClientInner { + url: String, + headers: HeaderMap, + api_key: Option, + client: reqwest::Client, +} + +/// A native Gemini API client using `/models` and `:streamGenerateContent`. +#[derive(Debug)] +pub struct GeminiClient(Arc>); + +impl Clone for GeminiClient { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl GeminiClient { + /// Creates a new Gemini client for the given API base URL. + pub fn new(url: String) -> Self { + let inner = GeminiClientInner { + url, + headers: HeaderMap::new(), + api_key: None, + client: crate::utils::http::default_client(), + }; + Self(Arc::new(RwLock::new(inner))) + } + + /// Sets a custom HTTP header used in all Gemini requests. + pub fn set_header(&mut self, key: &str, value: &str) -> Result<(), &'static str> { + let header_name = HeaderName::from_str(key).map_err(|_| "Invalid header name")?; + let header_value = value.parse().map_err(|_| "Invalid header value")?; + self.0 + .write() + .unwrap() + .headers + .insert(header_name, header_value); + Ok(()) + } + + /// Sets the Gemini API key used for request authentication. + pub fn set_key(&mut self, key: &str) -> Result<(), &'static str> { + self.0.write().unwrap().api_key = Some(key.to_string()); + self.set_header("x-goog-api-key", key) + } +} + +#[derive(Debug, Deserialize)] +struct GeminiModelsResponse { + #[serde(default)] + models: Vec, +} + +#[derive(Debug, Deserialize)] +struct GeminiModel { + name: String, + #[serde(rename = "displayName")] + display_name: Option, + #[serde(rename = "supportedGenerationMethods")] + #[serde(default)] + supported_generation_methods: Vec, +} + +#[derive(Debug, Serialize)] +struct GeminiGenerateRequest { + contents: Vec, + #[serde(rename = "system_instruction")] + #[serde(skip_serializing_if = "Option::is_none")] + system_instruction: Option, +} + +#[derive(Debug, Serialize)] +struct GeminiSystemInstruction { + parts: Vec, +} + +#[derive(Debug, Serialize)] +struct GeminiContent { + role: String, + parts: Vec, +} + +#[derive(Debug, Serialize)] +struct GeminiTextPart { + text: String, +} + +#[derive(Debug, Deserialize)] +struct GeminiStreamEvent { + #[serde(default)] + candidates: Vec, +} + +#[derive(Debug, Deserialize)] +struct GeminiCandidate { + content: Option, +} + +#[derive(Debug, Deserialize)] +struct GeminiCandidateContent { + #[serde(default)] + parts: Vec, +} + +#[derive(Debug, Deserialize)] +struct GeminiStreamPart { + #[serde(default)] + text: String, +} + +fn normalize_model_id(id: &str) -> &str { + id.trim_start_matches("models/") +} + +fn build_endpoint_url( + base_url: &str, + suffix: &str, + api_key: Option<&str>, + extra_query: &[(&str, &str)], +) -> Result { + let mut url = Url::parse(base_url).map_err(|error| { + ClientError::new_with_source( + ClientErrorKind::Format, + format!("Invalid Gemini base URL: {base_url}"), + Some(error), + ) + })?; + + let base_path = url.path().trim_end_matches('/'); + let suffix = suffix.trim_start_matches('/'); + let path = format!("{base_path}/{suffix}"); + url.set_path(&path); + + { + let mut query = url.query_pairs_mut(); + for (key, value) in extra_query { + query.append_pair(key, value); + } + if let Some(api_key) = api_key { + query.append_pair("key", api_key); + } + } + + Ok(url.to_string()) +} + +fn build_models_url(base_url: &str, api_key: Option<&str>) -> Result { + build_endpoint_url(base_url, "models", api_key, &[]) +} + +fn build_stream_url( + base_url: &str, + bot_id: &BotId, + api_key: Option<&str>, +) -> Result { + let model_id = normalize_model_id(bot_id.id()); + let suffix = format!("models/{model_id}:streamGenerateContent"); + build_endpoint_url(base_url, &suffix, api_key, &[("alt", "sse")]) +} + +fn supports_generate_content(model: &GeminiModel) -> bool { + model.supported_generation_methods.is_empty() + || model + .supported_generation_methods + .iter() + .any(|method| method == "generateContent") +} + +fn parse_models_response(payload: &str) -> Result, ClientError> { + let response: GeminiModelsResponse = serde_json::from_str(payload).map_err(|error| { + ClientError::new_with_source( + ClientErrorKind::Format, + "Could not parse Gemini models response.".to_string(), + Some(error), + ) + })?; + + let bots = response + .models + .iter() + .filter(|model| supports_generate_content(model)) + .map(|model| { + let normalized_id = normalize_model_id(&model.name); + let name = model + .display_name + .clone() + .unwrap_or_else(|| normalized_id.to_string()); + + Bot { + id: BotId::new(normalized_id), + name, + avatar: EntityAvatar::from_first_grapheme(&model.name.to_uppercase()) + .unwrap_or_else(|| EntityAvatar::Text("?".into())), + capabilities: BotCapabilities::new().with_capabilities([BotCapability::TextInput]), + } + }) + .collect(); + + Ok(bots) +} + +fn message_text(message: &Message) -> String { + if !message.content.text.is_empty() { + return message.content.text.clone(); + } + + if message.content.tool_results.is_empty() { + return String::new(); + } + + message + .content + .tool_results + .iter() + .map(|result| result.content.clone()) + .collect::>() + .join("\n") +} + +fn build_generate_request(messages: &[Message]) -> Result { + let mut contents = Vec::with_capacity(messages.len()); + let mut system_blocks: Vec = Vec::new(); + + for message in messages { + let text = message_text(message); + if text.is_empty() { + continue; + } + + match &message.from { + EntityId::User | EntityId::Tool => contents.push(GeminiContent { + role: "user".to_string(), + parts: vec![GeminiTextPart { text }], + }), + EntityId::System => system_blocks.push(text), + EntityId::Bot(_) => contents.push(GeminiContent { + role: "model".to_string(), + parts: vec![GeminiTextPart { text }], + }), + EntityId::App => { + return Err(ClientError::new( + ClientErrorKind::Format, + "App messages cannot be sent to Gemini.".to_string(), + )); + } + } + } + + if contents.is_empty() { + return Err(ClientError::new( + ClientErrorKind::Format, + "Gemini request has no conversation content.".to_string(), + )); + } + + let system_instruction = if system_blocks.is_empty() { + None + } else { + Some(GeminiSystemInstruction { + parts: vec![GeminiTextPart { + text: system_blocks.join("\n\n"), + }], + }) + }; + + Ok(GeminiGenerateRequest { + contents, + system_instruction, + }) +} + +fn parse_stream_text(payload: &str) -> Result { + let event: GeminiStreamEvent = serde_json::from_str(payload).map_err(|error| { + ClientError::new_with_source( + ClientErrorKind::Format, + "Could not parse Gemini stream event.".to_string(), + Some(error), + ) + })?; + + let text = event + .candidates + .iter() + .filter_map(|candidate| candidate.content.as_ref()) + .flat_map(|content| content.parts.iter()) + .map(|part| part.text.as_str()) + .collect::>() + .join(""); + + Ok(text) +} + +impl BotClient for GeminiClient { + fn bots(&mut self) -> BoxPlatformSendFuture<'static, ClientResult>> { + let inner = self.0.read().unwrap().clone(); + + Box::pin(async move { + let url = match build_models_url(&inner.url, inner.api_key.as_deref()) { + Ok(url) => url, + Err(error) => return error.into(), + }; + + let response = match inner.client.get(&url).headers(inner.headers).send().await { + Ok(response) => response, + Err(error) => { + return ClientError::new_with_source( + ClientErrorKind::Network, + format!( + "Could not send request to {url}. Verify your connection and key." + ), + Some(error), + ) + .into(); + } + }; + + if !response.status().is_success() { + let status = response.status(); + let details = response.text().await.unwrap_or_default(); + return ClientError::new( + ClientErrorKind::Response, + format!("Gemini models request failed with status {status}."), + ) + .with_details(details) + .into(); + } + + let payload = match response.text().await { + Ok(text) => text, + Err(error) => { + return ClientError::new_with_source( + ClientErrorKind::Format, + format!("Could not read Gemini models response from {url}."), + Some(error), + ) + .into(); + } + }; + + parse_models_response(&payload).into() + }) + } + + fn send( + &mut self, + bot_id: &BotId, + messages: &[Message], + _tools: &[Tool], + ) -> BoxPlatformSendStream<'static, ClientResult> { + let inner = self.0.read().unwrap().clone(); + let bot_id = bot_id.clone(); + let messages = messages.to_vec(); + + let stream = stream! { + let url = match build_stream_url(&inner.url, &bot_id, inner.api_key.as_deref()) { + Ok(url) => url, + Err(error) => { + yield error.into(); + return; + } + }; + + let request = match build_generate_request(&messages) { + Ok(request) => request, + Err(error) => { + yield error.into(); + return; + } + }; + + let response = match inner + .client + .post(&url) + .headers(inner.headers) + .json(&request) + .send() + .await + { + Ok(response) => response, + Err(error) => { + yield ClientError::new_with_source( + ClientErrorKind::Network, + format!( + "Could not send request to {url}. Verify your connection and key." + ), + Some(error), + ).into(); + return; + } + }; + + if !response.status().is_success() { + let status = response.status(); + let details = response.text().await.unwrap_or_default(); + yield ClientError::new( + ClientErrorKind::Response, + format!("Gemini streaming request failed with status {status}."), + ).with_details(details).into(); + return; + } + + let mut content = MessageContent::default(); + let mut full_text = String::new(); + let events = parse_sse(response.bytes_stream()); + + for await event in events { + let event = match event { + Ok(event) => event, + Err(error) => { + yield ClientError::new_with_source( + ClientErrorKind::Network, + format!("Gemini response stream from {url} was interrupted."), + Some(error), + ).into(); + return; + } + }; + + let chunk = match parse_stream_text(&event) { + Ok(chunk) => chunk, + Err(error) => { + yield error.into(); + return; + } + }; + + if chunk.is_empty() { + continue; + } + + full_text.push_str(&chunk); + content.text = full_text.clone(); + yield ClientResult::new_ok(content.clone()); + } + }; + + Box::pin(stream) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_models_response_prefers_display_name() { + let payload = r#" + { + "models": [ + { + "name": "models/gemini-2.0-flash", + "displayName": "Gemini 2.0 Flash", + "supportedGenerationMethods": ["generateContent"] + } + ] + }"#; + + let bots = parse_models_response(payload).expect("failed to parse models response"); + let bot = bots.first().expect("expected one bot"); + + assert_eq!(bot.id.id(), "gemini-2.0-flash"); + assert_eq!(bot.name, "Gemini 2.0 Flash"); + } + + #[test] + fn models_url_appends_key_and_preserves_existing_query() { + let url = build_models_url( + "https://generativelanguage.googleapis.com/v1beta?alt=sse", + Some("test-key"), + ) + .expect("failed to build models url"); + + assert!(url.contains("/models?")); + assert!(url.contains("alt=sse")); + assert!(url.contains("key=test-key")); + } + + #[test] + fn stream_url_uses_stream_generate_content() { + let url = build_stream_url( + "https://generativelanguage.googleapis.com/v1beta", + &BotId::new("models/gemini-2.0-flash"), + Some("test-key"), + ) + .expect("failed to build stream url"); + + assert!(url.contains("/models/gemini-2.0-flash:streamGenerateContent")); + assert!(url.contains("alt=sse")); + assert!(url.contains("key=test-key")); + } + + #[test] + fn build_generate_request_maps_system_user_and_model_roles() { + let messages = vec![ + Message { + from: EntityId::System, + content: MessageContent { + text: "You are helpful.".to_string(), + ..Default::default() + }, + ..Default::default() + }, + Message { + from: EntityId::User, + content: MessageContent { + text: "Hi".to_string(), + ..Default::default() + }, + ..Default::default() + }, + Message { + from: EntityId::Bot(BotId::new("gemini-2.0-flash")), + content: MessageContent { + text: "Hello".to_string(), + ..Default::default() + }, + ..Default::default() + }, + ]; + + let request = build_generate_request(&messages).expect("failed to build request"); + + assert_eq!(request.contents.len(), 2); + assert_eq!(request.contents[0].role, "user"); + assert_eq!(request.contents[1].role, "model"); + assert_eq!( + request.system_instruction.expect("missing system instruction").parts[0].text, + "You are helpful." + ); + } + + #[test] + fn parse_stream_text_collects_all_candidate_parts() { + let payload = r#" + { + "candidates": [ + { "content": { "parts": [{"text":"Hello "}, {"text":"Gemini"}] } } + ] + }"#; + + let text = parse_stream_text(payload).expect("failed to parse stream payload"); + assert_eq!(text, "Hello Gemini"); + } +} diff --git a/src/prelude.rs b/src/prelude.rs index b53b5f4..4cbc028 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -6,6 +6,8 @@ pub use crate::protocol::*; // These are the clients that are most commonly used. #[cfg(feature = "api-clients")] pub use crate::clients::openai::OpenAiClient; +#[cfg(feature = "api-clients")] +pub use crate::clients::gemini::GeminiClient; pub use crate::clients::router::RouterClient; // These other clients are less commonly used. From c45756e23efaab2e53943306cd80c507ae3f1f58 Mon Sep 17 00:00:00 2001 From: Alvin Date: Tue, 17 Feb 2026 03:32:03 +0800 Subject: [PATCH 02/16] fix(gemini): remove duplicate Tool import --- src/clients/gemini.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index e54de30..9d72bbb 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -1,4 +1,3 @@ -use crate::protocol::Tool; use crate::protocol::*; use crate::utils::asynchronous::{BoxPlatformSendFuture, BoxPlatformSendStream}; use crate::utils::sse::parse_sse; From df68b6c3f9a207d5c5c43e345ec293c250f62b55 Mon Sep 17 00:00:00 2001 From: Alvin Date: Tue, 17 Feb 2026 03:34:19 +0800 Subject: [PATCH 03/16] fix(gemini): authenticate via header only, remove key from query params --- src/clients/gemini.rs | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index 9d72bbb..c72cecc 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -14,7 +14,6 @@ use url::Url; struct GeminiClientInner { url: String, headers: HeaderMap, - api_key: Option, client: reqwest::Client, } @@ -34,7 +33,6 @@ impl GeminiClient { let inner = GeminiClientInner { url, headers: HeaderMap::new(), - api_key: None, client: crate::utils::http::default_client(), }; Self(Arc::new(RwLock::new(inner))) @@ -54,7 +52,6 @@ impl GeminiClient { /// Sets the Gemini API key used for request authentication. pub fn set_key(&mut self, key: &str) -> Result<(), &'static str> { - self.0.write().unwrap().api_key = Some(key.to_string()); self.set_header("x-goog-api-key", key) } } @@ -129,7 +126,6 @@ fn normalize_model_id(id: &str) -> &str { fn build_endpoint_url( base_url: &str, suffix: &str, - api_key: Option<&str>, extra_query: &[(&str, &str)], ) -> Result { let mut url = Url::parse(base_url).map_err(|error| { @@ -150,26 +146,22 @@ fn build_endpoint_url( for (key, value) in extra_query { query.append_pair(key, value); } - if let Some(api_key) = api_key { - query.append_pair("key", api_key); - } } Ok(url.to_string()) } -fn build_models_url(base_url: &str, api_key: Option<&str>) -> Result { - build_endpoint_url(base_url, "models", api_key, &[]) +fn build_models_url(base_url: &str) -> Result { + build_endpoint_url(base_url, "models", &[]) } fn build_stream_url( base_url: &str, bot_id: &BotId, - api_key: Option<&str>, ) -> Result { let model_id = normalize_model_id(bot_id.id()); let suffix = format!("models/{model_id}:streamGenerateContent"); - build_endpoint_url(base_url, &suffix, api_key, &[("alt", "sse")]) + build_endpoint_url(base_url, &suffix, &[("alt", "sse")]) } fn supports_generate_content(model: &GeminiModel) -> bool { @@ -309,7 +301,7 @@ impl BotClient for GeminiClient { let inner = self.0.read().unwrap().clone(); Box::pin(async move { - let url = match build_models_url(&inner.url, inner.api_key.as_deref()) { + let url = match build_models_url(&inner.url) { Ok(url) => url, Err(error) => return error.into(), }; @@ -366,7 +358,7 @@ impl BotClient for GeminiClient { let messages = messages.to_vec(); let stream = stream! { - let url = match build_stream_url(&inner.url, &bot_id, inner.api_key.as_deref()) { + let url = match build_stream_url(&inner.url, &bot_id) { Ok(url) => url, Err(error) => { yield error.into(); @@ -481,16 +473,14 @@ mod tests { } #[test] - fn models_url_appends_key_and_preserves_existing_query() { + fn models_url_preserves_existing_query() { let url = build_models_url( "https://generativelanguage.googleapis.com/v1beta?alt=sse", - Some("test-key"), ) .expect("failed to build models url"); assert!(url.contains("/models?")); assert!(url.contains("alt=sse")); - assert!(url.contains("key=test-key")); } #[test] @@ -498,13 +488,11 @@ mod tests { let url = build_stream_url( "https://generativelanguage.googleapis.com/v1beta", &BotId::new("models/gemini-2.0-flash"), - Some("test-key"), ) .expect("failed to build stream url"); assert!(url.contains("/models/gemini-2.0-flash:streamGenerateContent")); assert!(url.contains("alt=sse")); - assert!(url.contains("key=test-key")); } #[test] From b4288fe8a14d9bc884ac2dda8ee5c7c9acbb74a5 Mon Sep 17 00:00:00 2001 From: Alvin Date: Tue, 17 Feb 2026 03:36:21 +0800 Subject: [PATCH 04/16] feat(gemini): derive model capabilities from supportedGenerationMethods --- src/clients/gemini.rs | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index c72cecc..e71edb1 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -172,6 +172,18 @@ fn supports_generate_content(model: &GeminiModel) -> bool { .any(|method| method == "generateContent") } +fn derive_capabilities(model: &GeminiModel) -> BotCapabilities { + let mut caps = vec![BotCapability::TextInput]; + if model + .supported_generation_methods + .iter() + .any(|m| m == "generateContent") + { + caps.push(BotCapability::ToolInput); + } + BotCapabilities::new().with_capabilities(caps) +} + fn parse_models_response(payload: &str) -> Result, ClientError> { let response: GeminiModelsResponse = serde_json::from_str(payload).map_err(|error| { ClientError::new_with_source( @@ -197,7 +209,7 @@ fn parse_models_response(payload: &str) -> Result, ClientError> { name, avatar: EntityAvatar::from_first_grapheme(&model.name.to_uppercase()) .unwrap_or_else(|| EntityAvatar::Text("?".into())), - capabilities: BotCapabilities::new().with_capabilities([BotCapability::TextInput]), + capabilities: derive_capabilities(model), } }) .collect(); @@ -535,6 +547,30 @@ mod tests { ); } + #[test] + fn parse_models_response_maps_capabilities_from_generation_methods() { + let payload = r#" + { + "models": [ + { + "name": "models/gemini-2.0-flash", + "supportedGenerationMethods": ["generateContent"] + }, + { + "name": "models/text-embedding-004", + "supportedGenerationMethods": ["embedContent"] + } + ] + }"#; + + let bots = parse_models_response(payload).expect("failed to parse"); + assert_eq!(bots.len(), 1, "embedding model should be filtered out"); + + let bot = &bots[0]; + assert!(bot.capabilities.has_capability(&BotCapability::TextInput)); + assert!(bot.capabilities.has_capability(&BotCapability::ToolInput)); + } + #[test] fn parse_stream_text_collects_all_candidate_parts() { let payload = r#" From 8caca514414f447834381df1de46bb7c1dfb19c9 Mon Sep 17 00:00:00 2001 From: Alvin Date: Tue, 17 Feb 2026 03:39:09 +0800 Subject: [PATCH 05/16] feat(gemini): handle model list pagination via nextPageToken --- src/clients/gemini.rs | 168 +++++++++++++++++++++++++++++++++--------- 1 file changed, 132 insertions(+), 36 deletions(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index e71edb1..22db19e 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -60,6 +60,8 @@ impl GeminiClient { struct GeminiModelsResponse { #[serde(default)] models: Vec, + #[serde(rename = "nextPageToken")] + next_page_token: Option, } #[derive(Debug, Deserialize)] @@ -151,8 +153,16 @@ fn build_endpoint_url( Ok(url.to_string()) } -fn build_models_url(base_url: &str) -> Result { - build_endpoint_url(base_url, "models", &[]) +fn build_models_url( + base_url: &str, + page_token: Option<&str>, +) -> Result { + match page_token { + Some(token) => { + build_endpoint_url(base_url, "models", &[("pageToken", token)]) + } + None => build_endpoint_url(base_url, "models", &[]), + } } fn build_stream_url( @@ -310,52 +320,126 @@ fn parse_stream_text(payload: &str) -> Result { impl BotClient for GeminiClient { fn bots(&mut self) -> BoxPlatformSendFuture<'static, ClientResult>> { - let inner = self.0.read().unwrap().clone(); + let inner = self + .0 + .read() + .expect("gemini client lock poisoned") + .clone(); Box::pin(async move { - let url = match build_models_url(&inner.url) { - Ok(url) => url, - Err(error) => return error.into(), - }; + let mut all_bots = Vec::new(); + let mut page_token: Option = None; + + loop { + let url = match build_models_url( + &inner.url, + page_token.as_deref(), + ) { + Ok(url) => url, + Err(error) => return error.into(), + }; - let response = match inner.client.get(&url).headers(inner.headers).send().await { - Ok(response) => response, - Err(error) => { - return ClientError::new_with_source( - ClientErrorKind::Network, + let response = match inner + .client + .get(&url) + .headers(inner.headers.clone()) + .send() + .await + { + Ok(response) => response, + Err(error) => { + return ClientError::new_with_source( + ClientErrorKind::Network, + format!( + "Could not send request to {url}. \ + Verify your connection and key." + ), + Some(error), + ) + .into(); + } + }; + + if !response.status().is_success() { + let status = response.status(); + let details = + response.text().await.unwrap_or_default(); + return ClientError::new( + ClientErrorKind::Response, format!( - "Could not send request to {url}. Verify your connection and key." + "Gemini models request failed \ + with status {status}." ), - Some(error), ) + .with_details(details) .into(); } - }; - if !response.status().is_success() { - let status = response.status(); - let details = response.text().await.unwrap_or_default(); - return ClientError::new( - ClientErrorKind::Response, - format!("Gemini models request failed with status {status}."), - ) - .with_details(details) - .into(); - } + let payload = match response.text().await { + Ok(text) => text, + Err(error) => { + return ClientError::new_with_source( + ClientErrorKind::Format, + format!( + "Could not read Gemini models \ + response from {url}." + ), + Some(error), + ) + .into(); + } + }; - let payload = match response.text().await { - Ok(text) => text, - Err(error) => { - return ClientError::new_with_source( - ClientErrorKind::Format, - format!("Could not read Gemini models response from {url}."), - Some(error), - ) - .into(); + let parsed: GeminiModelsResponse = + match serde_json::from_str(&payload) { + Ok(r) => r, + Err(error) => { + return ClientError::new_with_source( + ClientErrorKind::Format, + "Could not parse Gemini models \ + response." + .to_string(), + Some(error), + ) + .into(); + } + }; + + let bots = parsed + .models + .iter() + .filter(|m| supports_generate_content(m)) + .map(|model| { + let id = normalize_model_id(&model.name); + let name = model + .display_name + .clone() + .unwrap_or_else(|| id.to_string()); + Bot { + id: BotId::new(id), + name, + avatar: EntityAvatar::from_first_grapheme( + &model.name.to_uppercase(), + ) + .unwrap_or_else(|| { + EntityAvatar::Text("?".into()) + }), + capabilities: derive_capabilities(model), + } + }) + .collect::>(); + + all_bots.extend(bots); + + match parsed.next_page_token { + Some(token) if !token.is_empty() => { + page_token = Some(token); + } + _ => break, } - }; + } - parse_models_response(&payload).into() + ClientResult::new_ok(all_bots) }) } @@ -488,6 +572,7 @@ mod tests { fn models_url_preserves_existing_query() { let url = build_models_url( "https://generativelanguage.googleapis.com/v1beta?alt=sse", + None, ) .expect("failed to build models url"); @@ -495,6 +580,17 @@ mod tests { assert!(url.contains("alt=sse")); } + #[test] + fn models_url_includes_page_token() { + let url = build_models_url( + "https://generativelanguage.googleapis.com/v1beta", + Some("abc123"), + ) + .expect("failed to build models url"); + + assert!(url.contains("pageToken=abc123")); + } + #[test] fn stream_url_uses_stream_generate_content() { let url = build_stream_url( From e1a35e489f9f9c7b279a98506c8333497a498d84 Mon Sep 17 00:00:00 2001 From: Alvin Date: Tue, 17 Feb 2026 03:42:03 +0800 Subject: [PATCH 06/16] refactor(gemini): streaming perf, unwrap cleanup, tools TODO - Avoid cloning MessageContent on each stream chunk (create fresh per yield) - Replace .unwrap() with .expect() on RwLock operations - Add TODO for Gemini function calling / tools support - Mark parse_models_response as #[cfg(test)] (only used in tests now) --- src/clients/gemini.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index 22db19e..dc27ba9 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -44,7 +44,7 @@ impl GeminiClient { let header_value = value.parse().map_err(|_| "Invalid header value")?; self.0 .write() - .unwrap() + .expect("gemini client lock poisoned") .headers .insert(header_name, header_value); Ok(()) @@ -194,6 +194,7 @@ fn derive_capabilities(model: &GeminiModel) -> BotCapabilities { BotCapabilities::new().with_capabilities(caps) } +#[cfg(test)] fn parse_models_response(payload: &str) -> Result, ClientError> { let response: GeminiModelsResponse = serde_json::from_str(payload).map_err(|error| { ClientError::new_with_source( @@ -449,7 +450,9 @@ impl BotClient for GeminiClient { messages: &[Message], _tools: &[Tool], ) -> BoxPlatformSendStream<'static, ClientResult> { - let inner = self.0.read().unwrap().clone(); + // TODO: Gemini supports function calling — convert `_tools` to + // Gemini `tools` / `function_declarations` and include in request. + let inner = self.0.read().expect("gemini client lock poisoned").clone(); let bot_id = bot_id.clone(); let messages = messages.to_vec(); @@ -501,7 +504,6 @@ impl BotClient for GeminiClient { return; } - let mut content = MessageContent::default(); let mut full_text = String::new(); let events = parse_sse(response.bytes_stream()); @@ -531,8 +533,9 @@ impl BotClient for GeminiClient { } full_text.push_str(&chunk); + let mut content = MessageContent::default(); content.text = full_text.clone(); - yield ClientResult::new_ok(content.clone()); + yield ClientResult::new_ok(content); } }; From 7d14028deac91b7672012b7737d1e097927c78f5 Mon Sep 17 00:00:00 2001 From: Alvin Date: Tue, 17 Feb 2026 11:47:48 +0800 Subject: [PATCH 07/16] fix(gemini): align capability signal and resource URL path --- src/clients/gemini.rs | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index dc27ba9..3b92cc5 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -169,8 +169,13 @@ fn build_stream_url( base_url: &str, bot_id: &BotId, ) -> Result { - let model_id = normalize_model_id(bot_id.id()); - let suffix = format!("models/{model_id}:streamGenerateContent"); + let model_id = bot_id.id(); + let model_path = if model_id.contains('/') { + model_id.to_string() + } else { + format!("models/{}", normalize_model_id(model_id)) + }; + let suffix = format!("{model_path}:streamGenerateContent"); build_endpoint_url(base_url, &suffix, &[("alt", "sse")]) } @@ -182,16 +187,8 @@ fn supports_generate_content(model: &GeminiModel) -> bool { .any(|method| method == "generateContent") } -fn derive_capabilities(model: &GeminiModel) -> BotCapabilities { - let mut caps = vec![BotCapability::TextInput]; - if model - .supported_generation_methods - .iter() - .any(|m| m == "generateContent") - { - caps.push(BotCapability::ToolInput); - } - BotCapabilities::new().with_capabilities(caps) +fn derive_capabilities(_model: &GeminiModel) -> BotCapabilities { + BotCapabilities::new().with_capabilities([BotCapability::TextInput]) } #[cfg(test)] @@ -606,6 +603,18 @@ mod tests { assert!(url.contains("alt=sse")); } + #[test] + fn stream_url_keeps_qualified_resource_path() { + let url = build_stream_url( + "https://generativelanguage.googleapis.com/v1beta", + &BotId::new("tunedModels/my-tuned-model"), + ) + .expect("failed to build stream url"); + + assert!(url.contains("/tunedModels/my-tuned-model:streamGenerateContent")); + assert!(!url.contains("/models/tunedModels/my-tuned-model:streamGenerateContent")); + } + #[test] fn build_generate_request_maps_system_user_and_model_roles() { let messages = vec![ @@ -667,7 +676,7 @@ mod tests { let bot = &bots[0]; assert!(bot.capabilities.has_capability(&BotCapability::TextInput)); - assert!(bot.capabilities.has_capability(&BotCapability::ToolInput)); + assert!(!bot.capabilities.has_capability(&BotCapability::ToolInput)); } #[test] From 1f81dcaa7265f17b2f92052be0a32151373306a0 Mon Sep 17 00:00:00 2001 From: Alvin Date: Tue, 17 Feb 2026 12:37:44 +0800 Subject: [PATCH 08/16] refactor(gemini): deduplicate mapping and tighten client ergonomics --- src/clients/gemini.rs | 72 +++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 43 deletions(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index 3b92cc5..3d9894e 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -1,3 +1,5 @@ +//! Native Gemini API client implementation. + use crate::protocol::*; use crate::utils::asynchronous::{BoxPlatformSendFuture, BoxPlatformSendStream}; use crate::utils::sse::parse_sse; @@ -187,10 +189,30 @@ fn supports_generate_content(model: &GeminiModel) -> bool { .any(|method| method == "generateContent") } -fn derive_capabilities(_model: &GeminiModel) -> BotCapabilities { +fn derive_capabilities() -> BotCapabilities { BotCapabilities::new().with_capabilities([BotCapability::TextInput]) } +fn gemini_model_to_bot(model: &GeminiModel) -> Option { + if !supports_generate_content(model) { + return None; + } + + let normalized_id = normalize_model_id(&model.name); + let name = model + .display_name + .clone() + .unwrap_or_else(|| normalized_id.to_string()); + + Some(Bot { + id: BotId::new(normalized_id), + name, + avatar: EntityAvatar::from_first_grapheme(&model.name.to_uppercase()) + .unwrap_or_else(|| EntityAvatar::Text("?".into())), + capabilities: derive_capabilities(), + }) +} + #[cfg(test)] fn parse_models_response(payload: &str) -> Result, ClientError> { let response: GeminiModelsResponse = serde_json::from_str(payload).map_err(|error| { @@ -201,27 +223,7 @@ fn parse_models_response(payload: &str) -> Result, ClientError> { ) })?; - let bots = response - .models - .iter() - .filter(|model| supports_generate_content(model)) - .map(|model| { - let normalized_id = normalize_model_id(&model.name); - let name = model - .display_name - .clone() - .unwrap_or_else(|| normalized_id.to_string()); - - Bot { - id: BotId::new(normalized_id), - name, - avatar: EntityAvatar::from_first_grapheme(&model.name.to_uppercase()) - .unwrap_or_else(|| EntityAvatar::Text("?".into())), - capabilities: derive_capabilities(model), - } - }) - .collect(); - + let bots = response.models.iter().filter_map(gemini_model_to_bot).collect(); Ok(bots) } @@ -406,25 +408,7 @@ impl BotClient for GeminiClient { let bots = parsed .models .iter() - .filter(|m| supports_generate_content(m)) - .map(|model| { - let id = normalize_model_id(&model.name); - let name = model - .display_name - .clone() - .unwrap_or_else(|| id.to_string()); - Bot { - id: BotId::new(id), - name, - avatar: EntityAvatar::from_first_grapheme( - &model.name.to_uppercase(), - ) - .unwrap_or_else(|| { - EntityAvatar::Text("?".into()) - }), - capabilities: derive_capabilities(model), - } - }) + .filter_map(gemini_model_to_bot) .collect::>(); all_bots.extend(bots); @@ -530,8 +514,10 @@ impl BotClient for GeminiClient { } full_text.push_str(&chunk); - let mut content = MessageContent::default(); - content.text = full_text.clone(); + let content = MessageContent { + text: full_text.clone(), + ..Default::default() + }; yield ClientResult::new_ok(content); } }; From 3d5fdd89ffd3f24ed563edfde9adc0b412b5c7d4 Mon Sep 17 00:00:00 2001 From: Alvin Date: Wed, 4 Mar 2026 11:37:03 +0800 Subject: [PATCH 09/16] feat(gemini): add native tool-calling request and stream parsing --- src/clients/gemini.rs | 495 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 410 insertions(+), 85 deletions(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index 3d9894e..c68c3b9 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -6,7 +6,9 @@ use crate::utils::sse::parse_sse; use async_stream::stream; use reqwest::header::{HeaderMap, HeaderName}; use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; use std::{ + collections::{BTreeMap, HashMap}, str::FromStr, sync::{Arc, RwLock}, }; @@ -82,6 +84,22 @@ struct GeminiGenerateRequest { #[serde(rename = "system_instruction")] #[serde(skip_serializing_if = "Option::is_none")] system_instruction: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, +} + +#[derive(Debug, Serialize)] +struct GeminiToolDeclarations { + #[serde(rename = "function_declarations")] + function_declarations: Vec, +} + +#[derive(Debug, Serialize)] +struct GeminiFunctionDeclaration { + name: String, + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + parameters: Value, } #[derive(Debug, Serialize)] @@ -92,7 +110,7 @@ struct GeminiSystemInstruction { #[derive(Debug, Serialize)] struct GeminiContent { role: String, - parts: Vec, + parts: Vec, } #[derive(Debug, Serialize)] @@ -100,6 +118,34 @@ struct GeminiTextPart { text: String, } +#[derive(Debug, Serialize)] +#[serde(untagged)] +enum GeminiOutgoingPart { + Text(GeminiTextPart), + FunctionCall { + #[serde(rename = "functionCall")] + function_call: GeminiFunctionCall, + }, + FunctionResponse { + #[serde(rename = "functionResponse")] + function_response: GeminiFunctionResponse, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct GeminiFunctionCall { + #[serde(default)] + name: String, + #[serde(default)] + args: Value, +} + +#[derive(Debug, Serialize)] +struct GeminiFunctionResponse { + name: String, + response: Value, +} + #[derive(Debug, Deserialize)] struct GeminiStreamEvent { #[serde(default)] @@ -121,6 +167,14 @@ struct GeminiCandidateContent { struct GeminiStreamPart { #[serde(default)] text: String, + #[serde(rename = "functionCall")] + function_call: Option, +} + +#[derive(Debug, Default)] +struct GeminiStreamDelta { + text: String, + function_calls: Vec, } fn normalize_model_id(id: &str) -> &str { @@ -155,22 +209,14 @@ fn build_endpoint_url( Ok(url.to_string()) } -fn build_models_url( - base_url: &str, - page_token: Option<&str>, -) -> Result { +fn build_models_url(base_url: &str, page_token: Option<&str>) -> Result { match page_token { - Some(token) => { - build_endpoint_url(base_url, "models", &[("pageToken", token)]) - } + Some(token) => build_endpoint_url(base_url, "models", &[("pageToken", token)]), None => build_endpoint_url(base_url, "models", &[]), } } -fn build_stream_url( - base_url: &str, - bot_id: &BotId, -) -> Result { +fn build_stream_url(base_url: &str, bot_id: &BotId) -> Result { let model_id = bot_id.id(); let model_path = if model_id.contains('/') { model_id.to_string() @@ -190,7 +236,7 @@ fn supports_generate_content(model: &GeminiModel) -> bool { } fn derive_capabilities() -> BotCapabilities { - BotCapabilities::new().with_capabilities([BotCapability::TextInput]) + BotCapabilities::new().with_capabilities([BotCapability::TextInput, BotCapability::ToolInput]) } fn gemini_model_to_bot(model: &GeminiModel) -> Option { @@ -223,48 +269,166 @@ fn parse_models_response(payload: &str) -> Result, ClientError> { ) })?; - let bots = response.models.iter().filter_map(gemini_model_to_bot).collect(); + let bots = response + .models + .iter() + .filter_map(gemini_model_to_bot) + .collect(); Ok(bots) } -fn message_text(message: &Message) -> String { - if !message.content.text.is_empty() { - return message.content.text.clone(); +fn as_tool_parameters(schema: &Map) -> Value { + if schema.is_empty() { + return serde_json::json!({ + "type": "object", + "properties": {} + }); } + Value::Object(schema.clone()) +} - if message.content.tool_results.is_empty() { - return String::new(); +fn as_gemini_tools(tools: &[Tool]) -> Option> { + if tools.is_empty() { + return None; } - message - .content - .tool_results + let function_declarations = tools .iter() - .map(|result| result.content.clone()) - .collect::>() - .join("\n") + .map(|tool| GeminiFunctionDeclaration { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: as_tool_parameters(&tool.input_schema), + }) + .collect::>(); + + Some(vec![GeminiToolDeclarations { + function_declarations, + }]) } -fn build_generate_request(messages: &[Message]) -> Result { +fn collect_tool_call_names(messages: &[Message]) -> HashMap { + let mut names = HashMap::new(); + for message in messages { + for call in &message.content.tool_calls { + names.insert(call.id.clone(), call.name.clone()); + } + } + names +} + +fn parse_tool_result_payload(result: &ToolResult) -> Value { + match serde_json::from_str::(&result.content) { + Ok(Value::Object(mut object)) => { + if result.is_error && !object.contains_key("is_error") { + object.insert("is_error".to_string(), Value::Bool(true)); + } + Value::Object(object) + } + Ok(other) => serde_json::json!({ + "content": other, + "is_error": result.is_error, + }), + Err(_) => serde_json::json!({ + "content": result.content, + "is_error": result.is_error, + }), + } +} + +fn as_bot_parts(message: &Message) -> Vec { + let mut parts = Vec::new(); + + if !message.content.text.is_empty() { + parts.push(GeminiOutgoingPart::Text(GeminiTextPart { + text: message.content.text.clone(), + })); + } + + for call in &message.content.tool_calls { + parts.push(GeminiOutgoingPart::FunctionCall { + function_call: GeminiFunctionCall { + name: call.name.clone(), + args: Value::Object(call.arguments.clone()), + }, + }); + } + + parts +} + +fn as_tool_parts( + message: &Message, + tool_call_names: &HashMap, +) -> Vec { + let mut parts = Vec::new(); + + for result in &message.content.tool_results { + if let Some(name) = tool_call_names.get(&result.tool_call_id) { + parts.push(GeminiOutgoingPart::FunctionResponse { + function_response: GeminiFunctionResponse { + name: name.clone(), + response: parse_tool_result_payload(result), + }, + }); + } else if !result.content.is_empty() { + parts.push(GeminiOutgoingPart::Text(GeminiTextPart { + text: result.content.clone(), + })); + } + } + + if parts.is_empty() && !message.content.text.is_empty() { + parts.push(GeminiOutgoingPart::Text(GeminiTextPart { + text: message.content.text.clone(), + })); + } + + parts +} + +fn build_generate_request( + messages: &[Message], + tools: &[Tool], +) -> Result { let mut contents = Vec::with_capacity(messages.len()); let mut system_blocks: Vec = Vec::new(); + let tool_call_names = collect_tool_call_names(messages); for message in messages { - let text = message_text(message); - if text.is_empty() { - continue; - } - match &message.from { - EntityId::User | EntityId::Tool => contents.push(GeminiContent { - role: "user".to_string(), - parts: vec![GeminiTextPart { text }], - }), - EntityId::System => system_blocks.push(text), - EntityId::Bot(_) => contents.push(GeminiContent { - role: "model".to_string(), - parts: vec![GeminiTextPart { text }], - }), + EntityId::User => { + if !message.content.text.is_empty() { + contents.push(GeminiContent { + role: "user".to_string(), + parts: vec![GeminiOutgoingPart::Text(GeminiTextPart { + text: message.content.text.clone(), + })], + }); + } + } + EntityId::Tool => { + let parts = as_tool_parts(message, &tool_call_names); + if !parts.is_empty() { + contents.push(GeminiContent { + role: "user".to_string(), + parts, + }); + } + } + EntityId::System => { + if !message.content.text.is_empty() { + system_blocks.push(message.content.text.clone()); + } + } + EntityId::Bot(_) => { + let parts = as_bot_parts(message); + if !parts.is_empty() { + contents.push(GeminiContent { + role: "model".to_string(), + parts, + }); + } + } EntityId::App => { return Err(ClientError::new( ClientErrorKind::Format, @@ -294,10 +458,11 @@ fn build_generate_request(messages: &[Message]) -> Result Result { +fn parse_stream_delta(payload: &str) -> Result { let event: GeminiStreamEvent = serde_json::from_str(payload).map_err(|error| { ClientError::new_with_source( ClientErrorKind::Format, @@ -306,35 +471,53 @@ fn parse_stream_text(payload: &str) -> Result { ) })?; - let text = event - .candidates - .iter() - .filter_map(|candidate| candidate.content.as_ref()) - .flat_map(|content| content.parts.iter()) - .map(|part| part.text.as_str()) - .collect::>() - .join(""); + let mut delta = GeminiStreamDelta::default(); + + for candidate in event.candidates { + if let Some(content) = candidate.content { + for part in content.parts { + if !part.text.is_empty() { + delta.text.push_str(&part.text); + } + if let Some(function_call) = part.function_call { + if !function_call.name.is_empty() { + delta.function_calls.push(function_call); + } + } + } + } + } - Ok(text) + Ok(delta) +} + +#[cfg(test)] +fn parse_stream_text(payload: &str) -> Result { + Ok(parse_stream_delta(payload)?.text) +} + +fn function_call_args_to_map(args: Value) -> Map { + match args { + Value::Object(args) => args, + Value::Null => Map::new(), + other => { + let mut arguments = Map::new(); + arguments.insert("value".to_string(), other); + arguments + } + } } impl BotClient for GeminiClient { fn bots(&mut self) -> BoxPlatformSendFuture<'static, ClientResult>> { - let inner = self - .0 - .read() - .expect("gemini client lock poisoned") - .clone(); + let inner = self.0.read().expect("gemini client lock poisoned").clone(); Box::pin(async move { let mut all_bots = Vec::new(); let mut page_token: Option = None; loop { - let url = match build_models_url( - &inner.url, - page_token.as_deref(), - ) { + let url = match build_models_url(&inner.url, page_token.as_deref()) { Ok(url) => url, Err(error) => return error.into(), }; @@ -362,8 +545,7 @@ impl BotClient for GeminiClient { if !response.status().is_success() { let status = response.status(); - let details = - response.text().await.unwrap_or_default(); + let details = response.text().await.unwrap_or_default(); return ClientError::new( ClientErrorKind::Response, format!( @@ -390,20 +572,19 @@ impl BotClient for GeminiClient { } }; - let parsed: GeminiModelsResponse = - match serde_json::from_str(&payload) { - Ok(r) => r, - Err(error) => { - return ClientError::new_with_source( - ClientErrorKind::Format, - "Could not parse Gemini models \ + let parsed: GeminiModelsResponse = match serde_json::from_str(&payload) { + Ok(r) => r, + Err(error) => { + return ClientError::new_with_source( + ClientErrorKind::Format, + "Could not parse Gemini models \ response." - .to_string(), - Some(error), - ) - .into(); - } - }; + .to_string(), + Some(error), + ) + .into(); + } + }; let bots = parsed .models @@ -429,13 +610,12 @@ impl BotClient for GeminiClient { &mut self, bot_id: &BotId, messages: &[Message], - _tools: &[Tool], + tools: &[Tool], ) -> BoxPlatformSendStream<'static, ClientResult> { - // TODO: Gemini supports function calling — convert `_tools` to - // Gemini `tools` / `function_declarations` and include in request. let inner = self.0.read().expect("gemini client lock poisoned").clone(); let bot_id = bot_id.clone(); let messages = messages.to_vec(); + let tools = tools.to_vec(); let stream = stream! { let url = match build_stream_url(&inner.url, &bot_id) { @@ -446,7 +626,7 @@ impl BotClient for GeminiClient { } }; - let request = match build_generate_request(&messages) { + let request = match build_generate_request(&messages, &tools) { Ok(request) => request, Err(error) => { yield error.into(); @@ -486,6 +666,9 @@ impl BotClient for GeminiClient { } let mut full_text = String::new(); + let mut tool_call_ids_by_index: HashMap = HashMap::new(); + let mut tool_calls_by_index: BTreeMap = BTreeMap::new(); + let mut next_tool_call_id = 0usize; let events = parse_sse(response.bytes_stream()); for await event in events { @@ -501,21 +684,46 @@ impl BotClient for GeminiClient { } }; - let chunk = match parse_stream_text(&event) { - Ok(chunk) => chunk, + let delta = match parse_stream_delta(&event) { + Ok(delta) => delta, Err(error) => { yield error.into(); return; } }; - if chunk.is_empty() { + if delta.text.is_empty() && delta.function_calls.is_empty() { continue; } - full_text.push_str(&chunk); + if !delta.text.is_empty() { + full_text.push_str(&delta.text); + } + + for (index, function_call) in delta.function_calls.into_iter().enumerate() { + let call_id = tool_call_ids_by_index + .entry(index) + .or_insert_with(|| { + let call_id = format!("gemini-call-{next_tool_call_id}"); + next_tool_call_id += 1; + call_id + }) + .clone(); + + tool_calls_by_index.insert( + index, + ToolCall { + id: call_id, + name: function_call.name, + arguments: function_call_args_to_map(function_call.args), + ..Default::default() + }, + ); + } + let content = MessageContent { text: full_text.clone(), + tool_calls: tool_calls_by_index.values().cloned().collect(), ..Default::default() }; yield ClientResult::new_ok(content); @@ -630,13 +838,17 @@ mod tests { }, ]; - let request = build_generate_request(&messages).expect("failed to build request"); + let request = build_generate_request(&messages, &[]).expect("failed to build request"); assert_eq!(request.contents.len(), 2); assert_eq!(request.contents[0].role, "user"); assert_eq!(request.contents[1].role, "model"); assert_eq!( - request.system_instruction.expect("missing system instruction").parts[0].text, + request + .system_instruction + .expect("missing system instruction") + .parts[0] + .text, "You are helpful." ); } @@ -662,7 +874,7 @@ mod tests { let bot = &bots[0]; assert!(bot.capabilities.has_capability(&BotCapability::TextInput)); - assert!(!bot.capabilities.has_capability(&BotCapability::ToolInput)); + assert!(bot.capabilities.has_capability(&BotCapability::ToolInput)); } #[test] @@ -677,4 +889,117 @@ mod tests { let text = parse_stream_text(payload).expect("failed to parse stream payload"); assert_eq!(text, "Hello Gemini"); } + + #[test] + fn build_generate_request_includes_tool_declarations() { + let messages = vec![Message { + from: EntityId::User, + content: MessageContent { + text: "What's the weather in Tokyo?".to_string(), + ..Default::default() + }, + ..Default::default() + }]; + + let tools = vec![Tool { + name: "get_weather".to_string(), + description: Some("Get weather for a city.".to_string()), + input_schema: std::sync::Arc::new( + serde_json::from_str( + r#"{ + "type": "object", + "properties": { + "city": { "type": "string" } + }, + "required": ["city"] + }"#, + ) + .expect("invalid schema json"), + ), + }]; + + let request = build_generate_request(&messages, &tools).expect("failed to build request"); + let value = serde_json::to_value(request).expect("failed to serialize request"); + let declarations = value["tools"][0]["function_declarations"] + .as_array() + .expect("missing function_declarations"); + + assert_eq!(declarations.len(), 1); + assert_eq!(declarations[0]["name"], "get_weather"); + assert_eq!(declarations[0]["parameters"]["type"], "object"); + } + + #[test] + fn build_generate_request_maps_tool_results_to_function_response_parts() { + let tool_call_id = "call-1".to_string(); + let messages = vec![ + Message { + from: EntityId::Bot(BotId::new("gemini-2.0-flash")), + content: MessageContent { + tool_calls: vec![ToolCall { + id: tool_call_id.clone(), + name: "filesystem__read_file".to_string(), + arguments: serde_json::Map::new(), + ..Default::default() + }], + ..Default::default() + }, + ..Default::default() + }, + Message { + from: EntityId::Tool, + content: MessageContent { + tool_results: vec![ToolResult { + tool_call_id, + content: r#"{"content":"hello"}"#.to_string(), + is_error: false, + }], + ..Default::default() + }, + ..Default::default() + }, + ]; + + let request = build_generate_request(&messages, &[]).expect("failed to build request"); + let value = serde_json::to_value(request).expect("failed to serialize request"); + + let model_parts = value["contents"][0]["parts"] + .as_array() + .expect("missing model parts"); + let tool_parts = value["contents"][1]["parts"] + .as_array() + .expect("missing tool parts"); + + assert_eq!( + model_parts[0]["functionCall"]["name"], + "filesystem__read_file" + ); + assert_eq!( + tool_parts[0]["functionResponse"]["name"], + "filesystem__read_file" + ); + } + + #[test] + fn parse_stream_delta_extracts_text_and_function_calls() { + let payload = r#" + { + "candidates": [ + { + "content": { + "parts": [ + {"text":"Checking..."}, + {"functionCall":{"name":"get_weather","args":{"city":"Tokyo"}}} + ] + } + } + ] + }"#; + + let delta = parse_stream_delta(payload).expect("failed to parse stream payload"); + assert_eq!(delta.text, "Checking..."); + assert_eq!(delta.function_calls.len(), 1); + assert_eq!(delta.function_calls[0].name, "get_weather"); + assert_eq!(delta.function_calls[0].args["city"], "Tokyo"); + } } From e94d05b7dc4025bb7679fcd506263834b4223a2f Mon Sep 17 00:00:00 2001 From: Alvin Date: Wed, 4 Mar 2026 13:31:52 +0800 Subject: [PATCH 10/16] fix(sse): parse CRLF events and skip non-data frames safely --- src/utils/sse.rs | 134 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 117 insertions(+), 17 deletions(-) diff --git a/src/utils/sse.rs b/src/utils/sse.rs index e7efbb0..9661df5 100644 --- a/src/utils/sse.rs +++ b/src/utils/sse.rs @@ -4,6 +4,7 @@ use async_stream::stream; use futures::Stream; pub(crate) const EVENT_TERMINATOR: &'static [u8] = b"\n\n"; +pub(crate) const EVENT_TERMINATOR_CRLF: &'static [u8] = b"\r\n\r\n"; /// Split from the last SSE event terminator. /// @@ -16,16 +17,65 @@ pub(crate) const EVENT_TERMINATOR: &'static [u8] = b"\n\n"; /// /// Returns `None` if no terminator is found. pub(crate) fn rsplit_once_terminator(buffer: &[u8]) -> Option<(&[u8], &[u8])> { - buffer - .windows(2) - .enumerate() - .rev() - .find(|(_, w)| w == &EVENT_TERMINATOR) - .map(|(pos, _)| { - let (before, after_with_terminator) = buffer.split_at(pos); - let after = &after_with_terminator[2..]; - (before, after) - }) + fn find_last(buffer: &[u8], term: &[u8]) -> Option { + buffer + .windows(term.len()) + .enumerate() + .rev() + .find(|(_, w)| *w == term) + .map(|(pos, _)| pos) + } + + let lf = find_last(buffer, EVENT_TERMINATOR).map(|pos| (pos, EVENT_TERMINATOR.len())); + let crlf = + find_last(buffer, EVENT_TERMINATOR_CRLF).map(|pos| (pos, EVENT_TERMINATOR_CRLF.len())); + + let (pos, len) = match (lf, crlf) { + (Some(lf), Some(crlf)) => { + if lf.0 >= crlf.0 { + lf + } else { + crlf + } + } + (Some(lf), None) => lf, + (None, Some(crlf)) => crlf, + (None, None) => return None, + }; + + let (before, after_with_terminator) = buffer.split_at(pos); + let after = &after_with_terminator[len..]; + Some((before, after)) +} + +fn extract_sse_data(message: &str) -> Option { + let mut data_lines = Vec::new(); + + for line in message.lines() { + if line.starts_with(':') { + continue; + } + + let Some((field, value)) = line.split_once(':') else { + continue; + }; + + if field.trim() == "data" { + let value = value.strip_prefix(' ').unwrap_or(value); + data_lines.push(value); + } + } + + if data_lines.is_empty() { + return None; + } + + let data = data_lines.join("\n"); + if data.trim() == "[DONE]" { + return None; + } + + Some(data) } /// Convert a stream of bytes into a stream of SSE messages. @@ -58,15 +108,12 @@ where }; // Silently drop any invalid utf8 bytes from the completed messages. - let completed_messages = String::from_utf8_lossy(completed_messages); + let completed_messages = String::from_utf8_lossy(completed_messages) + .replace("\r\n", "\n"); - let messages = - completed_messages + let messages = completed_messages .split(event_terminator_str) - .filter(|m| !m.starts_with(":")) - // TODO: Return a format error instead of unwraping. - .map(|m| m.trim_start().split("data:").nth(1).unwrap()) - .filter(|m| m.trim() != "[DONE]"); + .filter_map(extract_sse_data); for m in messages { yield Ok(m.to_string()); @@ -80,6 +127,7 @@ where #[cfg(test)] mod tests { use super::*; + use futures::{StreamExt, executor::block_on}; #[test] fn test_rsplit_once_terminator() { @@ -88,4 +136,56 @@ mod tests { assert_eq!(completed, b"data: 1\n\ndata: 2"); assert_eq!(incomplete, b"data: incomplete mes"); } + + #[test] + fn test_rsplit_once_terminator_crlf() { + let buffer = b"data: 1\r\n\r\ndata: 2\r\n\r\ndata: incomplete mes"; + let (completed, incomplete) = rsplit_once_terminator(buffer).unwrap(); + assert_eq!(completed, b"data: 1\r\n\r\ndata: 2"); + assert_eq!(incomplete, b"data: incomplete mes"); + } + + #[test] + fn test_extract_sse_data_ignores_non_data_event() { + let message = "event: ping\nid: 1"; + assert_eq!(extract_sse_data(message), None); + } + + #[test] + fn test_extract_sse_data_with_data_field() { + let message = "event: message\ndata: {\"ok\":true}"; + assert_eq!(extract_sse_data(message), Some("{\"ok\":true}".to_string())); + } + + #[test] + fn test_parse_sse_skips_non_data_event() { + let input = futures::stream::iter(vec![Ok::<_, ()>( + b"event: ping\n\n\ + data: hello\n\n" + .to_vec(), + )]); + + let mut output = std::pin::pin!(parse_sse(input)); + let first = block_on(output.next()); + let second = block_on(output.next()); + + assert_eq!(first, Some(Ok("hello".to_string()))); + assert_eq!(second, None); + } + + #[test] + fn test_parse_sse_with_crlf_terminator() { + let input = futures::stream::iter(vec![Ok::<_, ()>( + b"data: first\r\n\r\ndata: second\r\n\r\n".to_vec(), + )]); + + let mut output = std::pin::pin!(parse_sse(input)); + let first = block_on(output.next()); + let second = block_on(output.next()); + let third = block_on(output.next()); + + assert_eq!(first, Some(Ok("first".to_string()))); + assert_eq!(second, Some(Ok("second".to_string()))); + assert_eq!(third, None); + } } From 76c9e54855b20cf03d8bba9faca5b74d7b854340 Mon Sep 17 00:00:00 2001 From: Alvin Date: Wed, 4 Mar 2026 13:31:57 +0800 Subject: [PATCH 11/16] feat(gemini): implement native tool-calling roundtrip --- src/clients/gemini.rs | 202 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 195 insertions(+), 7 deletions(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index c68c3b9..6289064 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -81,16 +81,19 @@ struct GeminiModel { #[derive(Debug, Serialize)] struct GeminiGenerateRequest { contents: Vec, - #[serde(rename = "system_instruction")] + #[serde(rename = "systemInstruction")] #[serde(skip_serializing_if = "Option::is_none")] system_instruction: Option, #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, + #[serde(rename = "toolConfig")] + #[serde(skip_serializing_if = "Option::is_none")] + tool_config: Option, } #[derive(Debug, Serialize)] struct GeminiToolDeclarations { - #[serde(rename = "function_declarations")] + #[serde(rename = "functionDeclarations")] function_declarations: Vec, } @@ -102,6 +105,20 @@ struct GeminiFunctionDeclaration { parameters: Value, } +#[derive(Debug, Serialize)] +struct GeminiToolConfig { + #[serde(rename = "functionCallingConfig")] + function_calling_config: GeminiFunctionCallingConfig, +} + +#[derive(Debug, Serialize)] +struct GeminiFunctionCallingConfig { + mode: String, + #[serde(rename = "allowedFunctionNames")] + #[serde(skip_serializing_if = "Vec::is_empty")] + allowed_function_names: Vec, +} + #[derive(Debug, Serialize)] struct GeminiSystemInstruction { parts: Vec, @@ -125,6 +142,9 @@ enum GeminiOutgoingPart { FunctionCall { #[serde(rename = "functionCall")] function_call: GeminiFunctionCall, + #[serde(rename = "thoughtSignature")] + #[serde(skip_serializing_if = "Option::is_none")] + thought_signature: Option, }, FunctionResponse { #[serde(rename = "functionResponse")] @@ -169,12 +189,23 @@ struct GeminiStreamPart { text: String, #[serde(rename = "functionCall")] function_call: Option, + #[serde(rename = "thoughtSignature")] + thought_signature: Option, } #[derive(Debug, Default)] struct GeminiStreamDelta { text: String, - function_calls: Vec, + function_calls: Vec, +} + +const TOOL_CALL_SIGNATURES_KEY: &str = "gemini_tool_call_thought_signatures"; + +#[derive(Debug)] +struct GeminiFunctionCallDelta { + name: String, + args: Value, + thought_signature: Option, } fn normalize_model_id(id: &str) -> &str { @@ -306,6 +337,27 @@ fn as_gemini_tools(tools: &[Tool]) -> Option> { }]) } +fn as_gemini_tool_config(tools: &[Tool], messages: &[Message]) -> Option { + if tools.is_empty() { + return None; + } + + let has_tool_results = messages.iter().any(|message| { + matches!(message.from, EntityId::Tool) && !message.content.tool_results.is_empty() + }); + + Some(GeminiToolConfig { + function_calling_config: GeminiFunctionCallingConfig { + mode: if has_tool_results { "AUTO" } else { "ANY" }.to_string(), + allowed_function_names: if has_tool_results { + Vec::new() + } else { + tools.iter().map(|tool| tool.name.clone()).collect() + }, + }, + }) +} + fn collect_tool_call_names(messages: &[Message]) -> HashMap { let mut names = HashMap::new(); for message in messages { @@ -337,6 +389,7 @@ fn parse_tool_result_payload(result: &ToolResult) -> Value { fn as_bot_parts(message: &Message) -> Vec { let mut parts = Vec::new(); + let thought_signatures = parse_tool_call_thought_signatures(message.content.data.as_deref()); if !message.content.text.is_empty() { parts.push(GeminiOutgoingPart::Text(GeminiTextPart { @@ -350,6 +403,7 @@ fn as_bot_parts(message: &Message) -> Vec { name: call.name.clone(), args: Value::Object(call.arguments.clone()), }, + thought_signature: thought_signatures.get(&call.id).cloned(), }); } @@ -459,6 +513,7 @@ fn build_generate_request( contents, system_instruction, tools: as_gemini_tools(tools), + tool_config: as_gemini_tool_config(tools, messages), }) } @@ -481,7 +536,11 @@ fn parse_stream_delta(payload: &str) -> Result { } if let Some(function_call) = part.function_call { if !function_call.name.is_empty() { - delta.function_calls.push(function_call); + delta.function_calls.push(GeminiFunctionCallDelta { + name: function_call.name, + args: function_call.args, + thought_signature: part.thought_signature, + }); } } } @@ -508,6 +567,52 @@ fn function_call_args_to_map(args: Value) -> Map { } } +fn encode_tool_call_thought_signatures(signatures: &HashMap) -> Option { + if signatures.is_empty() { + return None; + } + + let signatures_object = signatures + .iter() + .map(|(k, v)| (k.clone(), Value::String(v.clone()))) + .collect::>(); + + let mut root = Map::new(); + root.insert( + TOOL_CALL_SIGNATURES_KEY.to_string(), + Value::Object(signatures_object), + ); + + serde_json::to_string(&Value::Object(root)).ok() +} + +fn parse_tool_call_thought_signatures(data: Option<&str>) -> HashMap { + let Some(data) = data else { + return HashMap::new(); + }; + + let Ok(value) = serde_json::from_str::(data) else { + return HashMap::new(); + }; + + let Some(signatures) = value + .as_object() + .and_then(|root| root.get(TOOL_CALL_SIGNATURES_KEY)) + .and_then(Value::as_object) + else { + return HashMap::new(); + }; + + signatures + .iter() + .filter_map(|(id, signature)| { + signature + .as_str() + .map(|signature| (id.clone(), signature.to_string())) + }) + .collect() +} + impl BotClient for GeminiClient { fn bots(&mut self) -> BoxPlatformSendFuture<'static, ClientResult>> { let inner = self.0.read().expect("gemini client lock poisoned").clone(); @@ -668,6 +773,7 @@ impl BotClient for GeminiClient { let mut full_text = String::new(); let mut tool_call_ids_by_index: HashMap = HashMap::new(); let mut tool_calls_by_index: BTreeMap = BTreeMap::new(); + let mut tool_call_signatures_by_id: HashMap = HashMap::new(); let mut next_tool_call_id = 0usize; let events = parse_sse(response.bytes_stream()); @@ -713,17 +819,22 @@ impl BotClient for GeminiClient { tool_calls_by_index.insert( index, ToolCall { - id: call_id, + id: call_id.clone(), name: function_call.name, arguments: function_call_args_to_map(function_call.args), ..Default::default() }, ); + + if let Some(signature) = function_call.thought_signature { + tool_call_signatures_by_id.insert(call_id, signature); + } } let content = MessageContent { text: full_text.clone(), tool_calls: tool_calls_by_index.values().cloned().collect(), + data: encode_tool_call_thought_signatures(&tool_call_signatures_by_id), ..Default::default() }; yield ClientResult::new_ok(content); @@ -846,11 +957,22 @@ mod tests { assert_eq!( request .system_instruction + .as_ref() .expect("missing system instruction") .parts[0] .text, "You are helpful." ); + + let value = serde_json::to_value(request).expect("failed to serialize request"); + assert_eq!( + value["systemInstruction"]["parts"][0]["text"], + "You are helpful." + ); + assert!( + value["system_instruction"].is_null(), + "snake_case field should not be present" + ); } #[test] @@ -920,13 +1042,22 @@ mod tests { let request = build_generate_request(&messages, &tools).expect("failed to build request"); let value = serde_json::to_value(request).expect("failed to serialize request"); - let declarations = value["tools"][0]["function_declarations"] + let declarations = value["tools"][0]["functionDeclarations"] .as_array() .expect("missing function_declarations"); assert_eq!(declarations.len(), 1); assert_eq!(declarations[0]["name"], "get_weather"); assert_eq!(declarations[0]["parameters"]["type"], "object"); + assert!( + value["tools"][0]["function_declarations"].is_null(), + "snake_case field should not be present" + ); + assert_eq!(value["toolConfig"]["functionCallingConfig"]["mode"], "ANY"); + assert_eq!( + value["toolConfig"]["functionCallingConfig"]["allowedFunctionNames"][0], + "get_weather" + ); } #[test] @@ -980,6 +1111,34 @@ mod tests { ); } + #[test] + fn build_generate_request_sets_tool_mode_auto_after_tool_results() { + let messages = vec![Message { + from: EntityId::Tool, + content: MessageContent { + tool_results: vec![ToolResult { + tool_call_id: "call-1".to_string(), + content: r#"{"ok":true}"#.to_string(), + is_error: false, + }], + ..Default::default() + }, + ..Default::default() + }]; + + let tools = vec![Tool { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + input_schema: std::sync::Arc::new( + serde_json::from_str(r#"{"type":"object"}"#).expect("invalid schema"), + ), + }]; + + let request = build_generate_request(&messages, &tools).expect("failed to build request"); + let value = serde_json::to_value(request).expect("failed to serialize request"); + assert_eq!(value["toolConfig"]["functionCallingConfig"]["mode"], "AUTO"); + } + #[test] fn parse_stream_delta_extracts_text_and_function_calls() { let payload = r#" @@ -989,7 +1148,7 @@ mod tests { "content": { "parts": [ {"text":"Checking..."}, - {"functionCall":{"name":"get_weather","args":{"city":"Tokyo"}}} + {"functionCall":{"name":"get_weather","args":{"city":"Tokyo"}},"thoughtSignature":"sig-123"} ] } } @@ -1001,5 +1160,34 @@ mod tests { assert_eq!(delta.function_calls.len(), 1); assert_eq!(delta.function_calls[0].name, "get_weather"); assert_eq!(delta.function_calls[0].args["city"], "Tokyo"); + assert_eq!( + delta.function_calls[0].thought_signature.as_deref(), + Some("sig-123") + ); + } + + #[test] + fn as_bot_parts_includes_thought_signature_from_data() { + let message = Message { + from: EntityId::Bot(BotId::new("gemini-3-flash-preview")), + content: MessageContent { + tool_calls: vec![ToolCall { + id: "call-1".to_string(), + name: "get_weather".to_string(), + arguments: serde_json::from_str(r#"{"location":"Montevideo"}"#) + .expect("invalid args"), + ..Default::default() + }], + data: Some( + r#"{"gemini_tool_call_thought_signatures":{"call-1":"sig-abc"}}"#.to_string(), + ), + ..Default::default() + }, + ..Default::default() + }; + + let parts = as_bot_parts(&message); + let value = serde_json::to_value(parts).expect("failed to serialize parts"); + assert_eq!(value[0]["thoughtSignature"], "sig-abc"); } } From 49d71cc2fd6a970e3d06986fc824c47eefbb1d57 Mon Sep 17 00:00:00 2001 From: Alvin Date: Wed, 4 Mar 2026 13:32:04 +0800 Subject: [PATCH 12/16] feat(example): add gemini native tool-calls demo --- examples/gemini-tool-calls/Cargo.toml | 10 ++ examples/gemini-tool-calls/README.md | 19 ++++ examples/gemini-tool-calls/src/main.rs | 132 +++++++++++++++++++++++++ 3 files changed, 161 insertions(+) create mode 100644 examples/gemini-tool-calls/Cargo.toml create mode 100644 examples/gemini-tool-calls/README.md create mode 100644 examples/gemini-tool-calls/src/main.rs diff --git a/examples/gemini-tool-calls/Cargo.toml b/examples/gemini-tool-calls/Cargo.toml new file mode 100644 index 0000000..d112a71 --- /dev/null +++ b/examples/gemini-tool-calls/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "gemini-tool-calls-example" +version = "0.1.0" +edition = "2024" + +[dependencies] +aitk = { path = "../..", features = ["api-clients"] } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +futures = "0.3" +serde_json = "1.0" diff --git a/examples/gemini-tool-calls/README.md b/examples/gemini-tool-calls/README.md new file mode 100644 index 0000000..db3c079 --- /dev/null +++ b/examples/gemini-tool-calls/README.md @@ -0,0 +1,19 @@ +## Gemini Native Tool Calls + +Demonstrates native Gemini tool-calling with `GeminiClient`: +- Send tool declarations (`function_declarations`) +- Receive model tool calls +- Execute tools in Rust +- Send `ToolResult` back to Gemini +- Receive final answer + +### Requirements + +Set env variables and run: + +```shell +export API_URL="https://generativelanguage.googleapis.com/v1beta" +export API_KEY="your-gemini-key" +export MODEL_ID="gemini-2.0-flash" +cargo run +``` diff --git a/examples/gemini-tool-calls/src/main.rs b/examples/gemini-tool-calls/src/main.rs new file mode 100644 index 0000000..6ef1e74 --- /dev/null +++ b/examples/gemini-tool-calls/src/main.rs @@ -0,0 +1,132 @@ +use aitk::prelude::*; +use futures::StreamExt; +use std::sync::Arc; + +#[tokio::main] +async fn main() { + let url = std::env::var("API_URL").expect("API_URL must be set"); + let key = std::env::var("API_KEY").expect("API_KEY must be set"); + let model = std::env::var("MODEL_ID").expect("MODEL_ID must be set"); + + let mut client = GeminiClient::new(url); + client.set_key(&key).expect("Invalid API key"); + + let bot_id = BotId::new(&model); + + let weather_tool = Tool { + name: "get_weather".into(), + description: Some("Get current weather for a location".into()), + input_schema: Arc::new( + serde_json::from_str( + r#"{ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name, e.g. 'Tokyo'" + } + }, + "required": ["location"] + }"#, + ) + .expect("Invalid JSON schema"), + ), + }; + + let tools = [weather_tool]; + + let mut messages = vec![Message { + from: EntityId::User, + content: MessageContent { + text: "What's the weather like in Montevideo?".into(), + ..Default::default() + }, + ..Default::default() + }]; + + for turn in 0..5 { + let assistant_content = + match send_and_collect(&mut client, &bot_id, &messages, &tools).await { + Ok(content) => content, + Err(()) => return, + }; + + if assistant_content.tool_calls.is_empty() { + println!("\nFinal answer:\n{}", assistant_content.text); + return; + } + + println!("\nTurn {} tool calls:", turn + 1); + for tc in &assistant_content.tool_calls { + println!("Tool call: {} with args {:?}", tc.name, tc.arguments); + } + + messages.push(Message { + from: EntityId::Bot(bot_id.clone()), + content: assistant_content.clone(), + ..Default::default() + }); + + for tc in &assistant_content.tool_calls { + let result = execute_tool(tc); + println!("Tool result for {}: {}", tc.name, result); + + messages.push(Message { + from: EntityId::Tool, + content: MessageContent { + tool_results: vec![ToolResult { + tool_call_id: tc.id.clone(), + content: result, + is_error: false, + }], + ..Default::default() + }, + ..Default::default() + }); + } + } + + println!("\nReached max turns without final text."); +} + +async fn send_and_collect( + client: &mut GeminiClient, + bot_id: &BotId, + messages: &[Message], + tools: &[Tool], +) -> Result { + let mut last_content = MessageContent::default(); + let mut stream = client.send(bot_id, messages, tools); + + while let Some(result) = stream.next().await { + match result.into_result() { + Ok(content) => last_content = content, + Err(errors) => { + for e in errors { + eprintln!("Error: {e}"); + if let Some(details) = e.details() { + eprintln!("Details: {details}"); + } + } + return Err(()); + } + } + } + + Ok(last_content) +} + +fn execute_tool(tool_call: &ToolCall) -> String { + match tool_call.name.as_str() { + "get_weather" => { + let location = tool_call + .arguments + .get("location") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + + format!(r#"{{"location": "{location}", "temp": "22C", "condition": "sunny"}}"#) + } + other => format!(r#"{{"error": "Unknown tool: {other}"}}"#), + } +} From 5d43cd2e0ab93bb7b11458494be51db40483f705 Mon Sep 17 00:00:00 2001 From: Alvin Date: Thu, 5 Mar 2026 11:19:26 +0800 Subject: [PATCH 13/16] fix(gemini): protocol-first tool call ids and minimal SSE CRLF normalization --- src/clients/gemini.rs | 257 ++++++++++++++++++++++++++++++++++-------- src/utils/sse.rs | 52 +++------ 2 files changed, 223 insertions(+), 86 deletions(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index 6289064..486b2f4 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -8,7 +8,7 @@ use reqwest::header::{HeaderMap, HeaderName}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use std::{ - collections::{BTreeMap, HashMap}, + collections::HashMap, str::FromStr, sync::{Arc, RwLock}, }; @@ -154,6 +154,10 @@ enum GeminiOutgoingPart { #[derive(Debug, Clone, Serialize, Deserialize)] struct GeminiFunctionCall { + // Gemini may provide a stable server-side call id for function call/result correlation. + // We preserve it when present and use it as the primary identity key in stream assembly. + #[serde(default)] + id: Option, #[serde(default)] name: String, #[serde(default)] @@ -203,6 +207,7 @@ const TOOL_CALL_SIGNATURES_KEY: &str = "gemini_tool_call_thought_signatures"; #[derive(Debug)] struct GeminiFunctionCallDelta { + id: Option, name: String, args: Value, thought_signature: Option, @@ -337,23 +342,15 @@ fn as_gemini_tools(tools: &[Tool]) -> Option> { }]) } -fn as_gemini_tool_config(tools: &[Tool], messages: &[Message]) -> Option { +fn as_gemini_tool_config(tools: &[Tool]) -> Option { if tools.is_empty() { return None; } - let has_tool_results = messages.iter().any(|message| { - matches!(message.from, EntityId::Tool) && !message.content.tool_results.is_empty() - }); - Some(GeminiToolConfig { function_calling_config: GeminiFunctionCallingConfig { - mode: if has_tool_results { "AUTO" } else { "ANY" }.to_string(), - allowed_function_names: if has_tool_results { - Vec::new() - } else { - tools.iter().map(|tool| tool.name.clone()).collect() - }, + mode: "AUTO".to_string(), + allowed_function_names: Vec::new(), }, }) } @@ -400,6 +397,9 @@ fn as_bot_parts(message: &Message) -> Vec { for call in &message.content.tool_calls { parts.push(GeminiOutgoingPart::FunctionCall { function_call: GeminiFunctionCall { + // Keep the call id when replaying model tool calls back to Gemini. + // This preserves protocol-level correlation with later function responses. + id: Some(call.id.clone()), name: call.name.clone(), args: Value::Object(call.arguments.clone()), }, @@ -513,7 +513,7 @@ fn build_generate_request( contents, system_instruction, tools: as_gemini_tools(tools), - tool_config: as_gemini_tool_config(tools, messages), + tool_config: as_gemini_tool_config(tools), }) } @@ -537,6 +537,7 @@ fn parse_stream_delta(payload: &str) -> Result { if let Some(function_call) = part.function_call { if !function_call.name.is_empty() { delta.function_calls.push(GeminiFunctionCallDelta { + id: function_call.id, name: function_call.name, args: function_call.args, thought_signature: part.thought_signature, @@ -613,6 +614,107 @@ fn parse_tool_call_thought_signatures(data: Option<&str>) -> HashMap, + order: Vec, + calls_by_id: HashMap, + thought_signatures_by_id: HashMap, + next_id: usize, +} + +struct StreamToolCallSlot { + // Fallback signature used only when protocol id is absent. + signature: String, + id: String, +} + +impl GeminiStreamToolCallState { + fn apply_delta(&mut self, function_calls: Vec) { + for (stream_index, function_call) in function_calls.into_iter().enumerate() { + let signature = stream_tool_call_signature(&function_call.name, &function_call.args); + // Design decision: + // 1) Protocol ID first: if Gemini returns `functionCall.id`, we must preserve it + // end-to-end so follow-up `functionResponse` can correlate with the exact server call. + // 2) Fallback only when `id` is absent: some responses may omit it, so we keep a local + // stable key based on stream position + call signature to avoid ID collisions. + let call_id = if let Some(protocol_id) = function_call.id.clone() { + self.by_stream_index.insert( + stream_index, + StreamToolCallSlot { + signature, + id: protocol_id.clone(), + }, + ); + self.ensure_ordered_id(&protocol_id); + protocol_id + } else { + self.call_id_from_fallback(stream_index, signature) + }; + + self.calls_by_id.insert( + call_id.clone(), + ToolCall { + id: call_id.clone(), + name: function_call.name, + arguments: function_call_args_to_map(function_call.args), + ..Default::default() + }, + ); + + if let Some(thought_signature) = function_call.thought_signature { + self.thought_signatures_by_id + .insert(call_id, thought_signature); + } + } + } + + fn call_id_from_fallback(&mut self, stream_index: usize, signature: String) -> String { + // Fallback policy: + // - same stream index + same signature => same logical call (continue updating), + // - otherwise allocate a new local id to prevent cross-call collisions. + match self.by_stream_index.get(&stream_index) { + Some(slot) if slot.signature == signature => slot.id.clone(), + _ => { + let id = format!("gemini-call-{}", self.next_id); + self.next_id += 1; + self.by_stream_index.insert( + stream_index, + StreamToolCallSlot { + signature, + id: id.clone(), + }, + ); + self.order.push(id.clone()); + id + } + } + } + + fn ensure_ordered_id(&mut self, id: &str) { + if self.calls_by_id.contains_key(id) { + return; + } + self.order.push(id.to_string()); + } + + fn tool_calls(&self) -> Vec { + self.order + .iter() + .filter_map(|id| self.calls_by_id.get(id).cloned()) + .collect() + } + + fn encoded_thought_signatures(&self) -> Option { + encode_tool_call_thought_signatures(&self.thought_signatures_by_id) + } +} + +fn stream_tool_call_signature(name: &str, args: &Value) -> String { + let serialized_args = serde_json::to_string(args).unwrap_or_default(); + format!("{name}:{serialized_args}") +} + impl BotClient for GeminiClient { fn bots(&mut self) -> BoxPlatformSendFuture<'static, ClientResult>> { let inner = self.0.read().expect("gemini client lock poisoned").clone(); @@ -771,10 +873,7 @@ impl BotClient for GeminiClient { } let mut full_text = String::new(); - let mut tool_call_ids_by_index: HashMap = HashMap::new(); - let mut tool_calls_by_index: BTreeMap = BTreeMap::new(); - let mut tool_call_signatures_by_id: HashMap = HashMap::new(); - let mut next_tool_call_id = 0usize; + let mut stream_tool_call_state = GeminiStreamToolCallState::default(); let events = parse_sse(response.bytes_stream()); for await event in events { @@ -806,35 +905,12 @@ impl BotClient for GeminiClient { full_text.push_str(&delta.text); } - for (index, function_call) in delta.function_calls.into_iter().enumerate() { - let call_id = tool_call_ids_by_index - .entry(index) - .or_insert_with(|| { - let call_id = format!("gemini-call-{next_tool_call_id}"); - next_tool_call_id += 1; - call_id - }) - .clone(); - - tool_calls_by_index.insert( - index, - ToolCall { - id: call_id.clone(), - name: function_call.name, - arguments: function_call_args_to_map(function_call.args), - ..Default::default() - }, - ); - - if let Some(signature) = function_call.thought_signature { - tool_call_signatures_by_id.insert(call_id, signature); - } - } + stream_tool_call_state.apply_delta(delta.function_calls); let content = MessageContent { text: full_text.clone(), - tool_calls: tool_calls_by_index.values().cloned().collect(), - data: encode_tool_call_thought_signatures(&tool_call_signatures_by_id), + tool_calls: stream_tool_call_state.tool_calls(), + data: stream_tool_call_state.encoded_thought_signatures(), ..Default::default() }; yield ClientResult::new_ok(content); @@ -1053,10 +1129,10 @@ mod tests { value["tools"][0]["function_declarations"].is_null(), "snake_case field should not be present" ); - assert_eq!(value["toolConfig"]["functionCallingConfig"]["mode"], "ANY"); - assert_eq!( - value["toolConfig"]["functionCallingConfig"]["allowedFunctionNames"][0], - "get_weather" + assert_eq!(value["toolConfig"]["functionCallingConfig"]["mode"], "AUTO"); + assert!( + value["toolConfig"]["functionCallingConfig"]["allowedFunctionNames"].is_null(), + "allowedFunctionNames should be omitted in AUTO mode" ); } @@ -1158,6 +1234,7 @@ mod tests { let delta = parse_stream_delta(payload).expect("failed to parse stream payload"); assert_eq!(delta.text, "Checking..."); assert_eq!(delta.function_calls.len(), 1); + assert_eq!(delta.function_calls[0].id, None); assert_eq!(delta.function_calls[0].name, "get_weather"); assert_eq!(delta.function_calls[0].args["city"], "Tokyo"); assert_eq!( @@ -1166,6 +1243,92 @@ mod tests { ); } + #[test] + fn parse_stream_delta_extracts_protocol_function_call_id() { + let payload = r#" + { + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "id": "protocol-call-42", + "name": "get_weather", + "args": {"city":"Tokyo"} + } + } + ] + } + } + ] + }"#; + + let delta = parse_stream_delta(payload).expect("failed to parse stream payload"); + assert_eq!(delta.function_calls.len(), 1); + assert_eq!( + delta.function_calls[0].id.as_deref(), + Some("protocol-call-42") + ); + } + + #[test] + fn stream_tool_call_state_preserves_distinct_calls_across_chunk_index_restarts() { + let mut state = GeminiStreamToolCallState::default(); + + state.apply_delta(vec![GeminiFunctionCallDelta { + id: None, + name: "first_call".to_string(), + args: serde_json::json!({"city":"Tokyo"}), + thought_signature: None, + }]); + + state.apply_delta(vec![GeminiFunctionCallDelta { + id: None, + name: "second_call".to_string(), + args: serde_json::json!({"city":"Seoul"}), + thought_signature: None, + }]); + + let calls = state.tool_calls(); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "first_call"); + assert_eq!(calls[1].name, "second_call"); + assert_ne!(calls[0].id, calls[1].id); + } + + #[test] + fn stream_tool_call_state_prefers_protocol_function_call_id() { + let payload = r#" + { + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "id": "protocol-call-42", + "name": "get_weather", + "args": {"city":"Tokyo"} + } + } + ] + } + } + ] + }"#; + + let delta = parse_stream_delta(payload).expect("failed to parse stream payload"); + + let mut state = GeminiStreamToolCallState::default(); + state.apply_delta(delta.function_calls); + + let calls = state.tool_calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].id, "protocol-call-42"); + assert_eq!(calls[0].name, "get_weather"); + } + #[test] fn as_bot_parts_includes_thought_signature_from_data() { let message = Message { diff --git a/src/utils/sse.rs b/src/utils/sse.rs index 9661df5..76635ce 100644 --- a/src/utils/sse.rs +++ b/src/utils/sse.rs @@ -3,8 +3,7 @@ use async_stream::stream; use futures::Stream; -pub(crate) const EVENT_TERMINATOR: &'static [u8] = b"\n\n"; -pub(crate) const EVENT_TERMINATOR_CRLF: &'static [u8] = b"\r\n\r\n"; +const EVENT_TERMINATOR: &[u8] = b"\n\n"; /// Split from the last SSE event terminator. /// @@ -16,35 +15,18 @@ pub(crate) const EVENT_TERMINATOR_CRLF: &'static [u8] = b"\r\n\r\n"; /// so you should keep this on the buffer until completed. /// /// Returns `None` if no terminator is found. -pub(crate) fn rsplit_once_terminator(buffer: &[u8]) -> Option<(&[u8], &[u8])> { - fn find_last(buffer: &[u8], term: &[u8]) -> Option { - buffer - .windows(term.len()) - .enumerate() - .rev() - .find(|(_, w)| *w == term) - .map(|(pos, _)| pos) - } - - let lf = find_last(buffer, EVENT_TERMINATOR).map(|pos| (pos, EVENT_TERMINATOR.len())); - let crlf = - find_last(buffer, EVENT_TERMINATOR_CRLF).map(|pos| (pos, EVENT_TERMINATOR_CRLF.len())); - - let (pos, len) = match (lf, crlf) { - (Some(lf), Some(crlf)) => { - if lf.0 >= crlf.0 { - lf - } else { - crlf - } - } - (Some(lf), None) => lf, - (None, Some(crlf)) => crlf, - (None, None) => return None, - }; +/// +/// This splitter handles LF-only delimiters. Normalize CRLF before calling it. +fn rsplit_once_terminator(buffer: &[u8]) -> Option<(&[u8], &[u8])> { + let pos = buffer + .windows(EVENT_TERMINATOR.len()) + .enumerate() + .rev() + .find(|(_, w)| *w == EVENT_TERMINATOR) + .map(|(pos, _)| pos)?; let (before, after_with_terminator) = buffer.split_at(pos); - let after = &after_with_terminator[len..]; + let after = &after_with_terminator[EVENT_TERMINATOR.len()..]; Some((before, after)) } @@ -100,6 +82,7 @@ where let chunk = chunk.as_ref(); buffer.extend_from_slice(chunk); + buffer.retain(|&b| b != b'\r'); let Some((completed_messages, incomplete_message)) = rsplit_once_terminator(&buffer) @@ -108,8 +91,7 @@ where }; // Silently drop any invalid utf8 bytes from the completed messages. - let completed_messages = String::from_utf8_lossy(completed_messages) - .replace("\r\n", "\n"); + let completed_messages = String::from_utf8_lossy(completed_messages); let messages = completed_messages .split(event_terminator_str) @@ -137,14 +119,6 @@ mod tests { assert_eq!(incomplete, b"data: incomplete mes"); } - #[test] - fn test_rsplit_once_terminator_crlf() { - let buffer = b"data: 1\r\n\r\ndata: 2\r\n\r\ndata: incomplete mes"; - let (completed, incomplete) = rsplit_once_terminator(buffer).unwrap(); - assert_eq!(completed, b"data: 1\r\n\r\ndata: 2"); - assert_eq!(incomplete, b"data: incomplete mes"); - } - #[test] fn test_extract_sse_data_ignores_non_data_event() { let message = "event: ping\nid: 1"; From 56c122ef5bd9a31f92072432e48592e03cd2abc2 Mon Sep 17 00:00:00 2001 From: Alvin Date: Sun, 15 Mar 2026 00:56:13 +0800 Subject: [PATCH 14/16] chore: sort imports and reformat lines --- src/clients/openai.rs | 2 +- src/controllers/chat.rs | 10 ++-------- src/prelude.rs | 4 ++-- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/clients/openai.rs b/src/clients/openai.rs index be0f4ff..aaf3178 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -8,9 +8,9 @@ use std::{ sync::{Arc, RwLock}, }; +use crate::protocol::*; use crate::utils::asynchronous::{BoxPlatformSendFuture, BoxPlatformSendStream}; use crate::utils::{serde::deserialize_null_default, sse::parse_sse}; -use crate::protocol::*; /// The content of a [`ContentPart::ImageUrl`]. #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/src/controllers/chat.rs b/src/controllers/chat.rs index ea8632d..98d369d 100644 --- a/src/controllers/chat.rs +++ b/src/controllers/chat.rs @@ -448,10 +448,7 @@ impl ChatController { c.dispatch_mutation(VecMutation::Set(bots.unwrap_or_default())); - let messages: Vec<_> = errors - .into_iter() - .map(Message::from_client_error) - .collect(); + let messages: Vec<_> = errors.into_iter().map(Message::from_client_error).collect(); c.dispatch_mutation(VecMutation::Extend(messages)); }); })); @@ -509,10 +506,7 @@ impl ChatController { false } Err(errors) => { - let messages: Vec<_> = errors - .into_iter() - .map(Message::from_client_error) - .collect(); + let messages: Vec<_> = errors.into_iter().map(Message::from_client_error).collect(); self.dispatch_mutation(VecMutation::Extend(messages)); true diff --git a/src/prelude.rs b/src/prelude.rs index 4cbc028..be94bd8 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -5,9 +5,9 @@ pub use crate::protocol::*; // These are the clients that are most commonly used. #[cfg(feature = "api-clients")] -pub use crate::clients::openai::OpenAiClient; -#[cfg(feature = "api-clients")] pub use crate::clients::gemini::GeminiClient; +#[cfg(feature = "api-clients")] +pub use crate::clients::openai::OpenAiClient; pub use crate::clients::router::RouterClient; // These other clients are less commonly used. From c5c73554d277bc2f592cd7b4d68afe6d71214d34 Mon Sep 17 00:00:00 2001 From: Alvin Date: Sun, 15 Mar 2026 00:56:40 +0800 Subject: [PATCH 15/16] fix(gemini): harden stream tool-call id assignment and error on unknown tool results - Promote fallback ids to protocol ids when they arrive later in the stream, preferring same-index matches to avoid misassignment when chunk indices shift between SSE events. - Return an error instead of silently degrading when a tool result references an unknown tool_call_id. - Fix ensure_ordered_id to check the order list instead of calls_by_id. - Collapse nested if-let per clippy, tighten comments, shorten test names. --- src/clients/gemini.rs | 315 ++++++++++++++++++++++++++++++++---------- 1 file changed, 245 insertions(+), 70 deletions(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index 486b2f4..903436e 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -154,8 +154,7 @@ enum GeminiOutgoingPart { #[derive(Debug, Clone, Serialize, Deserialize)] struct GeminiFunctionCall { - // Gemini may provide a stable server-side call id for function call/result correlation. - // We preserve it when present and use it as the primary identity key in stream assembly. + // Server-assigned call id used for function call/result correlation. #[serde(default)] id: Option, #[serde(default)] @@ -397,8 +396,6 @@ fn as_bot_parts(message: &Message) -> Vec { for call in &message.content.tool_calls { parts.push(GeminiOutgoingPart::FunctionCall { function_call: GeminiFunctionCall { - // Keep the call id when replaying model tool calls back to Gemini. - // This preserves protocol-level correlation with later function responses. id: Some(call.id.clone()), name: call.name.clone(), args: Value::Object(call.arguments.clone()), @@ -413,7 +410,7 @@ fn as_bot_parts(message: &Message) -> Vec { fn as_tool_parts( message: &Message, tool_call_names: &HashMap, -) -> Vec { +) -> Result, ClientError> { let mut parts = Vec::new(); for result in &message.content.tool_results { @@ -424,10 +421,14 @@ fn as_tool_parts( response: parse_tool_result_payload(result), }, }); - } else if !result.content.is_empty() { - parts.push(GeminiOutgoingPart::Text(GeminiTextPart { - text: result.content.clone(), - })); + } else { + return Err(ClientError::new( + ClientErrorKind::Format, + format!( + "Gemini tool result references unknown tool call id '{}'.", + result.tool_call_id + ), + )); } } @@ -437,7 +438,7 @@ fn as_tool_parts( })); } - parts + Ok(parts) } fn build_generate_request( @@ -461,7 +462,7 @@ fn build_generate_request( } } EntityId::Tool => { - let parts = as_tool_parts(message, &tool_call_names); + let parts = as_tool_parts(message, &tool_call_names)?; if !parts.is_empty() { contents.push(GeminiContent { role: "user".to_string(), @@ -534,15 +535,13 @@ fn parse_stream_delta(payload: &str) -> Result { if !part.text.is_empty() { delta.text.push_str(&part.text); } - if let Some(function_call) = part.function_call { - if !function_call.name.is_empty() { - delta.function_calls.push(GeminiFunctionCallDelta { - id: function_call.id, - name: function_call.name, - args: function_call.args, - thought_signature: part.thought_signature, - }); - } + if let Some(function_call) = part.function_call.filter(|c| !c.name.is_empty()) { + delta.function_calls.push(GeminiFunctionCallDelta { + id: function_call.id, + name: function_call.name, + args: function_call.args, + thought_signature: part.thought_signature, + }); } } } @@ -624,7 +623,6 @@ struct GeminiStreamToolCallState { } struct StreamToolCallSlot { - // Fallback signature used only when protocol id is absent. signature: String, id: String, } @@ -633,12 +631,30 @@ impl GeminiStreamToolCallState { fn apply_delta(&mut self, function_calls: Vec) { for (stream_index, function_call) in function_calls.into_iter().enumerate() { let signature = stream_tool_call_signature(&function_call.name, &function_call.args); - // Design decision: - // 1) Protocol ID first: if Gemini returns `functionCall.id`, we must preserve it - // end-to-end so follow-up `functionResponse` can correlate with the exact server call. - // 2) Fallback only when `id` is absent: some responses may omit it, so we keep a local - // stable key based on stream position + call signature to avoid ID collisions. + + // Use protocol id when available; otherwise fall back to a local id + // keyed by (stream_index, signature). + // + // When a protocol id arrives for a call that was previously tracked + // by a fallback id, we promote the old id: first check the same + // stream_index (best match), then search all slots by signature + // (handles index shifts between chunks). Identical-signature calls + // without protocol ids are inherently ambiguous. let call_id = if let Some(protocol_id) = function_call.id.clone() { + let matching_fallback = self + .by_stream_index + .get(&stream_index) + .filter(|slot| slot.signature == signature) + .map(|slot| (stream_index, slot.id.clone())) + .or_else(|| { + self.by_stream_index + .iter() + .find(|&(&idx, slot)| { + idx != stream_index && slot.signature == signature + }) + .map(|(&idx, slot)| (idx, slot.id.clone())) + }); + self.by_stream_index.insert( stream_index, StreamToolCallSlot { @@ -646,6 +662,14 @@ impl GeminiStreamToolCallState { id: protocol_id.clone(), }, ); + + if let Some((old_index, fallback_id)) = matching_fallback { + self.promote_call_id(&fallback_id, &protocol_id); + if old_index != stream_index { + self.by_stream_index.remove(&old_index); + } + } + self.ensure_ordered_id(&protocol_id); protocol_id } else { @@ -670,9 +694,6 @@ impl GeminiStreamToolCallState { } fn call_id_from_fallback(&mut self, stream_index: usize, signature: String) -> String { - // Fallback policy: - // - same stream index + same signature => same logical call (continue updating), - // - otherwise allocate a new local id to prevent cross-call collisions. match self.by_stream_index.get(&stream_index) { Some(slot) if slot.signature == signature => slot.id.clone(), _ => { @@ -691,8 +712,31 @@ impl GeminiStreamToolCallState { } } + fn promote_call_id(&mut self, previous_id: &str, protocol_id: &str) { + if previous_id == protocol_id { + return; + } + + if let Some(mut call) = self.calls_by_id.remove(previous_id) { + call.id = protocol_id.to_string(); + self.calls_by_id + .entry(protocol_id.to_string()) + .or_insert(call); + } + + if let Some(signature) = self.thought_signatures_by_id.remove(previous_id) { + self.thought_signatures_by_id + .entry(protocol_id.to_string()) + .or_insert(signature); + } + + if let Some(pos) = self.order.iter().position(|id| id == previous_id) { + self.order[pos] = protocol_id.to_string(); + } + } + fn ensure_ordered_id(&mut self, id: &str) { - if self.calls_by_id.contains_key(id) { + if self.order.iter().any(|existing| existing == id) { return; } self.order.push(id.to_string()); @@ -711,7 +755,8 @@ impl GeminiStreamToolCallState { } fn stream_tool_call_signature(name: &str, args: &Value) -> String { - let serialized_args = serde_json::to_string(args).unwrap_or_default(); + let serialized_args = serde_json::to_string(args) + .expect("serializing Gemini tool call arguments should not fail"); format!("{name}:{serialized_args}") } @@ -741,8 +786,7 @@ impl BotClient for GeminiClient { return ClientError::new_with_source( ClientErrorKind::Network, format!( - "Could not send request to {url}. \ - Verify your connection and key." + "Could not send request to {url}. Verify your connection and key." ), Some(error), ) @@ -755,10 +799,7 @@ impl BotClient for GeminiClient { let details = response.text().await.unwrap_or_default(); return ClientError::new( ClientErrorKind::Response, - format!( - "Gemini models request failed \ - with status {status}." - ), + format!("Gemini models request failed with status {status}."), ) .with_details(details) .into(); @@ -769,10 +810,7 @@ impl BotClient for GeminiClient { Err(error) => { return ClientError::new_with_source( ClientErrorKind::Format, - format!( - "Could not read Gemini models \ - response from {url}." - ), + format!("Could not read Gemini models response from {url}."), Some(error), ) .into(); @@ -784,9 +822,7 @@ impl BotClient for GeminiClient { Err(error) => { return ClientError::new_with_source( ClientErrorKind::Format, - "Could not parse Gemini models \ - response." - .to_string(), + "Could not parse Gemini models response.".to_string(), Some(error), ) .into(); @@ -930,7 +966,7 @@ mod tests { use super::*; #[test] - fn parse_models_response_prefers_display_name() { + fn models_use_display_name() { let payload = r#" { "models": [ @@ -950,7 +986,7 @@ mod tests { } #[test] - fn models_url_preserves_existing_query() { + fn models_url_keeps_query() { let url = build_models_url( "https://generativelanguage.googleapis.com/v1beta?alt=sse", None, @@ -962,7 +998,7 @@ mod tests { } #[test] - fn models_url_includes_page_token() { + fn models_url_pagination() { let url = build_models_url( "https://generativelanguage.googleapis.com/v1beta", Some("abc123"), @@ -973,7 +1009,7 @@ mod tests { } #[test] - fn stream_url_uses_stream_generate_content() { + fn stream_url_format() { let url = build_stream_url( "https://generativelanguage.googleapis.com/v1beta", &BotId::new("models/gemini-2.0-flash"), @@ -985,7 +1021,7 @@ mod tests { } #[test] - fn stream_url_keeps_qualified_resource_path() { + fn stream_url_qualified_path() { let url = build_stream_url( "https://generativelanguage.googleapis.com/v1beta", &BotId::new("tunedModels/my-tuned-model"), @@ -997,7 +1033,7 @@ mod tests { } #[test] - fn build_generate_request_maps_system_user_and_model_roles() { + fn request_maps_roles() { let messages = vec![ Message { from: EntityId::System, @@ -1052,7 +1088,7 @@ mod tests { } #[test] - fn parse_models_response_maps_capabilities_from_generation_methods() { + fn models_filter_by_capability() { let payload = r#" { "models": [ @@ -1076,7 +1112,7 @@ mod tests { } #[test] - fn parse_stream_text_collects_all_candidate_parts() { + fn stream_text_merges_parts() { let payload = r#" { "candidates": [ @@ -1089,7 +1125,7 @@ mod tests { } #[test] - fn build_generate_request_includes_tool_declarations() { + fn request_includes_tools() { let messages = vec![Message { from: EntityId::User, content: MessageContent { @@ -1137,7 +1173,7 @@ mod tests { } #[test] - fn build_generate_request_maps_tool_results_to_function_response_parts() { + fn request_maps_tool_results() { let tool_call_id = "call-1".to_string(); let messages = vec![ Message { @@ -1188,19 +1224,68 @@ mod tests { } #[test] - fn build_generate_request_sets_tool_mode_auto_after_tool_results() { - let messages = vec![Message { - from: EntityId::Tool, - content: MessageContent { - tool_results: vec![ToolResult { - tool_call_id: "call-1".to_string(), - content: r#"{"ok":true}"#.to_string(), - is_error: false, - }], + fn request_rejects_unknown_tool_id() { + let messages = vec![ + Message { + from: EntityId::User, + content: MessageContent { + text: "Use the tool".to_string(), + ..Default::default() + }, ..Default::default() }, - ..Default::default() - }]; + Message { + from: EntityId::Tool, + content: MessageContent { + tool_results: vec![ToolResult { + tool_call_id: "missing-call".to_string(), + content: r#"{"ok":true}"#.to_string(), + is_error: false, + }], + ..Default::default() + }, + ..Default::default() + }, + ]; + + let error = build_generate_request(&messages, &[]) + .expect_err("unknown tool result ids should fail request building"); + assert_eq!(error.kind(), ClientErrorKind::Format); + assert!( + error.to_string().contains("missing-call"), + "error should mention the missing tool call id" + ); + } + + #[test] + fn request_uses_auto_mode() { + let messages = vec![ + Message { + from: EntityId::Bot(BotId::new("gemini-2.0-flash")), + content: MessageContent { + tool_calls: vec![ToolCall { + id: "call-1".to_string(), + name: "get_weather".to_string(), + arguments: serde_json::Map::new(), + ..Default::default() + }], + ..Default::default() + }, + ..Default::default() + }, + Message { + from: EntityId::Tool, + content: MessageContent { + tool_results: vec![ToolResult { + tool_call_id: "call-1".to_string(), + content: r#"{"ok":true}"#.to_string(), + is_error: false, + }], + ..Default::default() + }, + ..Default::default() + }, + ]; let tools = vec![Tool { name: "get_weather".to_string(), @@ -1216,7 +1301,7 @@ mod tests { } #[test] - fn parse_stream_delta_extracts_text_and_function_calls() { + fn delta_extracts_text_and_calls() { let payload = r#" { "candidates": [ @@ -1244,7 +1329,7 @@ mod tests { } #[test] - fn parse_stream_delta_extracts_protocol_function_call_id() { + fn delta_keeps_protocol_id() { let payload = r#" { "candidates": [ @@ -1273,7 +1358,7 @@ mod tests { } #[test] - fn stream_tool_call_state_preserves_distinct_calls_across_chunk_index_restarts() { + fn tool_calls_distinct_across_chunks() { let mut state = GeminiStreamToolCallState::default(); state.apply_delta(vec![GeminiFunctionCallDelta { @@ -1298,7 +1383,7 @@ mod tests { } #[test] - fn stream_tool_call_state_prefers_protocol_function_call_id() { + fn tool_calls_prefer_protocol_id() { let payload = r#" { "candidates": [ @@ -1330,7 +1415,97 @@ mod tests { } #[test] - fn as_bot_parts_includes_thought_signature_from_data() { + fn tool_calls_upgrade_fallback_id() { + let mut state = GeminiStreamToolCallState::default(); + + state.apply_delta(vec![GeminiFunctionCallDelta { + id: None, + name: "get_weather".to_string(), + args: serde_json::json!({"city":"Tokyo"}), + thought_signature: Some("sig-local".to_string()), + }]); + + state.apply_delta(vec![GeminiFunctionCallDelta { + id: Some("protocol-call-42".to_string()), + name: "get_weather".to_string(), + args: serde_json::json!({"city":"Tokyo"}), + thought_signature: Some("sig-protocol".to_string()), + }]); + + let calls = state.tool_calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].id, "protocol-call-42"); + assert_eq!(calls[0].name, "get_weather"); + + let signatures = + parse_tool_call_thought_signatures(state.encoded_thought_signatures().as_deref()); + assert_eq!(signatures.len(), 1); + assert_eq!( + signatures.get("protocol-call-42").map(String::as_str), + Some("sig-protocol") + ); + } + + #[test] + fn tool_calls_survive_index_shift() { + let mut state = GeminiStreamToolCallState::default(); + + // Chunk 1: two calls without protocol ids at index 0 and 1. + state.apply_delta(vec![ + GeminiFunctionCallDelta { + id: None, + name: "call_a".to_string(), + args: serde_json::json!({"x": 1}), + thought_signature: None, + }, + GeminiFunctionCallDelta { + id: None, + name: "call_b".to_string(), + args: serde_json::json!({"y": 2}), + thought_signature: None, + }, + ]); + + let calls = state.tool_calls(); + assert_eq!(calls.len(), 2); + let a_fallback_id = calls[0].id.clone(); + let b_fallback_id = calls[1].id.clone(); + assert_eq!(calls[0].name, "call_a"); + assert_eq!(calls[1].name, "call_b"); + + // Chunk 2: only call_b is resent, but now at index 0 with a protocol id. + // The old index 0 held call_a — signatures differ, so call_a must NOT + // be promoted to call_b's protocol id. + state.apply_delta(vec![GeminiFunctionCallDelta { + id: Some("proto-b".to_string()), + name: "call_b".to_string(), + args: serde_json::json!({"y": 2}), + thought_signature: None, + }]); + + let calls = state.tool_calls(); + assert_eq!(calls.len(), 2, "both calls must survive"); + + let a = calls + .iter() + .find(|c| c.name == "call_a") + .expect("call_a missing"); + let b = calls + .iter() + .find(|c| c.name == "call_b") + .expect("call_b missing"); + + assert_eq!( + a.id, a_fallback_id, + "call_a must keep its original fallback id" + ); + assert_ne!(a.id, "proto-b", "call_a must NOT get call_b's protocol id"); + assert_eq!(b.id, "proto-b", "call_b must use its protocol id"); + assert_ne!(b.id, b_fallback_id, "call_b should have been upgraded"); + } + + #[test] + fn bot_parts_roundtrip_thought_signature() { let message = Message { from: EntityId::Bot(BotId::new("gemini-3-flash-preview")), content: MessageContent { From 3e0378b408ae9455da772a9ff5469c7d461270d8 Mon Sep 17 00:00:00 2001 From: Alvin Date: Sun, 15 Mar 2026 00:59:36 +0800 Subject: [PATCH 16/16] fix(gemini): trim redundant tests and revert unrelated cosmetic changes --- src/clients/gemini.rs | 244 +++++++++++----------------------------- src/clients/openai.rs | 2 +- src/controllers/chat.rs | 10 +- src/prelude.rs | 4 +- 4 files changed, 75 insertions(+), 185 deletions(-) diff --git a/src/clients/gemini.rs b/src/clients/gemini.rs index 903436e..7dd9d44 100644 --- a/src/clients/gemini.rs +++ b/src/clients/gemini.rs @@ -550,11 +550,6 @@ fn parse_stream_delta(payload: &str) -> Result { Ok(delta) } -#[cfg(test)] -fn parse_stream_text(payload: &str) -> Result { - Ok(parse_stream_delta(payload)?.text) -} - fn function_call_args_to_map(args: Value) -> Map { match args { Value::Object(args) => args, @@ -632,14 +627,11 @@ impl GeminiStreamToolCallState { for (stream_index, function_call) in function_calls.into_iter().enumerate() { let signature = stream_tool_call_signature(&function_call.name, &function_call.args); - // Use protocol id when available; otherwise fall back to a local id - // keyed by (stream_index, signature). - // - // When a protocol id arrives for a call that was previously tracked - // by a fallback id, we promote the old id: first check the same - // stream_index (best match), then search all slots by signature - // (handles index shifts between chunks). Identical-signature calls - // without protocol ids are inherently ambiguous. + // Promotion policy for protocol ids: + // 1. Same stream_index + same signature → promote (strongest match) + // 2. Global signature search → promote only if exactly one candidate + // 3. Multiple candidates → don't promote (ambiguous, prefer duplicate + // over wrong-assignment) let call_id = if let Some(protocol_id) = function_call.id.clone() { let matching_fallback = self .by_stream_index @@ -647,12 +639,19 @@ impl GeminiStreamToolCallState { .filter(|slot| slot.signature == signature) .map(|slot| (stream_index, slot.id.clone())) .or_else(|| { - self.by_stream_index + let candidates: Vec<_> = self + .by_stream_index .iter() - .find(|&(&idx, slot)| { + .filter(|&(&idx, slot)| { idx != stream_index && slot.signature == signature }) - .map(|(&idx, slot)| (idx, slot.id.clone())) + .collect(); + if candidates.len() == 1 { + let (&idx, slot) = candidates[0]; + Some((idx, slot.id.clone())) + } else { + None + } }); self.by_stream_index.insert( @@ -985,41 +984,6 @@ mod tests { assert_eq!(bot.name, "Gemini 2.0 Flash"); } - #[test] - fn models_url_keeps_query() { - let url = build_models_url( - "https://generativelanguage.googleapis.com/v1beta?alt=sse", - None, - ) - .expect("failed to build models url"); - - assert!(url.contains("/models?")); - assert!(url.contains("alt=sse")); - } - - #[test] - fn models_url_pagination() { - let url = build_models_url( - "https://generativelanguage.googleapis.com/v1beta", - Some("abc123"), - ) - .expect("failed to build models url"); - - assert!(url.contains("pageToken=abc123")); - } - - #[test] - fn stream_url_format() { - let url = build_stream_url( - "https://generativelanguage.googleapis.com/v1beta", - &BotId::new("models/gemini-2.0-flash"), - ) - .expect("failed to build stream url"); - - assert!(url.contains("/models/gemini-2.0-flash:streamGenerateContent")); - assert!(url.contains("alt=sse")); - } - #[test] fn stream_url_qualified_path() { let url = build_stream_url( @@ -1087,43 +1051,6 @@ mod tests { ); } - #[test] - fn models_filter_by_capability() { - let payload = r#" - { - "models": [ - { - "name": "models/gemini-2.0-flash", - "supportedGenerationMethods": ["generateContent"] - }, - { - "name": "models/text-embedding-004", - "supportedGenerationMethods": ["embedContent"] - } - ] - }"#; - - let bots = parse_models_response(payload).expect("failed to parse"); - assert_eq!(bots.len(), 1, "embedding model should be filtered out"); - - let bot = &bots[0]; - assert!(bot.capabilities.has_capability(&BotCapability::TextInput)); - assert!(bot.capabilities.has_capability(&BotCapability::ToolInput)); - } - - #[test] - fn stream_text_merges_parts() { - let payload = r#" - { - "candidates": [ - { "content": { "parts": [{"text":"Hello "}, {"text":"Gemini"}] } } - ] - }"#; - - let text = parse_stream_text(payload).expect("failed to parse stream payload"); - assert_eq!(text, "Hello Gemini"); - } - #[test] fn request_includes_tools() { let messages = vec![Message { @@ -1257,49 +1184,6 @@ mod tests { ); } - #[test] - fn request_uses_auto_mode() { - let messages = vec![ - Message { - from: EntityId::Bot(BotId::new("gemini-2.0-flash")), - content: MessageContent { - tool_calls: vec![ToolCall { - id: "call-1".to_string(), - name: "get_weather".to_string(), - arguments: serde_json::Map::new(), - ..Default::default() - }], - ..Default::default() - }, - ..Default::default() - }, - Message { - from: EntityId::Tool, - content: MessageContent { - tool_results: vec![ToolResult { - tool_call_id: "call-1".to_string(), - content: r#"{"ok":true}"#.to_string(), - is_error: false, - }], - ..Default::default() - }, - ..Default::default() - }, - ]; - - let tools = vec![Tool { - name: "get_weather".to_string(), - description: Some("Get weather".to_string()), - input_schema: std::sync::Arc::new( - serde_json::from_str(r#"{"type":"object"}"#).expect("invalid schema"), - ), - }]; - - let request = build_generate_request(&messages, &tools).expect("failed to build request"); - let value = serde_json::to_value(request).expect("failed to serialize request"); - assert_eq!(value["toolConfig"]["functionCallingConfig"]["mode"], "AUTO"); - } - #[test] fn delta_extracts_text_and_calls() { let payload = r#" @@ -1328,35 +1212,6 @@ mod tests { ); } - #[test] - fn delta_keeps_protocol_id() { - let payload = r#" - { - "candidates": [ - { - "content": { - "parts": [ - { - "functionCall": { - "id": "protocol-call-42", - "name": "get_weather", - "args": {"city":"Tokyo"} - } - } - ] - } - } - ] - }"#; - - let delta = parse_stream_delta(payload).expect("failed to parse stream payload"); - assert_eq!(delta.function_calls.len(), 1); - assert_eq!( - delta.function_calls[0].id.as_deref(), - Some("protocol-call-42") - ); - } - #[test] fn tool_calls_distinct_across_chunks() { let mut state = GeminiStreamToolCallState::default(); @@ -1505,27 +1360,56 @@ mod tests { } #[test] - fn bot_parts_roundtrip_thought_signature() { - let message = Message { - from: EntityId::Bot(BotId::new("gemini-3-flash-preview")), - content: MessageContent { - tool_calls: vec![ToolCall { - id: "call-1".to_string(), - name: "get_weather".to_string(), - arguments: serde_json::from_str(r#"{"location":"Montevideo"}"#) - .expect("invalid args"), - ..Default::default() - }], - data: Some( - r#"{"gemini_tool_call_thought_signatures":{"call-1":"sig-abc"}}"#.to_string(), - ), - ..Default::default() + fn tool_calls_no_promote_on_ambiguous_signature() { + let mut state = GeminiStreamToolCallState::default(); + + // Two calls with identical signature but different fallback ids. + state.apply_delta(vec![ + GeminiFunctionCallDelta { + id: None, + name: "do_thing".to_string(), + args: serde_json::json!({"x": 1}), + thought_signature: None, }, - ..Default::default() - }; + GeminiFunctionCallDelta { + id: None, + name: "do_thing".to_string(), + args: serde_json::json!({"x": 1}), + thought_signature: None, + }, + ]); + assert_eq!(state.tool_calls().len(), 2); - let parts = as_bot_parts(&message); - let value = serde_json::to_value(parts).expect("failed to serialize parts"); - assert_eq!(value[0]["thoughtSignature"], "sig-abc"); + // Protocol id arrives at a NEW index — two candidates match by signature, + // so promotion must be skipped (ambiguous). + state.apply_delta(vec![ + GeminiFunctionCallDelta { + id: None, + name: "other".to_string(), + args: serde_json::json!({}), + thought_signature: None, + }, + GeminiFunctionCallDelta { + id: None, + name: "other".to_string(), + args: serde_json::json!({}), + thought_signature: None, + }, + GeminiFunctionCallDelta { + id: Some("proto-1".to_string()), + name: "do_thing".to_string(), + args: serde_json::json!({"x": 1}), + thought_signature: None, + }, + ]); + + let calls = state.tool_calls(); + // Both original fallback calls must survive untouched. + let do_things: Vec<_> = calls.iter().filter(|c| c.name == "do_thing").collect(); + assert!(do_things.len() >= 2, "original fallback calls must not be consumed"); + assert!( + do_things.iter().all(|c| c.id != "proto-1" || c.name == "do_thing"), + "no fallback call should have been wrongly renamed" + ); } } diff --git a/src/clients/openai.rs b/src/clients/openai.rs index aaf3178..be0f4ff 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -8,9 +8,9 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::protocol::*; use crate::utils::asynchronous::{BoxPlatformSendFuture, BoxPlatformSendStream}; use crate::utils::{serde::deserialize_null_default, sse::parse_sse}; +use crate::protocol::*; /// The content of a [`ContentPart::ImageUrl`]. #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/src/controllers/chat.rs b/src/controllers/chat.rs index 98d369d..ea8632d 100644 --- a/src/controllers/chat.rs +++ b/src/controllers/chat.rs @@ -448,7 +448,10 @@ impl ChatController { c.dispatch_mutation(VecMutation::Set(bots.unwrap_or_default())); - let messages: Vec<_> = errors.into_iter().map(Message::from_client_error).collect(); + let messages: Vec<_> = errors + .into_iter() + .map(Message::from_client_error) + .collect(); c.dispatch_mutation(VecMutation::Extend(messages)); }); })); @@ -506,7 +509,10 @@ impl ChatController { false } Err(errors) => { - let messages: Vec<_> = errors.into_iter().map(Message::from_client_error).collect(); + let messages: Vec<_> = errors + .into_iter() + .map(Message::from_client_error) + .collect(); self.dispatch_mutation(VecMutation::Extend(messages)); true diff --git a/src/prelude.rs b/src/prelude.rs index be94bd8..4cbc028 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -5,9 +5,9 @@ pub use crate::protocol::*; // These are the clients that are most commonly used. #[cfg(feature = "api-clients")] -pub use crate::clients::gemini::GeminiClient; -#[cfg(feature = "api-clients")] pub use crate::clients::openai::OpenAiClient; +#[cfg(feature = "api-clients")] +pub use crate::clients::gemini::GeminiClient; pub use crate::clients::router::RouterClient; // These other clients are less commonly used.