Skip to content

Commit b91c846

Browse files
committed
initial commit add support for azureopenai and pass precommit
1 parent a42f711 commit b91c846

File tree

5 files changed

+214
-2
lines changed

5 files changed

+214
-2
lines changed

docs/docs/ai/llm.mdx

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ We support the following types of LLM APIs:
2020
| API Name | `LlmApiType` enum | Text Generation | Text Embedding |
2121
|----------|---------------------|--------------------|--------------------|
2222
| [OpenAI](#openai) | `LlmApiType.OPENAI` |||
23+
| [Azure OpenAI](#azure-openai) | `LlmApiType.AZURE_OPENAI` |||
2324
| [Ollama](#ollama) | `LlmApiType.OLLAMA` |||
2425
| [Google Gemini](#google-gemini) | `LlmApiType.GEMINI` |||
2526
| [Vertex AI](#vertex-ai) | `LlmApiType.VERTEX_AI` |||
@@ -116,6 +117,67 @@ cocoindex.functions.EmbedText(
116117
</TabItem>
117118
</Tabs>
118119

120+
### Azure OpenAI
121+
122+
[Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service) is Microsoft's cloud service offering OpenAI models through Azure.
123+
124+
To use the Azure OpenAI API:
125+
126+
1. Create an Azure account and set up an Azure OpenAI resource in the [Azure Portal](https://portal.azure.com/).
127+
2. Deploy a model (e.g., GPT-4, text-embedding-ada-002) to your Azure OpenAI resource.
128+
3. Get your API key from the Azure Portal under your Azure OpenAI resource.
129+
4. Set the environment variable `AZURE_OPENAI_API_KEY` to your API key.
130+
131+
Spec for Azure OpenAI requires:
132+
- `address` (type: `str`, required): The base URL of your Azure OpenAI resource, e.g., `https://your-resource-name.openai.azure.com`.
133+
- `api_config` (type: `cocoindex.llm.AzureOpenAiConfig`, required): Configuration with the following fields:
134+
- `deployment_id` (type: `str`, required): The deployment name/ID you created in Azure OpenAI Studio.
135+
- `api_version` (type: `str`, optional): The API version to use. Defaults to `2024-02-01` if not specified.
136+
137+
For text generation, a spec for Azure OpenAI looks like this:
138+
139+
<Tabs>
140+
<TabItem value="python" label="Python" default>
141+
142+
```python
143+
cocoindex.LlmSpec(
144+
api_type=cocoindex.LlmApiType.AZURE_OPENAI,
145+
model="gpt-4o", # This is the base model name
146+
address="https://your-resource-name.openai.azure.com",
147+
api_config=cocoindex.llm.AzureOpenAiConfig(
148+
deployment_id="your-deployment-name",
149+
api_version="2024-02-01", # Optional
150+
),
151+
)
152+
```
153+
154+
</TabItem>
155+
</Tabs>
156+
157+
For text embedding, a spec for Azure OpenAI looks like this:
158+
159+
<Tabs>
160+
<TabItem value="python" label="Python" default>
161+
162+
```python
163+
cocoindex.functions.EmbedText(
164+
api_type=cocoindex.LlmApiType.AZURE_OPENAI,
165+
model="text-embedding-3-small",
166+
address="https://your-resource-name.openai.azure.com",
167+
output_dimension=1536, # Optional, use the default output dimension if not specified
168+
api_config=cocoindex.llm.AzureOpenAiConfig(
169+
deployment_id="your-embedding-deployment-name",
170+
),
171+
)
172+
```
173+
174+
</TabItem>
175+
</Tabs>
176+
177+
:::note
178+
Azure OpenAI uses deployment names instead of direct model names in API calls. The `deployment_id` in the config should match the deployment you created in Azure OpenAI Studio.
179+
:::
180+
119181
### Ollama
120182

121183
[Ollama](https://ollama.com/) allows you to run LLM models on your local machine easily. To get started:

python/cocoindex/llm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class LlmApiType(Enum):
1717
VOYAGE = "Voyage"
1818
VLLM = "Vllm"
1919
BEDROCK = "Bedrock"
20+
AZURE_OPENAI = "AzureOpenAi"
2021

2122

2223
@dataclass
@@ -39,6 +40,16 @@ class OpenAiConfig:
3940
project_id: str | None = None
4041

4142

43+
@dataclass
44+
class AzureOpenAiConfig:
45+
"""A specification for an Azure OpenAI LLM."""
46+
47+
kind = "AzureOpenAi"
48+
49+
deployment_id: str
50+
api_version: str | None = None
51+
52+
4253
@dataclass
4354
class LlmSpec:
4455
"""A specification for a LLM."""
@@ -47,4 +58,4 @@ class LlmSpec:
4758
model: str
4859
address: str | None = None
4960
api_key: TransientAuthEntryReference[str] | None = None
50-
api_config: VertexAiConfig | OpenAiConfig | None = None
61+
api_config: VertexAiConfig | OpenAiConfig | AzureOpenAiConfig | None = None
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
use crate::prelude::*;
2+
3+
use super::LlmEmbeddingClient;
4+
use super::LlmGenerationClient;
5+
use async_openai::{Client as OpenAIClient, config::AzureConfig};
6+
use phf::phf_map;
7+
8+
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
9+
"text-embedding-3-small" => 1536,
10+
"text-embedding-3-large" => 3072,
11+
"text-embedding-ada-002" => 1536,
12+
};
13+
14+
pub struct Client {
15+
client: async_openai::Client<AzureConfig>,
16+
}
17+
18+
impl Client {
19+
pub async fn new_azure_openai(
20+
address: Option<String>,
21+
api_key: Option<String>,
22+
api_config: Option<super::LlmApiConfig>,
23+
) -> anyhow::Result<Self> {
24+
let config = match api_config {
25+
Some(super::LlmApiConfig::AzureOpenAi(config)) => config,
26+
Some(_) => anyhow::bail!("unexpected config type, expected AzureOpenAiConfig"),
27+
None => anyhow::bail!("AzureOpenAiConfig is required for Azure OpenAI"),
28+
};
29+
30+
let api_base =
31+
address.ok_or_else(|| anyhow::anyhow!("address is required for Azure OpenAI"))?;
32+
33+
// Default to latest stable API version if not specified
34+
let api_version = config
35+
.api_version
36+
.unwrap_or_else(|| "2024-02-01".to_string());
37+
38+
let api_key = api_key.or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok())
39+
.ok_or_else(|| anyhow::anyhow!("AZURE_OPENAI_API_KEY must be set either via api_key parameter or environment variable"))?;
40+
41+
let azure_config = AzureConfig::new()
42+
.with_api_base(api_base)
43+
.with_api_version(api_version)
44+
.with_deployment_id(config.deployment_id)
45+
.with_api_key(api_key);
46+
47+
Ok(Self {
48+
client: OpenAIClient::with_config(azure_config),
49+
})
50+
}
51+
}
52+
53+
#[async_trait]
54+
impl LlmGenerationClient for Client {
55+
async fn generate<'req>(
56+
&self,
57+
request: super::LlmGenerateRequest<'req>,
58+
) -> Result<super::LlmGenerateResponse> {
59+
let request = &request;
60+
let response = retryable::run(
61+
|| async {
62+
let req = super::openai::create_llm_generation_request(request)?;
63+
let response = self.client.chat().create(req).await?;
64+
retryable::Ok(response)
65+
},
66+
&retryable::RetryOptions::default(),
67+
)
68+
.await?;
69+
70+
// Extract the response text from the first choice
71+
let text = response
72+
.choices
73+
.into_iter()
74+
.next()
75+
.and_then(|choice| choice.message.content)
76+
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
77+
78+
Ok(super::LlmGenerateResponse { text })
79+
}
80+
81+
fn json_schema_options(&self) -> super::ToJsonSchemaOptions {
82+
super::ToJsonSchemaOptions {
83+
fields_always_required: true,
84+
supports_format: false,
85+
extract_descriptions: false,
86+
top_level_must_be_object: true,
87+
supports_additional_properties: true,
88+
}
89+
}
90+
}
91+
92+
#[async_trait]
93+
impl LlmEmbeddingClient for Client {
94+
async fn embed_text<'req>(
95+
&self,
96+
request: super::LlmEmbeddingRequest<'req>,
97+
) -> Result<super::LlmEmbeddingResponse> {
98+
let response = retryable::run(
99+
|| async {
100+
let texts: Vec<String> = request.texts.iter().map(|t| t.to_string()).collect();
101+
self.client
102+
.embeddings()
103+
.create(async_openai::types::CreateEmbeddingRequest {
104+
model: request.model.to_string(),
105+
input: async_openai::types::EmbeddingInput::StringArray(texts),
106+
dimensions: request.output_dimension,
107+
..Default::default()
108+
})
109+
.await
110+
},
111+
&retryable::RetryOptions::default(),
112+
)
113+
.await?;
114+
Ok(super::LlmEmbeddingResponse {
115+
embeddings: response.data.into_iter().map(|e| e.embedding).collect(),
116+
})
117+
}
118+
119+
fn get_default_embedding_dimension(&self, model: &str) -> Option<u32> {
120+
DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied()
121+
}
122+
}

rust/cocoindex/src/llm/mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub enum LlmApiType {
1919
Vllm,
2020
VertexAi,
2121
Bedrock,
22+
AzureOpenAi,
2223
}
2324

2425
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -33,11 +34,18 @@ pub struct OpenAiConfig {
3334
pub project_id: Option<String>,
3435
}
3536

37+
#[derive(Debug, Clone, Serialize, Deserialize)]
38+
pub struct AzureOpenAiConfig {
39+
pub deployment_id: String,
40+
pub api_version: Option<String>,
41+
}
42+
3643
#[derive(Debug, Clone, Serialize, Deserialize)]
3744
#[serde(tag = "kind")]
3845
pub enum LlmApiConfig {
3946
VertexAi(VertexAiConfig),
4047
OpenAi(OpenAiConfig),
48+
AzureOpenAi(AzureOpenAiConfig),
4149
}
4250

4351
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -108,6 +116,7 @@ pub trait LlmEmbeddingClient: Send + Sync {
108116
}
109117

110118
mod anthropic;
119+
mod azureopenai;
111120
mod bedrock;
112121
mod gemini;
113122
mod litellm;
@@ -147,6 +156,10 @@ pub async fn new_llm_generation_client(
147156
Box::new(openrouter::Client::new_openrouter(address, api_key).await?)
148157
as Box<dyn LlmGenerationClient>
149158
}
159+
LlmApiType::AzureOpenAi => {
160+
Box::new(azureopenai::Client::new_azure_openai(address, api_key, api_config).await?)
161+
as Box<dyn LlmGenerationClient>
162+
}
150163
LlmApiType::Voyage => {
151164
api_bail!("Voyage is not supported for generation")
152165
}
@@ -182,6 +195,10 @@ pub async fn new_llm_embedding_client(
182195
Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?)
183196
as Box<dyn LlmEmbeddingClient>
184197
}
198+
LlmApiType::AzureOpenAi => {
199+
Box::new(azureopenai::Client::new_azure_openai(address, api_key, api_config).await?)
200+
as Box<dyn LlmEmbeddingClient>
201+
}
185202
LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic | LlmApiType::Bedrock => {
186203
api_bail!("Embedding is not supported for API type {:?}", api_type)
187204
}

rust/cocoindex/src/llm/openai.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ impl Client {
6767
}
6868
}
6969

70-
fn create_llm_generation_request(
70+
pub(super) fn create_llm_generation_request(
7171
request: &super::LlmGenerateRequest,
7272
) -> Result<CreateChatCompletionRequest> {
7373
let mut messages = Vec::new();

0 commit comments

Comments
 (0)