From 2996b5b225c05a2f0d9c00e2727d5f8f0105e9c9 Mon Sep 17 00:00:00 2001 From: luolingchun Date: Thu, 21 May 2026 15:20:10 +0800 Subject: [PATCH] Fix with starlette > 1 --- star_openapi/cli.py | 2 +- star_openapi/openapi.py | 21 ++++++++++++++++++--- star_openapi/router.py | 2 +- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/star_openapi/cli.py b/star_openapi/cli.py index dae4f63..5e7af67 100644 --- a/star_openapi/cli.py +++ b/star_openapi/cli.py @@ -431,7 +431,7 @@ def run_command( ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs, ssl_ciphers=ssl_ciphers, - headers=[header.split(":", 1) for header in headers], # type: ignore[misc] + headers=[header.split(":", 1) for header in headers], # type: ignore use_colors=use_colors, factory=factory, app_dir=app_dir, diff --git a/star_openapi/openapi.py b/star_openapi/openapi.py index 5af1142..5542510 100644 --- a/star_openapi/openapi.py +++ b/star_openapi/openapi.py @@ -1,4 +1,4 @@ -from collections.abc import Callable +from collections.abc import Awaitable, Callable from http import HTTPMethod from importlib import import_module from importlib.metadata import entry_points @@ -10,6 +10,7 @@ from starlette.applications import Starlette from starlette.responses import HTMLResponse, JSONResponse from starlette.routing import Mount, Route +from starlette.websockets import WebSocket from .cli import cli from .config import Config @@ -61,7 +62,7 @@ def __init__( doc_ui: bool = True, doc_prefix: str = "/openapi", doc_url: str = "/openapi.json", - **kwargs, + **kwargs: Any, ): """ OpenAPI class that provides REST API functionality along with Swagger UI and Redoc, etc. @@ -360,6 +361,20 @@ def _collect_openapi_info( else: return parse_parameters(func, doc_ui=False) + def _add_websocket_route( + self, + path: str, + endpoint: Callable[[WebSocket], Awaitable[None]], + name: str | None = None, + ) -> None: + if not path.startswith("/"): + origin_path = path + path = "/" + path + else: + origin_path = path + route = APIWebSocketRoute(path=path, origin_path=origin_path, endpoint=endpoint, name=name) + self.routes.append(route) + def get( self, rule: str, @@ -685,7 +700,7 @@ def websocket( ): def decorator(func) -> Callable: endpoint = create_websocket_endpoint(func) - self.add_websocket_route(rule, endpoint, name=name) + self._add_websocket_route(rule, endpoint, name=name) return func diff --git a/star_openapi/router.py b/star_openapi/router.py index 056e3db..d8f2570 100644 --- a/star_openapi/router.py +++ b/star_openapi/router.py @@ -76,7 +76,7 @@ def __init__( operation_id_callback: Callable = get_operation_id_for_path, responses: ResponseDict | None = None, doc_ui: bool = True, - **kwargs, + **kwargs: Any, ): """ Based on Router