Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 270 additions & 18 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import datetime
import hashlib
import json
import builtins
from typing import Any, Dict, List, Optional, Set, Union, cast

from fastapi import HTTPException
Expand Down Expand Up @@ -42,6 +43,8 @@
from litellm.types.mcp import MCPAuth, MCPStdioConfig
from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer

BASE_EXCEPTION_GROUP_TYPE = getattr(builtins, "BaseExceptionGroup", None)


def _deserialize_env_dict(env_data: Any) -> Optional[Dict[str, str]]:
"""
Expand Down Expand Up @@ -95,6 +98,20 @@ def __init__(self):
}
"""

def _flatten_exception_group(
self, exception: BaseException
) -> List[BaseException]:
"""Recursively flatten ExceptionGroup instances for easier inspection."""
if BASE_EXCEPTION_GROUP_TYPE and isinstance(
exception, BASE_EXCEPTION_GROUP_TYPE
):
flattened: List[BaseException] = []
exc_group = cast(Any, exception)
for inner in exc_group.exceptions:
flattened.extend(self._flatten_exception_group(inner))
return flattened
return [exception]

def get_registry(self) -> Dict[str, MCPServer]:
"""
Get the registered MCP Servers from the registry and union with the config MCP Servers
Expand Down Expand Up @@ -468,8 +485,10 @@ async def get_allowed_mcp_servers(
if len(allowed_mcp_servers) > 0:
return allowed_mcp_servers
else:
user_info = f"User ID: {user_api_key_auth.user_id if user_api_key_auth else 'None'}, Permission ID: {user_api_key_auth.object_permission_id if user_api_key_auth else 'None'}"
verbose_logger.debug(
"No allowed MCP Servers found for user api key auth, returning default registry servers"
f"No allowed MCP Servers found for user api key auth ({user_info}), returning default registry servers. "
f"If you just created an MCP server, ensure your user has been granted access to it."
)
return list(self.get_registry().keys())
except Exception as e:
Expand Down Expand Up @@ -512,7 +531,7 @@ async def list_tools(
Returns:
List[MCPTool]: Combined list of tools from all servers
"""
allowed_mcp_servers = await self.get_allowed_mcp_servers(user_api_key_auth)
allowed_mcp_servers = await self.get_allowed_mcp_servers(user_api_key_auth)

list_tools_result: List[MCPTool] = []
verbose_logger.debug("SERVER MANAGER LISTING TOOLS")
Expand Down Expand Up @@ -1094,15 +1113,39 @@ async def _call_regular_mcp_tool(
"""
# Get server-specific auth header if available
server_auth_header: Optional[Union[Dict[str, str], str]] = None
if mcp_server_auth_headers and mcp_server.alias:
server_auth_header = mcp_server_auth_headers.get(mcp_server.alias)
elif mcp_server_auth_headers and mcp_server.server_name:
server_auth_header = mcp_server_auth_headers.get(mcp_server.server_name)
normalized_auth_headers: Optional[Dict[str, Union[Dict[str, str], str]]] = None
if mcp_server_auth_headers:
normalized_auth_headers = {
key.lower(): value for key, value in mcp_server_auth_headers.items()
}

if normalized_auth_headers and mcp_server.alias:
server_auth_header = normalized_auth_headers.get(
normalize_server_name(mcp_server.alias).lower()
)
elif normalized_auth_headers and mcp_server.server_name:
server_auth_header = normalized_auth_headers.get(
normalize_server_name(mcp_server.server_name).lower()
)

# Fall back to deprecated mcp_auth_header if no server-specific header found
if server_auth_header is None:
server_auth_header = mcp_auth_header

# Track which auth headers are already supplied so we don't overwrite them
primary_auth_header_names: Set[str] = set()
if isinstance(server_auth_header, dict):
primary_auth_header_names = {key.lower() for key in server_auth_header}
elif isinstance(server_auth_header, str):
if mcp_server.auth_type in (
MCPAuth.bearer_token,
MCPAuth.basic,
MCPAuth.authorization,
):
primary_auth_header_names = {"authorization"}
elif mcp_server.auth_type == MCPAuth.api_key:
primary_auth_header_names = {"x-api-key"}

# oauth2 headers
extra_headers: Optional[Dict[str, str]] = None
if mcp_server.auth_type == MCPAuth.oauth2:
Expand All @@ -1111,9 +1154,25 @@ async def _call_regular_mcp_tool(
if mcp_server.extra_headers and raw_headers:
if extra_headers is None:
extra_headers = {}
# Create a lowercase lookup for case-insensitive matching
lowercase_raw_headers = {
key.lower(): value for key, value in raw_headers.items()
}
for header in mcp_server.extra_headers:
if header in raw_headers:
extra_headers[header] = raw_headers[header]
header_lower = header.lower()
if header_lower in primary_auth_header_names:
verbose_logger.debug(
"Skipping forwarding header '%s' to avoid overriding MCP auth",
header,
)
continue

header_value = raw_headers.get(header)
if header_value is None:
header_value = lowercase_raw_headers.get(header_lower)

if header_value is not None:
extra_headers[header] = header_value

client = self._create_mcp_client(
server=mcp_server,
Expand Down Expand Up @@ -1147,6 +1206,38 @@ async def _call_tool_via_client(client, params):
f"Guardrail blocked MCP tool call during result check: {str(e)}"
)
raise e
except BaseException as exc:
if isinstance(exc, asyncio.CancelledError):
raise
if isinstance(exc, (SystemExit, KeyboardInterrupt)):
raise

flattened_exceptions = self._flatten_exception_group(exc)
for inner_exc in flattened_exceptions:
if isinstance(inner_exc, BlockedPiiEntityError):
raise inner_exc
if isinstance(inner_exc, GuardrailRaisedException):
raise inner_exc
if isinstance(inner_exc, HTTPException):
raise inner_exc

root_exception = flattened_exceptions[0] if flattened_exceptions else exc
verbose_logger.error(
"Unhandled errors during MCP tool call for %s: %s",
original_tool_name,
str(root_exception),
exc_info=True,
)

raise HTTPException(
status_code=500,
detail=(
"Failed to call MCP tool "
f"{original_tool_name}: {str(root_exception)}"
),
) from (
root_exception if isinstance(root_exception, Exception) else None
)

# If proxy_logging_obj is None, the tool call result is at index 0
# If proxy_logging_obj is not None, the tool call result is at index 1 (after the during hook task)
Expand Down Expand Up @@ -1190,6 +1281,26 @@ async def call_tool(

# Get the MCP server
mcp_server = self._get_mcp_server_from_tool_name(name)

if mcp_server is None and mcp_server_auth_headers:
mcp_server = self._resolve_server_from_auth_headers(
mcp_server_auth_headers,
preferred_name=server_name_from_prefix,
)

if mcp_server is None:
await self._warm_tool_mapping(
user_api_key_auth=user_api_key_auth,
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
oauth2_headers=oauth2_headers,
raw_headers=raw_headers,
)
# Retry lookup after warming mapping
mcp_server = self._get_mcp_server_from_tool_name(name)
if mcp_server is None:
mcp_server = self._get_mcp_server_from_tool_name(original_tool_name)

if mcp_server is None:
raise ValueError(f"Tool {name} not found")

Expand Down Expand Up @@ -1318,26 +1429,166 @@ def _get_mcp_server_from_tool_name(self, tool_name: str) -> Optional[MCPServer]:
Returns:
MCPServer if found, None otherwise
"""
normalized_tool_name = normalize_server_name(tool_name).lower()

# First try with the original tool name
if tool_name in self.tool_name_to_mcp_server_name_mapping:
server_name = self.tool_name_to_mcp_server_name_mapping[tool_name]
for server in self.get_registry().values():
if normalize_server_name(server.name) == normalize_server_name(
server_name
):
return server
for key, server_name in self.tool_name_to_mcp_server_name_mapping.items():
if normalize_server_name(key).lower() == normalized_tool_name:
target_name = normalize_server_name(server_name).lower()
for server in self.get_registry().values():
if (
normalize_server_name(server.name).lower() == target_name
or (
server.alias
and normalize_server_name(server.alias).lower()
== target_name
)
):
return server

# If not found and tool name is prefixed, try extracting server name from prefix
if is_tool_name_prefixed(tool_name):
_, server_name_from_prefix = get_server_name_prefix_tool_mcp(tool_name)
for server in self.get_registry().values():
if normalize_server_name(server.name) == normalize_server_name(
if normalize_server_name(server.name).lower() == normalize_server_name(
server_name_from_prefix
):
).lower():
return server

return None

def get_mcp_server_by_alias(self, alias: str) -> Optional[MCPServer]:
"""Return server matching alias (case-insensitive)."""
normalized_alias = normalize_server_name(alias).lower()
for server in self.get_registry().values():
if server.alias and normalize_server_name(server.alias).lower() == normalized_alias:
return server
return None

def _resolve_server_from_auth_headers(
self,
mcp_server_auth_headers: Dict[str, Dict[str, str]],
preferred_name: Optional[str] = None,
) -> Optional[MCPServer]:
"""Attempt to resolve server based on auth header aliases."""
normalized_headers = {
normalize_server_name(key).lower(): value
for key, value in mcp_server_auth_headers.items()
}

candidate_names: List[str] = []

if preferred_name:
candidate_names.append(normalize_server_name(preferred_name))

candidate_names.extend(normalized_headers.keys())

candidate_names = list(
dict.fromkeys(name for name in candidate_names if name)
)

for candidate in candidate_names:
server = self.get_mcp_server_by_alias(candidate)
if server:
return server
server_by_name = self.get_mcp_server_by_name(candidate)
if server_by_name:
return server_by_name

return None

async def _warm_tool_mapping(
self,
user_api_key_auth: Optional[UserAPIKeyAuth],
mcp_auth_header: Optional[str],
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]],
oauth2_headers: Optional[Dict[str, str]],
raw_headers: Optional[Dict[str, str]],
) -> None:
"""Fetch tools for allowed servers to refresh tool-to-server mapping."""
try:
allowed_server_ids = await self.get_allowed_mcp_servers(user_api_key_auth)
except Exception as e:
verbose_logger.warning(
"Failed to fetch allowed MCP servers while warming tool mapping: %s",
str(e),
)
return

for server_id in allowed_server_ids:
server = self.get_mcp_server_by_id(server_id)
if server is None:
continue

server_auth_header: Optional[Union[Dict[str, str], str]] = None
primary_auth_header_names: Set[str] = set()
if mcp_server_auth_headers:
normalized_auth_headers = {
normalize_server_name(key).lower(): value
for key, value in mcp_server_auth_headers.items()
}
if server.alias:
server_auth_header = normalized_auth_headers.get(
normalize_server_name(server.alias).lower()
)
if server_auth_header is None and server.server_name:
server_auth_header = normalized_auth_headers.get(
normalize_server_name(server.server_name).lower()
)

if server_auth_header is None:
server_auth_header = mcp_auth_header

if isinstance(server_auth_header, dict):
primary_auth_header_names = {
key.lower() for key in server_auth_header
}
elif isinstance(server_auth_header, str):
if server.auth_type in (
MCPAuth.bearer_token,
MCPAuth.basic,
MCPAuth.authorization,
):
primary_auth_header_names = {"authorization"}
elif server.auth_type == MCPAuth.api_key:
primary_auth_header_names = {"x-api-key"}

extra_headers: Optional[Dict[str, str]] = None
if server.auth_type == MCPAuth.oauth2:
extra_headers = oauth2_headers

if server.extra_headers and raw_headers:
if extra_headers is None:
extra_headers = {}
lowercase_raw_headers = {
key.lower(): value for key, value in raw_headers.items()
}
for header in server.extra_headers:
header_lower = header.lower()
if header_lower in primary_auth_header_names:
verbose_logger.debug(
"Skipping forwarding header '%s' while warming mapping to avoid overriding MCP auth",
header,
)
continue
if header_lower in lowercase_raw_headers:
extra_headers[header] = lowercase_raw_headers[header_lower]
elif header in raw_headers:
extra_headers[header] = raw_headers[header]

try:
await self._get_tools_from_server(
server=server,
mcp_auth_header=server_auth_header,
extra_headers=extra_headers,
)
except Exception as e:
verbose_logger.debug(
"Failed to warm tool mapping for server %s: %s",
server.name,
str(e),
)

async def _add_mcp_servers_from_db_to_in_memory_registry(self):
from litellm.proxy._experimental.mcp_server.db import get_all_mcp_servers
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
Expand Down Expand Up @@ -1386,9 +1637,10 @@ def get_mcp_server_by_name(self, server_name: str) -> Optional[MCPServer]:
"""
Get the MCP Server from the server name
"""
target_name = normalize_server_name(server_name).lower()
registry = self.get_registry()
for server in registry.values():
if server.server_name == server_name:
if server.server_name and normalize_server_name(server.server_name).lower() == target_name:
return server
return None

Expand Down
Loading
Loading