Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self) -> None:
class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
_azure_endpoint: httpx.URL | None
_azure_deployment: str | None
_is_v1_api: bool

@override
def _build_request(
Expand All @@ -60,10 +61,12 @@ def _build_request(
*,
retries_taken: int = 0,
) -> httpx.Request:
if options.url in _deployments_endpoints and is_mapping(options.json_data):
model = options.json_data.get("model")
if model is not None and "/deployments" not in str(self.base_url.path):
options.url = f"/deployments/{model}{options.url}"
# v1 API doesn't use /deployments/{model}/ path - model is passed in body
if not getattr(self, '_is_v1_api', False):
if options.url in _deployments_endpoints and is_mapping(options.json_data):
model = options.json_data.get("model")
if model is not None and "/deployments" not in str(self.base_url.path):
options.url = f"/deployments/{model}{options.url}"

return super()._build_request(options, retries_taken=retries_taken)

Expand All @@ -73,6 +76,10 @@ def _prepare_url(self, url: str) -> httpx.URL:
and the API feature being called is **not** a deployments-based endpoint
(i.e. requires /deployments/deployment-name in the URL path).
"""
# v1 API doesn't need URL rewriting - base_url already has /openai/v1/
if getattr(self, '_is_v1_api', False):
return super()._prepare_url(url)

if self._azure_deployment and self._azure_endpoint and url not in _deployments_endpoints:
merge_url = httpx.URL(url)
if merge_url.is_relative_url:
Expand Down Expand Up @@ -208,6 +215,9 @@ def __init__(
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
)

# Check if using v1 API format (new Azure OpenAI API)
_is_v1_api = api_version in ("v1", "latest", "preview")

if default_query is None:
default_query = {"api-version": api_version}
else:
Expand All @@ -222,7 +232,10 @@ def __init__(
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
)

if azure_deployment is not None:
if _is_v1_api:
# v1 API uses /openai/v1/ path without /deployments/
base_url = f"{azure_endpoint.rstrip('/')}/openai/v1"
elif azure_deployment is not None:
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
else:
base_url = f"{azure_endpoint.rstrip('/')}/openai"
Expand Down Expand Up @@ -253,6 +266,7 @@ def __init__(
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
self._is_v1_api = _is_v1_api

@override
def copy(
Expand Down Expand Up @@ -489,6 +503,9 @@ def __init__(
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
)

# Check if using v1 API format (new Azure OpenAI API)
_is_v1_api = api_version in ("v1", "latest", "preview")

if default_query is None:
default_query = {"api-version": api_version}
else:
Expand All @@ -503,7 +520,10 @@ def __init__(
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
)

if azure_deployment is not None:
if _is_v1_api:
# v1 API uses /openai/v1/ path without /deployments/
base_url = f"{azure_endpoint.rstrip('/')}/openai/v1"
elif azure_deployment is not None:
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
else:
base_url = f"{azure_endpoint.rstrip('/')}/openai"
Expand Down Expand Up @@ -534,6 +554,7 @@ def __init__(
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
self._is_v1_api = _is_v1_api

@override
def copy(
Expand Down
106 changes: 106 additions & 0 deletions tests/lib/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,3 +802,109 @@ def test_client_sets_base_url(client: Client) -> None:
)
)
assert req.url == "https://example-resource.azure.openai.com/openai/models?api-version=2024-02-01"


# Tests for v1 API support
class TestAzureV1API:
"""Tests for Azure OpenAI v1/latest/preview API support."""

@pytest.mark.parametrize("api_version", ["v1", "latest", "preview"])
@pytest.mark.parametrize("client_cls", [AzureOpenAI, AsyncAzureOpenAI])
def test_v1_api_base_url(self, api_version: str, client_cls: type[Client]) -> None:
"""v1 API should use /openai/v1/ base URL."""
client = client_cls(
api_version=api_version,
api_key="test",
azure_endpoint="https://example.azure.openai.com",
)
assert "/openai/v1" in str(client.base_url)
assert "/deployments/" not in str(client.base_url)

@pytest.mark.parametrize("api_version", ["v1", "latest", "preview"])
@pytest.mark.parametrize("client_cls", [AzureOpenAI, AsyncAzureOpenAI])
def test_v1_api_no_deployments_path(self, api_version: str, client_cls: type[Client]) -> None:
"""v1 API should NOT add /deployments/{model}/ to the path."""
client = client_cls(
api_version=api_version,
api_key="test",
azure_endpoint="https://example.azure.openai.com",
)
req = client._build_request(
FinalRequestOptions.construct(
method="post",
url="/chat/completions",
json_data={"model": "gpt-4o"},
)
)
assert "/deployments/" not in str(req.url)
assert "/openai/v1/chat/completions" in str(req.url)

@pytest.mark.parametrize("api_version", ["v1", "latest", "preview"])
@pytest.mark.parametrize("client_cls", [AzureOpenAI, AsyncAzureOpenAI])
def test_v1_api_has_query_param(self, api_version: str, client_cls: type[Client]) -> None:
"""v1 API should still include ?api-version= query param."""
client = client_cls(
api_version=api_version,
api_key="test",
azure_endpoint="https://example.azure.openai.com",
)
req = client._build_request(
FinalRequestOptions.construct(
method="post",
url="/chat/completions",
json_data={"model": "gpt-4o"},
)
)
assert f"api-version={api_version}" in str(req.url)

@pytest.mark.parametrize("client_cls", [AzureOpenAI, AsyncAzureOpenAI])
def test_traditional_api_still_works(self, client_cls: type[Client]) -> None:
"""Traditional API should still use /deployments/ path."""
client = client_cls(
api_version="2024-10-21",
api_key="test",
azure_endpoint="https://example.azure.openai.com",
)
req = client._build_request(
FinalRequestOptions.construct(
method="post",
url="/chat/completions",
json_data={"model": "gpt-4o"},
)
)
assert "/deployments/gpt-4o/" in str(req.url)
assert "api-version=2024-10-21" in str(req.url)

@pytest.mark.parametrize("api_version", ["v1", "latest", "preview"])
def test_v1_api_ignores_azure_deployment_param(self, api_version: str) -> None:
"""v1 API should ignore azure_deployment parameter since model is in body."""
client = AzureOpenAI(
api_version=api_version,
api_key="test",
azure_endpoint="https://example.azure.openai.com",
azure_deployment="ignored-deployment",
)
# base_url should still be /openai/v1, not /openai/deployments/ignored-deployment
assert "/openai/v1" in str(client.base_url)
assert "/deployments/" not in str(client.base_url)

@pytest.mark.parametrize("api_version", ["v1", "latest", "preview"])
@pytest.mark.parametrize("client_cls", [AzureOpenAI, AsyncAzureOpenAI])
def test_v1_api_non_deployment_endpoints_keep_v1_path(self, api_version: str, client_cls: type[Client]) -> None:
"""v1 API should keep /v1/ path for non-deployment endpoints like /responses."""
client = client_cls(
api_version=api_version,
api_key="test",
azure_endpoint="https://example.azure.openai.com",
azure_deployment="some-deployment", # Even with deployment param
)
req = client._build_request(
FinalRequestOptions.construct(
method="post",
url="/responses",
json_data={"model": "gpt-4o", "input": "hi"},
)
)
# Should be /openai/v1/responses, NOT /openai/responses
assert "/openai/v1/responses" in str(req.url)
assert "/deployments/" not in str(req.url)