Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
524 changes: 524 additions & 0 deletions MIGRATION_MCP_V2.md

Large diffs are not rendered by default.

50 changes: 9 additions & 41 deletions cpex/framework/external/mcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
# Third-Party
import httpx
import orjson
from mcp import ClientSession, McpError, StdioServerParameters
from mcp import ClientSession, MCPError, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import TextContent
from mcp.client.streamable_http import streamable_http_client
from mcp_types import TextContent

# First-Party
from cpex.framework.base import HookRef, Plugin, PluginRef
Expand Down Expand Up @@ -82,9 +82,6 @@ def __init__(self, config: PluginConfig) -> None:
self._stdio_ready: Optional[asyncio.Event] = None
self._stdio_stop: Optional[asyncio.Event] = None
self._stdio_error: Optional[BaseException] = None
self._get_session_id: Optional[Callable[[], str | None]] = None
self._session_id: Optional[str] = None
self._http_client_factory: Optional[Callable[..., httpx.AsyncClient]] = None
self._reconnect_attempts: int = 3
self._reconnect_delay: float = 0.1
self._reconnect_lock: asyncio.Lock = asyncio.Lock()
Expand Down Expand Up @@ -373,23 +370,20 @@ def _tls_httpx_client_factory(

return httpx.AsyncClient(**kwargs)

self._http_client_factory = _tls_httpx_client_factory
max_retries = 3
base_delay = 1.0

for attempt in range(max_retries):
try:
client_factory = _tls_httpx_client_factory
streamable_client = streamablehttp_client(
uri, httpx_client_factory=client_factory, terminate_on_close=False
http_client_instance = _tls_httpx_client_factory()
streamable_client = streamable_http_client(
uri, http_client=http_client_instance, terminate_on_close=True
)
http_transport = await self._exit_stack.enter_async_context(streamable_client)
self._http, self._write, get_session_id = http_transport
self._get_session_id = get_session_id
self._http, self._write = http_transport
self._session = await self._exit_stack.enter_async_context(ClientSession(self._http, self._write))

await self._session.initialize()
self._session_id = self._get_session_id() if self._get_session_id else None
response = await self._session.list_tools()
tools = response.tools
logger.info(
Expand Down Expand Up @@ -446,8 +440,6 @@ async def _cleanup_session(self) -> None:
self._http = None
self._write = None
self._stdio = None
self._get_session_id = None
self._session_id = None

async def _reconnect_session(self) -> None:
"""Tear down old session and reconnect to MCP server with linear backoff.
Expand Down Expand Up @@ -570,8 +562,8 @@ async def _execute_call() -> PluginResult:
) from reconn_err
logger.exception(pe)
raise
except McpError as e:
logger.warning("McpError for plugin %s: %s", self.name, e)
except MCPError as e:
logger.warning("MCPError for plugin %s: %s", self.name, e)
try:
async with self._reconnect_lock:
await self._reconnect_session()
Expand Down Expand Up @@ -637,30 +629,6 @@ async def shutdown(self) -> None:

if self._exit_stack:
await self._exit_stack.aclose()
if self._config and self._config.mcp and self._config.mcp.proto == TransportType.STREAMABLEHTTP:
await self.__terminate_http_session()
self._get_session_id = None
self._session_id = None
self._http_client_factory = None

async def __terminate_http_session(self) -> None:
"""Terminate streamable HTTP session explicitly to avoid lingering server state."""
if not self._session_id or not self._config or not self._config.mcp or not self._config.mcp.url:
return
# Third-Party
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER # pylint: disable=import-outside-toplevel

client_factory = self._http_client_factory
try:
if client_factory:
client = client_factory()
else:
client = httpx.AsyncClient(follow_redirects=True)
async with client:
headers = {MCP_SESSION_ID_HEADER: self._session_id}
await client.delete(self._config.mcp.url, headers=headers)
except Exception as exc:
logger.debug("Failed to terminate streamable HTTP session: %s", exc)


class ExternalHookRef(HookRef):
Expand Down
84 changes: 43 additions & 41 deletions cpex/framework/external/mcp/server/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
SPDX-License-Identifier: Apache-2.0
Authors: Fred Araujo, Teryl Taylor

MCP Plugin Runtime using FastMCP with SSL/TLS support.
MCP Plugin Runtime using MCPServer with SSL/TLS support.

This runtime does the following:
- Uses FastMCP from the MCP Python SDK
- Uses MCPServer from the MCP Python SDK
- Supports both mTLS and non-mTLS configurations
- Reads configuration from PLUGINS_SERVER_* environment variables or uses configurations
the plugin config.yaml
Expand All @@ -19,7 +19,7 @@

>>> from cpex.framework.models import MCPServerConfig
>>> config = MCPServerConfig(host="localhost", port=8000)
>>> server = SSLCapableFastMCP(server_config=config, name="TestServer")
>>> server = SSLCapableMCPServer(server_config=config, name="TestServer")
>>> server.settings.host
'localhost'
>>> server.settings.port
Expand All @@ -29,7 +29,7 @@

>>> from cpex.framework.models import MCPServerConfig
>>> config = MCPServerConfig(host="127.0.0.1", port=8000, tls=None)
>>> server = SSLCapableFastMCP(server_config=config, name="NoTLSServer")
>>> server = SSLCapableMCPServer(server_config=config, name="NoTLSServer")
>>> ssl_config = server._get_ssl_config()
>>> ssl_config
{}
Expand All @@ -38,17 +38,17 @@

>>> from cpex.framework.models import MCPServerConfig
>>> config = MCPServerConfig(host="localhost", port=9000)
>>> server = SSLCapableFastMCP(server_config=config, name="ConfigTest")
>>> server = SSLCapableMCPServer(server_config=config, name="ConfigTest")
>>> server.server_config.host
'localhost'
>>> server.server_config.port
9000

Settings are properly passed to FastMCP:
Settings are properly passed to MCPServer:

>>> from cpex.framework.models import MCPServerConfig
>>> config = MCPServerConfig(host="0.0.0.0", port=8080)
>>> server = SSLCapableFastMCP(server_config=config, name="SettingsTest")
>>> server = SSLCapableMCPServer(server_config=config, name="SettingsTest")
>>> server.settings.host
'0.0.0.0'
>>> server.settings.port
Expand All @@ -66,7 +66,7 @@

# Third-Party
from fastapi import Response, status
from mcp.server.fastmcp import FastMCP
from mcp.server.mcpserver import MCPServer
from mcp.server.transport_security import TransportSecuritySettings
from prometheus_client import REGISTRY, Gauge, generate_latest

Expand Down Expand Up @@ -185,15 +185,15 @@ async def invoke_hook(hook_type: str, plugin_name: str, payload: Dict[str, Any],
return await SERVER.invoke_hook(hook_type, plugin_name, payload, context)


class SSLCapableFastMCP(FastMCP):
"""FastMCP server with SSL/TLS support using MCPServerConfig.
class SSLCapableMCPServer(MCPServer):
"""MCPServer with SSL/TLS support using MCPServerConfig.

Examples:
Create an SSL-capable FastMCP server:
Create an SSL-capable MCPServer:

>>> from cpex.framework.models import MCPServerConfig
>>> config = MCPServerConfig(host="127.0.0.1", port=8000)
>>> server = SSLCapableFastMCP(server_config=config, name="TestServer")
>>> server = SSLCapableMCPServer(server_config=config, name="TestServer")
>>> server.settings.host
'127.0.0.1'
>>> server.settings.port
Expand All @@ -205,13 +205,13 @@ def __init__(self, server_config: MCPServerConfig, *args, **kwargs):

Args:
server_config: the MCP server configuration including mTLS information.
*args: Additional positional arguments passed to FastMCP.
**kwargs: Additional keyword arguments passed to FastMCP.
*args: Additional positional arguments passed to MCPServer.
**kwargs: Additional keyword arguments passed to MCPServer.

Examples:
>>> from cpex.framework.models import MCPServerConfig
>>> config = MCPServerConfig(host="0.0.0.0", port=9000)
>>> server = SSLCapableFastMCP(server_config=config, name="PluginServer")
>>> server = SSLCapableMCPServer(server_config=config, name="PluginServer")
>>> server.server_config.host
'0.0.0.0'
>>> server.server_config.port
Expand All @@ -220,13 +220,14 @@ def __init__(self, server_config: MCPServerConfig, *args, **kwargs):
# Load server config from environment

self.server_config = server_config
# Override FastMCP settings with our server config
if "host" not in kwargs:
kwargs["host"] = self.server_config.host
if "port" not in kwargs:
kwargs["port"] = self.server_config.port
if self.server_config.uds and kwargs.get("transport_security") is None:
kwargs["transport_security"] = TransportSecuritySettings(
# MCPServer v2 does not accept host/port/transport_security in __init__;
# transport_security is passed to streamable_http_app(), host/port to run methods.
kwargs.pop("host", None)
kwargs.pop("port", None)

transport_security = kwargs.pop("transport_security", None)
if self.server_config.uds and transport_security is None:
transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=True,
allowed_hosts=[
"127.0.0.1",
Expand All @@ -245,6 +246,7 @@ def __init__(self, server_config: MCPServerConfig, *args, **kwargs):
"http://[::1]:*",
],
)
self._transport_security = transport_security

super().__init__(*args, **kwargs)

Expand All @@ -257,7 +259,7 @@ def _get_ssl_config(self) -> dict:
Examples:
>>> from cpex.framework.models import MCPServerConfig
>>> config = MCPServerConfig(host="127.0.0.1", port=8000, tls=None)
>>> server = SSLCapableFastMCP(server_config=config, name="TestServer")
>>> server = SSLCapableMCPServer(server_config=config, name="TestServer")
>>> ssl_config = server._get_ssl_config()
>>> ssl_config
{}
Expand Down Expand Up @@ -361,10 +363,10 @@ async def metrics_disabled():
# Create a minimal Starlette app with only the health endpoint
health_app = Starlette(routes=routes)

logger.info(f"Starting HTTP health check server on {self.settings.host}:{health_port}")
logger.info(f"Starting HTTP health check server on { self.server_config.host}:{health_port}")
config = uvicorn.Config(
app=health_app,
host=self.settings.host,
host= self.server_config.host,
port=health_port,
log_level="warning", # Reduce noise from health checks
)
Expand All @@ -379,13 +381,13 @@ async def run_streamable_http_async(self) -> None:

>>> from cpex.framework.models import MCPServerConfig
>>> config = MCPServerConfig(host="0.0.0.0", port=9000)
>>> server = SSLCapableFastMCP(server_config=config, name="HTTPServer")
>>> server = SSLCapableMCPServer(server_config=config, name="HTTPServer")
>>> server.settings.host
'0.0.0.0'
>>> server.settings.port
9000
"""
starlette_app = self.streamable_http_app()
starlette_app = self.streamable_http_app(transport_security=getattr(self, '_transport_security', None))

# Add health check endpoint to main app
# Third-Party
Expand Down Expand Up @@ -438,8 +440,8 @@ async def metrics_disabled():
ssl_config = self._get_ssl_config()
config_kwargs = {
"app": starlette_app,
"host": self.settings.host,
"port": self.settings.port,
"host": self.server_config.host,
"port": self.server_config.port,
"log_level": self.settings.log_level.lower(),
}
config_kwargs.update(ssl_config)
Expand All @@ -450,13 +452,13 @@ async def metrics_disabled():
config_kwargs["uds"] = self.server_config.uds
logger.info(f"Starting plugin server on unix socket {self.server_config.uds}")
else:
logger.info(f"Starting plugin server on {self.settings.host}:{self.settings.port}")
logger.info(f"Starting plugin server on { self.server_config.host}:{ self.server_config.port}")
config = uvicorn.Config(**config_kwargs) # type: ignore[arg-type]
server = uvicorn.Server(config)

# If SSL is enabled, start a separate HTTP health check server
if ssl_config and not self.server_config.uds:
health_port = self.settings.port + 1000 # Use port+1000 for health checks
health_port = self.server_config.port + 1000 # Use port+1000 for health checks
logger.info(f"SSL enabled - starting separate HTTP health check on port {health_port}")
# Run both servers concurrently
await asyncio.gather(server.serve(), self._start_health_check_server(health_port))
Expand All @@ -466,7 +468,7 @@ async def metrics_disabled():


async def run() -> None:
"""Run the external plugin server with FastMCP.
"""Run the external plugin server with MCPServer.

Supports both stdio and HTTP transports. Auto-detects transport based on stdin
(if stdin is not a TTY, uses stdio mode), or you can explicitly set PLUGINS_TRANSPORT.
Expand All @@ -491,7 +493,7 @@ async def run() -> None:
>>> SERVER is None
True

FastMCP server names are defined as constants:
MCPServer names are defined as constants:

>>> from cpex.framework.constants import MCP_SERVER_NAME
>>> isinstance(MCP_SERVER_NAME, str)
Expand Down Expand Up @@ -524,33 +526,33 @@ async def run() -> None:

try:
if transport == "stdio":
# Create basic FastMCP server for stdio (no SSL support needed for stdio)
mcp = FastMCP(
# Create basic MCPServer for stdio (no SSL support needed for stdio)
mcp = MCPServer(
name=MCP_SERVER_NAME,
instructions=MCP_SERVER_INSTRUCTIONS,
)

# Register module-level tool functions with FastMCP
# Register module-level tool functions with MCPServer
mcp.tool(name=GET_PLUGIN_CONFIGS)(get_plugin_configs)
mcp.tool(name=GET_PLUGIN_CONFIG)(get_plugin_config)
mcp.tool(name=INVOKE_HOOK)(invoke_hook)
# set the plugin_info gauge on startup
PLUGIN_INFO.labels(server_name=MCP_SERVER_NAME, transport="stdio", ssl_enabled="false").set(1)

# Run with stdio transport
logger.info("Starting MCP plugin server with FastMCP (stdio transport)")
logger.info("Starting MCP plugin server with MCPServer (stdio transport)")
await mcp.run_stdio_async()

else: # http or streamablehttp
server_config: MCPServerConfig = SERVER.get_server_config()
# Create FastMCP server with SSL support
mcp = SSLCapableFastMCP(
# Create MCPServer with SSL support
mcp = SSLCapableMCPServer(
server_config,
name=MCP_SERVER_NAME,
instructions=MCP_SERVER_INSTRUCTIONS,
)

# Register module-level tool functions with FastMCP
# Register module-level tool functions with MCPServer
mcp.tool(name=GET_PLUGIN_CONFIGS)(get_plugin_configs)
mcp.tool(name=GET_PLUGIN_CONFIG)(get_plugin_config)
mcp.tool(name=INVOKE_HOOK)(invoke_hook)
Expand All @@ -564,7 +566,7 @@ async def run() -> None:
f"Prometheus metrics available at http://{server_config.host}:{server_config.port}/metrics/prometheus"
)
# Run with streamable-http transport
logger.info("Starting MCP plugin server with FastMCP (HTTP transport)")
logger.info("Starting MCP plugin server with MCPServer (HTTP transport)")
await mcp.run_streamable_http_async()

except Exception:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dependencies = [
"httpx>=0.28.1",
"httpx[http2]>=0.28.1",
"jinja2>=3.1.6",
"mcp>=1.26.0",
"mcp==2.0.0b1",
"mcp-types==2.0.0b1",
"orjson>=3.11.7",
"prometheus-fastapi-instrumentator>=7.1.0",
"prometheus_client>=0.24.1",
Expand Down
Loading