|
| 1 | +use crate::prelude::*; |
| 2 | + |
| 3 | +use super::LlmEmbeddingClient; |
| 4 | +use super::LlmGenerationClient; |
| 5 | +use async_openai::{Client as OpenAIClient, config::AzureConfig}; |
| 6 | +use phf::phf_map; |
| 7 | + |
| 8 | +static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! { |
| 9 | + "text-embedding-3-small" => 1536, |
| 10 | + "text-embedding-3-large" => 3072, |
| 11 | + "text-embedding-ada-002" => 1536, |
| 12 | +}; |
| 13 | + |
| 14 | +pub struct Client { |
| 15 | + client: async_openai::Client<AzureConfig>, |
| 16 | +} |
| 17 | + |
| 18 | +impl Client { |
| 19 | + pub async fn new_azure_openai( |
| 20 | + address: Option<String>, |
| 21 | + api_key: Option<String>, |
| 22 | + api_config: Option<super::LlmApiConfig>, |
| 23 | + ) -> anyhow::Result<Self> { |
| 24 | + let config = match api_config { |
| 25 | + Some(super::LlmApiConfig::AzureOpenAi(config)) => config, |
| 26 | + Some(_) => anyhow::bail!("unexpected config type, expected AzureOpenAiConfig"), |
| 27 | + None => anyhow::bail!("AzureOpenAiConfig is required for Azure OpenAI"), |
| 28 | + }; |
| 29 | + |
| 30 | + let api_base = |
| 31 | + address.ok_or_else(|| anyhow::anyhow!("address is required for Azure OpenAI"))?; |
| 32 | + |
| 33 | + // Default to latest stable API version if not specified |
| 34 | + let api_version = config |
| 35 | + .api_version |
| 36 | + .unwrap_or_else(|| "2024-02-01".to_string()); |
| 37 | + |
| 38 | + let api_key = api_key.or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok()) |
| 39 | + .ok_or_else(|| anyhow::anyhow!("AZURE_OPENAI_API_KEY must be set either via api_key parameter or environment variable"))?; |
| 40 | + |
| 41 | + let azure_config = AzureConfig::new() |
| 42 | + .with_api_base(api_base) |
| 43 | + .with_api_version(api_version) |
| 44 | + .with_deployment_id(config.deployment_id) |
| 45 | + .with_api_key(api_key); |
| 46 | + |
| 47 | + Ok(Self { |
| 48 | + client: OpenAIClient::with_config(azure_config), |
| 49 | + }) |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +#[async_trait] |
| 54 | +impl LlmGenerationClient for Client { |
| 55 | + async fn generate<'req>( |
| 56 | + &self, |
| 57 | + request: super::LlmGenerateRequest<'req>, |
| 58 | + ) -> Result<super::LlmGenerateResponse> { |
| 59 | + let request = &request; |
| 60 | + let response = retryable::run( |
| 61 | + || async { |
| 62 | + let req = super::openai::create_llm_generation_request(request)?; |
| 63 | + let response = self.client.chat().create(req).await?; |
| 64 | + retryable::Ok(response) |
| 65 | + }, |
| 66 | + &retryable::RetryOptions::default(), |
| 67 | + ) |
| 68 | + .await?; |
| 69 | + |
| 70 | + // Extract the response text from the first choice |
| 71 | + let text = response |
| 72 | + .choices |
| 73 | + .into_iter() |
| 74 | + .next() |
| 75 | + .and_then(|choice| choice.message.content) |
| 76 | + .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?; |
| 77 | + |
| 78 | + Ok(super::LlmGenerateResponse { text }) |
| 79 | + } |
| 80 | + |
| 81 | + fn json_schema_options(&self) -> super::ToJsonSchemaOptions { |
| 82 | + super::ToJsonSchemaOptions { |
| 83 | + fields_always_required: true, |
| 84 | + supports_format: false, |
| 85 | + extract_descriptions: false, |
| 86 | + top_level_must_be_object: true, |
| 87 | + supports_additional_properties: true, |
| 88 | + } |
| 89 | + } |
| 90 | +} |
| 91 | + |
| 92 | +#[async_trait] |
| 93 | +impl LlmEmbeddingClient for Client { |
| 94 | + async fn embed_text<'req>( |
| 95 | + &self, |
| 96 | + request: super::LlmEmbeddingRequest<'req>, |
| 97 | + ) -> Result<super::LlmEmbeddingResponse> { |
| 98 | + let response = retryable::run( |
| 99 | + || async { |
| 100 | + let texts: Vec<String> = request.texts.iter().map(|t| t.to_string()).collect(); |
| 101 | + self.client |
| 102 | + .embeddings() |
| 103 | + .create(async_openai::types::CreateEmbeddingRequest { |
| 104 | + model: request.model.to_string(), |
| 105 | + input: async_openai::types::EmbeddingInput::StringArray(texts), |
| 106 | + dimensions: request.output_dimension, |
| 107 | + ..Default::default() |
| 108 | + }) |
| 109 | + .await |
| 110 | + }, |
| 111 | + &retryable::RetryOptions::default(), |
| 112 | + ) |
| 113 | + .await?; |
| 114 | + Ok(super::LlmEmbeddingResponse { |
| 115 | + embeddings: response.data.into_iter().map(|e| e.embedding).collect(), |
| 116 | + }) |
| 117 | + } |
| 118 | + |
| 119 | + fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> { |
| 120 | + DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied() |
| 121 | + } |
| 122 | +} |
0 commit comments