diff --git a/litellm/litellm_core_utils/get_provider_specific_headers.py b/litellm/litellm_core_utils/get_provider_specific_headers.py index cf9165cfda92..69a7ec720735 100644 --- a/litellm/litellm_core_utils/get_provider_specific_headers.py +++ b/litellm/litellm_core_utils/get_provider_specific_headers.py @@ -10,14 +10,20 @@ def get_provider_specific_headers( custom_llm_provider: Optional[str], ) -> Dict: """ - Get the provider specific headers for the given custom llm provider + Get the provider specific headers for the given custom llm provider. + + Supports comma-separated provider lists for headers that work across multiple providers. Returns: - Optional[Dict]: The provider specific headers for the given custom llm provider + Dict: The provider specific headers for the given custom llm provider """ - if ( - provider_specific_header is not None - and provider_specific_header.get("custom_llm_provider") == custom_llm_provider - ): + if provider_specific_header is None or custom_llm_provider is None: + return {} + + stored_providers = provider_specific_header.get("custom_llm_provider", "") + provider_list = [p.strip() for p in stored_providers.split(",")] + + if custom_llm_provider in provider_list: return provider_specific_header.get("extra_headers", {}) - return {} \ No newline at end of file + + return {} diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 6bdc0e55c613..8f745f244633 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1252,8 +1252,10 @@ def add_provider_specific_headers_to_request( added_header = True if added_header is True: + # Anthropic headers work across multiple providers + # Store as comma-separated list so retrieval can match any of them data["provider_specific_header"] = ProviderSpecificHeader( - custom_llm_provider="anthropic", + custom_llm_provider="anthropic,bedrock,bedrock_converse,vertex_ai", extra_headers=anthropic_headers, ) diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 34a8a9daf86c..eebf306d92bd 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -1,19 +1,21 @@ import asyncio +import json import os import sys -from typing import Any, Dict, Optional, List +from typing import Any, Dict, List, Optional from unittest.mock import Mock -from litellm.proxy.utils import _get_redoc_url, _get_docs_url -import json + import pytest from fastapi import Request +from litellm.proxy.utils import _get_docs_url, _get_redoc_url + sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import litellm -from unittest.mock import MagicMock, patch, AsyncMock +from unittest.mock import AsyncMock, MagicMock, patch +import litellm from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.auth.auth_utils import is_request_body_safe from litellm.proxy.litellm_pre_call_utils import ( @@ -490,8 +492,9 @@ def test_add_litellm_data_for_backend_llm_call( headers, general_settings, expected_data ): import json - from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" @@ -510,8 +513,8 @@ def test_foward_litellm_user_info_to_backend_llm_call(): litellm.add_user_information_to_llm_headers = True - from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" @@ -533,10 +536,10 @@ def test_foward_litellm_user_info_to_backend_llm_call(): def test_update_internal_user_params(): + from litellm.proxy._types import NewUserRequest from litellm.proxy.management_endpoints.internal_user_endpoints import ( _update_internal_new_user_params, ) - from litellm.proxy._types import NewUserRequest litellm.default_internal_user_params = { "max_budget": 100, @@ -559,10 +562,10 @@ def test_update_internal_user_params(): def test_update_internal_new_user_params_with_no_initial_role_set(): + from litellm.proxy._types import NewUserRequest from litellm.proxy.management_endpoints.internal_user_endpoints import ( _update_internal_new_user_params, ) - from litellm.proxy._types import NewUserRequest litellm.default_internal_user_params = { "max_budget": 100, @@ -585,10 +588,10 @@ def test_update_internal_new_user_params_with_no_initial_role_set(): def test_update_internal_new_user_params_with_user_defined_values(): + from litellm.proxy._types import NewUserRequest from litellm.proxy.management_endpoints.internal_user_endpoints import ( _update_internal_new_user_params, ) - from litellm.proxy._types import NewUserRequest litellm.default_internal_user_params = { "max_budget": 100, @@ -610,9 +613,10 @@ def test_update_internal_new_user_params_with_user_defined_values(): @pytest.mark.asyncio async def test_proxy_config_update_from_db(): - from litellm.proxy.proxy_server import ProxyConfig from pydantic import BaseModel + from litellm.proxy.proxy_server import ProxyConfig + proxy_config = ProxyConfig() pc = AsyncMock() @@ -655,10 +659,10 @@ class ReturnValue(BaseModel): @pytest.mark.asyncio async def test_prepare_key_update_data(): + from litellm.proxy._types import UpdateKeyRequest from litellm.proxy.management_endpoints.key_management_endpoints import ( prepare_key_update_data, ) - from litellm.proxy._types import UpdateKeyRequest existing_key_row = MagicMock() data = UpdateKeyRequest(key="test_key", models=["gpt-4"], duration="120s") @@ -935,9 +939,10 @@ def test_enforced_params_check( def test_get_key_models(): - from litellm.proxy.auth.model_checks import get_key_models from collections import defaultdict + from litellm.proxy.auth.model_checks import get_key_models + user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", @@ -959,9 +964,10 @@ def test_get_key_models(): def test_get_team_models(): - from litellm.proxy.auth.model_checks import get_team_models from collections import defaultdict + from litellm.proxy.auth.model_checks import get_team_models + user_api_key_dict = UserAPIKeyAuth( api_key="test_api_key", user_id="test_user_id", @@ -1089,8 +1095,8 @@ def test_get_complete_model_list(proxy_model_list, model_list, provider): """ Test that get_complete_model_list correctly expands model groups like 'openai/*' into individual models with provider prefixes """ - from litellm.proxy.auth.model_checks import get_complete_model_list from litellm import Router + from litellm.proxy.auth.model_checks import get_complete_model_list llm_router = Router(model_list=model_list) @@ -1202,9 +1208,10 @@ def test_proxy_config_state_get_config_state_error(): """ Ensures that get_config_state does not raise an error when the config is not a valid dictionary """ - from litellm.proxy.proxy_server import ProxyConfig import threading + from litellm.proxy.proxy_server import ProxyConfig + test_config = { "callback_list": [ { @@ -1339,8 +1346,8 @@ def test_is_allowed_to_make_key_request(): def test_get_model_group_info(): - from litellm.proxy.proxy_server import _get_model_group_info from litellm import Router + from litellm.proxy.proxy_server import _get_model_group_info router = Router( model_list=[ @@ -1368,10 +1375,11 @@ def test_get_model_group_info(): assert len(model_list) == 1 -import pytest import asyncio -from unittest.mock import AsyncMock, patch import json +from unittest.mock import AsyncMock, patch + +import pytest @pytest.fixture @@ -1444,10 +1452,12 @@ async def test_get_user_info_for_proxy_admin(mock_team_data, mock_key_data): def test_custom_openid_response(): - from litellm.proxy.management_endpoints.ui_sso import generic_response_convertor - from litellm.proxy.management_endpoints.ui_sso import JWTHandler - from litellm.proxy._types import LiteLLM_JWTAuth from litellm.caching import DualCache + from litellm.proxy._types import LiteLLM_JWTAuth + from litellm.proxy.management_endpoints.ui_sso import ( + JWTHandler, + generic_response_convertor, + ) jwt_handler = JWTHandler() jwt_handler.update_environment( @@ -1501,10 +1511,11 @@ def test_update_key_request_validation(): def test_get_temp_budget_increase(): - from litellm.proxy.auth.user_api_key_auth import _get_temp_budget_increase - from litellm.proxy._types import UserAPIKeyAuth from datetime import datetime, timedelta + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.auth.user_api_key_auth import _get_temp_budget_increase + expiry = datetime.now() + timedelta(days=1) expiry_in_isoformat = expiry.isoformat() @@ -1520,11 +1531,12 @@ def test_get_temp_budget_increase(): def test_update_key_budget_with_temp_budget_increase(): + from datetime import datetime, timedelta + + from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import ( _update_key_budget_with_temp_budget_increase, ) - from litellm.proxy._types import UserAPIKeyAuth - from datetime import datetime, timedelta expiry = datetime.now() + timedelta(days=1) expiry_in_isoformat = expiry.isoformat() @@ -1540,7 +1552,7 @@ def test_update_key_budget_with_temp_budget_increase(): assert _update_key_budget_with_temp_budget_increase(valid_token).max_budget == 200 -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import AsyncMock, MagicMock @pytest.mark.asyncio @@ -1581,17 +1593,18 @@ async def test_health_check_not_called_when_disabled(monkeypatch): }, ) def test_custom_openapi(mock_get_openapi_schema): - from litellm.proxy.proxy_server import custom_openapi - from litellm.proxy.proxy_server import app + from litellm.proxy.proxy_server import app, custom_openapi openapi_schema = custom_openapi() assert openapi_schema is not None -import pytest -from unittest.mock import MagicMock, AsyncMock import asyncio from datetime import timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest + from litellm.proxy.utils import ProxyUpdateSpend @@ -1644,6 +1657,7 @@ async def test_spend_logs_cleanup_after_error(): def test_provider_specific_header(): + """Test that provider_specific_header is set correctly for Anthropic headers.""" from litellm.proxy.litellm_pre_call_utils import ( add_provider_specific_headers_to_request, ) @@ -1700,14 +1714,88 @@ def test_provider_specific_header(): data=data, headers=headers, ) + # Verify multi-provider support: anthropic headers work across multiple providers assert data["provider_specific_header"] == { - "custom_llm_provider": "anthropic", + "custom_llm_provider": "anthropic,bedrock,bedrock_converse,vertex_ai", "extra_headers": { "anthropic-beta": "prompt-caching-2024-07-31", }, } +def test_provider_specific_header_multi_provider(): + """Test that provider_specific_header supports multiple providers for Anthropic headers.""" + from litellm.proxy.litellm_pre_call_utils import ( + add_provider_specific_headers_to_request, + ) + + data = { + "model": "gemini-1.5-flash", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Tell me a joke"}], + } + ], + "stream": True, + "proxy_server_request": { + "url": "http://0.0.0.0:4000/v1/chat/completions", + "method": "POST", + "headers": { + "content-type": "application/json", + "anthropic-beta": "context-1m-2025-08-07", + "anthropic-version": "2023-06-01", + "user-agent": "PostmanRuntime/7.32.3", + "accept": "*/*", + "postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529", + "host": "0.0.0.0:4000", + "accept-encoding": "gzip, deflate, br", + "connection": "keep-alive", + "content-length": "240", + }, + "body": { + "model": "gemini-1.5-flash", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Tell me a joke"}], + } + ], + "stream": True, + }, + }, + } + + headers = { + "content-type": "application/json", + "anthropic-beta": "context-1m-2025-08-07", + "anthropic-version": "2023-06-01", + "user-agent": "PostmanRuntime/7.32.3", + "accept": "*/*", + "postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529", + "host": "0.0.0.0:4000", + "accept-encoding": "gzip, deflate, br", + "connection": "keep-alive", + "content-length": "240", + } + + add_provider_specific_headers_to_request( + data=data, + headers=headers, + ) + + # Verify that provider_specific_header contains comma-separated providers + assert "provider_specific_header" in data + assert ( + data["provider_specific_header"]["custom_llm_provider"] + == "anthropic,bedrock,bedrock_converse,vertex_ai" + ) + assert data["provider_specific_header"]["extra_headers"] == { + "anthropic-beta": "context-1m-2025-08-07", + "anthropic-version": "2023-06-01", + } + + from litellm.proxy._types import LiteLLM_UserTable @@ -1928,11 +2016,13 @@ async def test_post_call_failure_hook_auth_error_key_info_route(): Test that post_call_failure_hook does NOT call _handle_logging_proxy_only_error when we get an auth error from /key/info route (since it's not an LLM API route). """ - from litellm.proxy.utils import ProxyLogging - from litellm.proxy._types import ProxyErrorTypes - from litellm.caching.caching import DualCache + from unittest.mock import AsyncMock, Mock, patch + from fastapi import HTTPException - from unittest.mock import Mock, patch, AsyncMock + + from litellm.caching.caching import DualCache + from litellm.proxy._types import ProxyErrorTypes + from litellm.proxy.utils import ProxyLogging # Setup cache = DualCache() @@ -1980,11 +2070,13 @@ async def test_post_call_failure_hook_auth_error_llm_api_route(): Test that post_call_failure_hook DOES call _handle_logging_proxy_only_error when we get an auth error from /v1/chat/completions route (since it is an LLM API route). """ - from litellm.proxy.utils import ProxyLogging - from litellm.proxy._types import ProxyErrorTypes - from litellm.caching.caching import DualCache + from unittest.mock import AsyncMock, Mock, patch + from fastapi import HTTPException - from unittest.mock import Mock, patch, AsyncMock + + from litellm.caching.caching import DualCache + from litellm.proxy._types import ProxyErrorTypes + from litellm.proxy.utils import ProxyLogging # Setup cache = DualCache() diff --git a/tests/test_litellm/litellm_core_utils/test_provider_specific_headers.py b/tests/test_litellm/litellm_core_utils/test_provider_specific_headers.py index aa1d31c6166f..293d6268ebac 100644 --- a/tests/test_litellm/litellm_core_utils/test_provider_specific_headers.py +++ b/tests/test_litellm/litellm_core_utils/test_provider_specific_headers.py @@ -11,14 +11,17 @@ def test_get_provider_specific_headers_matching_provider(self): """Test that the method returns extra_headers when custom_llm_provider matches.""" provider_specific_header: ProviderSpecificHeader = { "custom_llm_provider": "openai", - "extra_headers": {"Authorization": "Bearer token123", "Custom-Header": "value"} + "extra_headers": { + "Authorization": "Bearer token123", + "Custom-Header": "value", + }, } custom_llm_provider = "openai" - + result = ProviderSpecificHeaderUtils.get_provider_specific_headers( provider_specific_header, custom_llm_provider ) - + expected = {"Authorization": "Bearer token123", "Custom-Header": "value"} assert result == expected @@ -27,17 +30,85 @@ def test_get_provider_specific_headers_no_match_or_none(self): # Test case 1: Provider doesn't match provider_specific_header: ProviderSpecificHeader = { "custom_llm_provider": "anthropic", - "extra_headers": {"Authorization": "Bearer token123"} + "extra_headers": {"Authorization": "Bearer token123"}, } custom_llm_provider = "openai" - + result = ProviderSpecificHeaderUtils.get_provider_specific_headers( provider_specific_header, custom_llm_provider ) assert result == {} - + # Test case 2: provider_specific_header is None result = ProviderSpecificHeaderUtils.get_provider_specific_headers( None, "openai" ) assert result == {} + + def test_get_provider_specific_headers_multi_provider_anthropic_to_bedrock(self): + """Test that anthropic headers work with bedrock provider (multi-provider support).""" + provider_specific_header: ProviderSpecificHeader = { + "custom_llm_provider": "anthropic,bedrock,bedrock_converse,vertex_ai", + "extra_headers": {"anthropic-beta": "context-1m-2025-08-07"}, + } + + # Test bedrock provider + result = ProviderSpecificHeaderUtils.get_provider_specific_headers( + provider_specific_header, "bedrock" + ) + assert result == {"anthropic-beta": "context-1m-2025-08-07"} + + # Test anthropic provider + result = ProviderSpecificHeaderUtils.get_provider_specific_headers( + provider_specific_header, "anthropic" + ) + assert result == {"anthropic-beta": "context-1m-2025-08-07"} + + # Test bedrock_converse provider + result = ProviderSpecificHeaderUtils.get_provider_specific_headers( + provider_specific_header, "bedrock_converse" + ) + assert result == {"anthropic-beta": "context-1m-2025-08-07"} + + # Test vertex_ai provider + result = ProviderSpecificHeaderUtils.get_provider_specific_headers( + provider_specific_header, "vertex_ai" + ) + assert result == {"anthropic-beta": "context-1m-2025-08-07"} + + def test_get_provider_specific_headers_multi_provider_no_match(self): + """Test that non-listed providers return empty dict with multi-provider list.""" + provider_specific_header: ProviderSpecificHeader = { + "custom_llm_provider": "anthropic,bedrock,vertex_ai", + "extra_headers": {"anthropic-beta": "test"}, + } + + # Test provider not in list + result = ProviderSpecificHeaderUtils.get_provider_specific_headers( + provider_specific_header, "openai" + ) + assert result == {} + + def test_get_provider_specific_headers_with_spaces(self): + """Test that comma-separated list with spaces is handled correctly.""" + provider_specific_header: ProviderSpecificHeader = { + "custom_llm_provider": "anthropic, bedrock, vertex_ai", + "extra_headers": {"anthropic-beta": "test"}, + } + + result = ProviderSpecificHeaderUtils.get_provider_specific_headers( + provider_specific_header, "bedrock" + ) + assert result == {"anthropic-beta": "test"} + + def test_get_provider_specific_headers_none_custom_llm_provider(self): + """Test that None custom_llm_provider returns empty dict.""" + provider_specific_header: ProviderSpecificHeader = { + "custom_llm_provider": "anthropic", + "extra_headers": {"anthropic-beta": "test"}, + } + + result = ProviderSpecificHeaderUtils.get_provider_specific_headers( + provider_specific_header, None + ) + assert result == {}