Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion star_openapi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def run_command(
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
headers=[header.split(":", 1) for header in headers], # type: ignore[misc]
headers=[header.split(":", 1) for header in headers], # type: ignore
use_colors=use_colors,
factory=factory,
app_dir=app_dir,
Expand Down
21 changes: 18 additions & 3 deletions star_openapi/openapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from http import HTTPMethod
from importlib import import_module
from importlib.metadata import entry_points
Expand All @@ -10,6 +10,7 @@
from starlette.applications import Starlette
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import Mount, Route
from starlette.websockets import WebSocket

from .cli import cli
from .config import Config
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
doc_ui: bool = True,
doc_prefix: str = "/openapi",
doc_url: str = "/openapi.json",
**kwargs,
**kwargs: Any,
):
"""
OpenAPI class that provides REST API functionality along with Swagger UI and Redoc, etc.
Expand Down Expand Up @@ -360,6 +361,20 @@ def _collect_openapi_info(
else:
return parse_parameters(func, doc_ui=False)

def _add_websocket_route(
self,
path: str,
endpoint: Callable[[WebSocket], Awaitable[None]],
name: str | None = None,
) -> None:
if not path.startswith("/"):
origin_path = path
path = "/" + path
else:
origin_path = path
route = APIWebSocketRoute(path=path, origin_path=origin_path, endpoint=endpoint, name=name)
self.routes.append(route)

def get(
self,
rule: str,
Expand Down Expand Up @@ -685,7 +700,7 @@ def websocket(
):
def decorator(func) -> Callable:
endpoint = create_websocket_endpoint(func)
self.add_websocket_route(rule, endpoint, name=name)
self._add_websocket_route(rule, endpoint, name=name)

return func

Expand Down
2 changes: 1 addition & 1 deletion star_openapi/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
operation_id_callback: Callable = get_operation_id_for_path,
responses: ResponseDict | None = None,
doc_ui: bool = True,
**kwargs,
**kwargs: Any,
):
"""
Based on Router
Expand Down