diff --git a/pyproject.toml b/pyproject.toml index 08859dbc..6b2079b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,10 +24,13 @@ dev = [ "pytest-codspeed>=4.1.1", "pytest-httpbin==2.0.0", "pytest-trio==0.8.0", + "starlette>=0.49", "trio==0.31.0", "trio-typing==0.10.0", "trustme==1.2.1", "uvicorn>=0.35", + "websockets>=15", + "wsproto>=1.2", "werkzeug>=3.1.6", # Linting "mypy==1.17.1", diff --git a/src/httpx2/httpx2/__init__.py b/src/httpx2/httpx2/__init__.py index 068e0a25..d7a8d8d0 100644 --- a/src/httpx2/httpx2/__init__.py +++ b/src/httpx2/httpx2/__init__.py @@ -10,16 +10,19 @@ from ._transports import * from ._types import * from ._urls import * +from ._websockets import * __all__ = [ "__description__", "__title__", "__version__", "ASGITransport", + "ASGIWebSocketTransport", "AsyncBaseTransport", "AsyncByteStream", "AsyncClient", "AsyncHTTPTransport", + "AsyncWebSocketSession", "Auth", "BaseTransport", "BasicAuth", @@ -78,6 +81,13 @@ "UnsupportedProtocol", "URL", "USE_CLIENT_DEFAULT", + "websocket", + "WebSocketDisconnect", + "WebSocketException", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketSession", + "WebSocketUpgradeError", "WriteError", "WriteTimeout", "WSGITransport", diff --git a/src/httpx2/httpx2/_api.py b/src/httpx2/httpx2/_api.py index 25171cbc..05316c95 100644 --- a/src/httpx2/httpx2/_api.py +++ b/src/httpx2/httpx2/_api.py @@ -19,6 +19,13 @@ TimeoutTypes, ) from ._urls import URL +from ._websockets._session import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, + WebSocketSession, +) if typing.TYPE_CHECKING: import ssl # pragma: no cover @@ -34,6 +41,7 @@ "put", "request", "stream", + "websocket", ] @@ -159,6 +167,61 @@ def stream( yield response +@contextmanager +def websocket( + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | None = None, + proxy: ProxyTypes | None = None, + timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + follow_redirects: bool = False, + verify: ssl.SSLContext | str | bool = True, + trust_env: bool = True, + subprotocols: list[str] | None = None, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, +) -> Generator[WebSocketSession]: + """ + Open a WebSocket session. + + The session is closed automatically when exiting the context manager. + + ``` + >>> import httpx2 + >>> with httpx2.websocket("wss://echo.websocket.org") as ws: + ... ws.send_text("Hello!") + ... message = ws.receive_text() + ``` + + **Parameters**: See `httpx2.request` and `httpx2.Client.websocket`. + """ + with Client( + cookies=cookies, + proxy=proxy, + verify=verify, + timeout=timeout, + trust_env=trust_env, + ) as client: + with client.websocket( + url, + params=params, + headers=headers, + auth=auth, + follow_redirects=follow_redirects, + subprotocols=subprotocols, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + ) as session: + yield session + + def get( url: URL | str, *, diff --git a/src/httpx2/httpx2/_client.py b/src/httpx2/httpx2/_client.py index 18720ee6..f935d8d3 100644 --- a/src/httpx2/httpx2/_client.py +++ b/src/httpx2/httpx2/_client.py @@ -48,6 +48,16 @@ ) from ._urls import URL, QueryParams from ._utils import URLPattern, get_environment_proxies +from ._websockets._session import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, + AsyncWebSocketSession, + WebSocketSession, + aconnect_ws, + connect_ws, +) if typing.TYPE_CHECKING: import ssl # pragma: no cover @@ -845,6 +855,71 @@ def stream( finally: response.close() + @contextmanager + def websocket( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + subprotocols: list[str] | None = None, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + ) -> Generator[WebSocketSession]: + """ + Open a WebSocket session. + + The session is closed automatically when exiting the context manager. + + ```python + with httpx2.Client() as client: + with client.websocket("wss://example.com/ws") as ws: + ws.send_text("Hello!") + message = ws.receive_text() + ``` + + **Parameters**: See `httpx2.request` for the request parameters, plus: + + * **subprotocols** - *(optional)* A list of subprotocols to negotiate with the server. + * **max_message_size_bytes** - Message size in bytes to receive from the server. Defaults to 64 KiB. + * **queue_size** - Size of the queue where the received messages will be held + until they are consumed. If the queue is full, the client will stop receiving + messages from the server until the queue has room available. Defaults to 512. + * **keepalive_ping_interval_seconds** - Interval at which the client will automatically + send a Ping event to keep the connection alive. Set it to `None` to disable + this mechanism. Defaults to 20 seconds. + * **keepalive_ping_timeout_seconds** - Maximum delay the client will wait for an answer + to its Ping event. If the delay is exceeded, `httpx2.WebSocketNetworkError` will be + raised and the connection closed. Defaults to 20 seconds. + + Raises `httpx2.WebSocketUpgradeError` if the connection didn't correctly + upgrade to a WebSocket session. + """ + with connect_ws( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + subprotocols=subprotocols, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + ) as session: + yield session + def send( self, request: Request, @@ -1548,6 +1623,76 @@ async def stream( finally: await response.aclose() + @asynccontextmanager + async def websocket( + self, + url: URL | str, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: RequestExtensions | None = None, + subprotocols: list[str] | None = None, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + ) -> AsyncGenerator[AsyncWebSocketSession]: + """ + Open a WebSocket session. + + The session is closed automatically when exiting the context manager. + + ```python + async with httpx2.AsyncClient() as client: + async with client.websocket("wss://example.com/ws") as ws: + await ws.send_text("Hello!") + message = await ws.receive_text() + ``` + + Internally, the session uses an anyio task group to manage background tasks. + As a result, exceptions that are not caught inside the context manager and + propagate out of the `async with` block will be wrapped in an `ExceptionGroup`. + Use the `except*` syntax to handle them. + + **Parameters**: See `httpx2.request` for the request parameters, plus: + + * **subprotocols** - *(optional)* A list of subprotocols to negotiate with the server. + * **max_message_size_bytes** - Message size in bytes to receive from the server. Defaults to 64 KiB. + * **queue_size** - Size of the queue where the received messages will be held + until they are consumed. If the queue is full, the client will stop receiving + messages from the server until the queue has room available. Defaults to 512. + * **keepalive_ping_interval_seconds** - Interval at which the client will automatically + send a Ping event to keep the connection alive. Set it to `None` to disable + this mechanism. Defaults to 20 seconds. + * **keepalive_ping_timeout_seconds** - Maximum delay the client will wait for an answer + to its Ping event. If the delay is exceeded, `httpx2.WebSocketNetworkError` will be + raised and the connection closed. Defaults to 20 seconds. + + Raises `httpx2.WebSocketUpgradeError` if the connection didn't correctly + upgrade to a WebSocket session. + """ + async with aconnect_ws( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + subprotocols=subprotocols, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + ) as session: + yield session + async def send( self, request: Request, diff --git a/src/httpx2/httpx2/_websockets/__init__.py b/src/httpx2/httpx2/_websockets/__init__.py new file mode 100644 index 00000000..1a227924 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/__init__.py @@ -0,0 +1,20 @@ +from ._exceptions import ( + WebSocketDisconnect, + WebSocketException, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, +) +from ._session import AsyncWebSocketSession, WebSocketSession +from ._transport import ASGIWebSocketTransport + +__all__ = [ + "ASGIWebSocketTransport", + "AsyncWebSocketSession", + "WebSocketDisconnect", + "WebSocketException", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketSession", + "WebSocketUpgradeError", +] diff --git a/src/httpx2/httpx2/_websockets/_exceptions.py b/src/httpx2/httpx2/_websockets/_exceptions.py new file mode 100644 index 00000000..a3ed3fa3 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_exceptions.py @@ -0,0 +1,64 @@ +""" +Our exception hierarchy: + +* WebSocketException + x WebSocketUpgradeError + x WebSocketDisconnect + x WebSocketInvalidTypeReceived + x WebSocketNetworkError +""" + +from __future__ import annotations + +import typing + +if typing.TYPE_CHECKING: + from .._models import Response # pragma: no cover + +__all__ = [ + "WebSocketDisconnect", + "WebSocketException", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketUpgradeError", +] + + +class WebSocketException(Exception): + """ + Base class for all WebSocket exceptions. + """ + + +class WebSocketUpgradeError(WebSocketException): + """ + The initial connection didn't correctly upgrade to a WebSocket session. + """ + + def __init__(self, response: Response) -> None: + self.response = response + + +class WebSocketDisconnect(WebSocketException): + """ + The server closed the WebSocket session. + """ + + def __init__(self, code: int = 1000, reason: str | None = None) -> None: + self.code = code + self.reason = reason or "" + + +class WebSocketInvalidTypeReceived(WebSocketException): + """ + A received message was not of the expected type. + """ + + def __init__(self, message: str | bytes) -> None: + self.message = message + + +class WebSocketNetworkError(WebSocketException): + """ + A network error occurred, typically because the underlying stream has closed or timed out. + """ diff --git a/src/httpx2/httpx2/_websockets/_ping.py b/src/httpx2/httpx2/_websockets/_ping.py new file mode 100644 index 00000000..f7460e67 --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_ping.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import secrets +import threading + +import anyio + + +class PingManager: + def __init__(self) -> None: + self._pings: dict[bytes, threading.Event] = {} + + def create(self, ping_id: bytes | None = None) -> tuple[bytes, threading.Event]: + ping_id = secrets.token_bytes() if not ping_id else ping_id + event = threading.Event() + self._pings[ping_id] = event + return ping_id, event + + def ack(self, ping_id: bytes | bytearray | memoryview) -> None: + event = self._pings.pop(bytes(ping_id)) + event.set() + + +class AsyncPingManager: + def __init__(self) -> None: + self._pings: dict[bytes, anyio.Event] = {} + + def create(self, ping_id: bytes | None = None) -> tuple[bytes, anyio.Event]: + ping_id = secrets.token_bytes() if not ping_id else ping_id + event = anyio.Event() + self._pings[ping_id] = event + return ping_id, event + + def ack(self, ping_id: bytes | bytearray | memoryview) -> None: + event = self._pings.pop(bytes(ping_id)) + event.set() diff --git a/src/httpx2/httpx2/_websockets/_session.py b/src/httpx2/httpx2/_websockets/_session.py new file mode 100644 index 00000000..fb1e66fb --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_session.py @@ -0,0 +1,841 @@ +from __future__ import annotations + +import base64 +import concurrent.futures +import contextlib +import json +import queue +import secrets +import threading +import typing +from collections.abc import AsyncGenerator, Generator +from types import TracebackType + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from websockets.exceptions import InvalidState +from websockets.frames import Close, Frame, Opcode +from websockets.protocol import Protocol, Side, State + +from .._models import Headers +from .._urls import URL +from ._exceptions import ( + WebSocketDisconnect, + WebSocketException, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, +) +from ._ping import AsyncPingManager, PingManager +from ._transport import ASGIWebSocketAsyncNetworkStream + +if typing.TYPE_CHECKING: + from httpcore2 import AsyncNetworkStream, NetworkStream + + from .._client import AsyncClient, Client, UseClientDefault + from .._models import Response + from .._types import AuthTypes, CookieTypes, HeaderTypes, QueryParamTypes, RequestExtensions, TimeoutTypes + +JSONMode = typing.Literal["text", "binary"] +TaskResult = typing.TypeVar("TaskResult") + +DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536 +DEFAULT_QUEUE_SIZE = 512 +DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = 20.0 +DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 + +INTERNAL_ERROR = 1011 + + +class ShouldClose(Exception): + pass + + +class EndOfStream(Exception): + pass + + +class MessageAssembler: + """ + Assembles data frames, possibly fragmented, into complete messages. + """ + + def __init__(self) -> None: + self._buffer = bytearray() + self._text = False + + def feed(self, frame: Frame) -> str | bytes | None: + if frame.opcode is Opcode.TEXT or frame.opcode is Opcode.BINARY: + self._buffer = bytearray(frame.data) + self._text = frame.opcode is Opcode.TEXT + else: + self._buffer += frame.data + if not frame.fin: + return None + data = bytes(self._buffer) + self._buffer = bytearray() + return data.decode("utf-8") if self._text else data + + +class WebSocketSession: + """ + Sync context manager representing an opened WebSocket session. + + Attributes: + subprotocol: Optional protocol that has been accepted by the server. + response: The WebSocket handshake response. + """ + + subprotocol: str | None + response: Response | None + + def __init__( + self, + stream: NetworkStream, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + response: Response | None = None, + ) -> None: + self.stream = stream + self.protocol = Protocol(Side.CLIENT, state=State.OPEN, max_size=None) + self.response = response + if self.response is not None: + self.subprotocol = self.response.headers.get("sec-websocket-protocol") + else: + self.subprotocol = None + + self._events: queue.Queue[str | bytes | WebSocketException] = queue.Queue(queue_size) + self._assembler = MessageAssembler() + + self._ping_manager = PingManager() + self._should_close = threading.Event() + self._write_lock = threading.Lock() + self._should_close_task: concurrent.futures.Future[bool] | None = None + self._executor: concurrent.futures.ThreadPoolExecutor | None = None + + self._max_message_size_bytes = max_message_size_bytes + self._queue_size = queue_size + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + + def _get_executor_should_close_task( + self, + ) -> tuple[concurrent.futures.ThreadPoolExecutor, concurrent.futures.Future[bool]]: + if self._should_close_task is None: + self._executor = concurrent.futures.ThreadPoolExecutor() + self._should_close_task = self._executor.submit(self._should_close.wait) + assert self._executor is not None + return self._executor, self._should_close_task + + def __enter__(self) -> WebSocketSession: + self._background_receive_task = threading.Thread( + target=self._background_receive, args=(self._max_message_size_bytes,) + ) + self._background_receive_task.start() + + self._background_keepalive_ping_task: threading.Thread | None = None + if self._keepalive_ping_interval_seconds is not None: + self._background_keepalive_ping_task = threading.Thread( + target=self._background_keepalive_ping, + args=( + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, + ), + ) + self._background_keepalive_ping_task.start() + + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + self._background_receive_task.join() + if self._background_keepalive_ping_task is not None: + self._background_keepalive_ping_task.join() + + def ping(self, payload: bytes = b"") -> threading.Event: + """ + Send a Ping message. + + The payload is used internally to track this specific event. + If left empty, a random one will be generated. + + Returns an event that can be used to wait for the corresponding Pong response: + + ```python + pong_callback = ws.ping() + pong_callback.wait() + ``` + """ + ping_id, callback = self._ping_manager.create(payload) + self._send(self.protocol.send_ping, ping_id) + return callback + + def send_text(self, data: str) -> None: + """ + Send a text message. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + self._send(self.protocol.send_text, data.encode("utf-8")) + + def send_bytes(self, data: bytes) -> None: + """ + Send a bytes message. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + self._send(self.protocol.send_binary, data) + + def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + """ + Send JSON data, serialized with `json.dumps()`, in `'text'` or `'binary'` mode. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + assert mode in ["text", "binary"] + serialized_data = json.dumps(data) + if mode == "text": + self.send_text(serialized_data) + else: + self.send_bytes(serialized_data.encode("utf-8")) + + def receive(self, timeout: float | None = None) -> str | bytes: + """ + Receive a message from the server, either text or bytes. + + If `timeout` is `None`, this blocks until a message is available. + + Raises: + TimeoutError: No message was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + """ + try: + event = self._events.get(block=True, timeout=timeout) + except queue.Empty as e: + raise TimeoutError from e + if isinstance(event, WebSocketException): + raise event + return event + + def receive_text(self, timeout: float | None = None) -> str: + """ + Receive text from the server. + + If `timeout` is `None`, this blocks until a message is available. + + Raises: + TimeoutError: No message was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received message was not a text message. + """ + message = self.receive(timeout) + if isinstance(message, str): + return message + raise WebSocketInvalidTypeReceived(message) + + def receive_bytes(self, timeout: float | None = None) -> bytes: + """ + Receive bytes from the server. + + If `timeout` is `None`, this blocks until a message is available. + + Raises: + TimeoutError: No message was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received message was not a bytes message. + """ + message = self.receive(timeout) + if isinstance(message, bytes): + return message + raise WebSocketInvalidTypeReceived(message) + + def receive_json(self, timeout: float | None = None, mode: JSONMode = "text") -> typing.Any: + """ + Receive JSON data from the server, parsed with `json.loads()`, in `'text'` or `'binary'` mode. + + If `timeout` is `None`, this blocks until a message is available. + + Raises: + TimeoutError: No message was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received message didn't correspond to the specified mode. + """ + assert mode in ["text", "binary"] + data: str | bytes + if mode == "text": + data = self.receive_text(timeout) + else: + data = self.receive_bytes(timeout) + return json.loads(data) + + def close(self, code: int = 1000, reason: str | None = None) -> None: + """ + Close the WebSocket session. + + Internally, it'll send a Close frame. + + *This method is automatically called when exiting the context manager.* + """ + import httpcore2 + + self._should_close.set() + if self._executor is not None: + self._executor.shutdown(False) + try: + with self._write_lock: + if self.protocol.state is State.OPEN: + self.protocol.send_close(code, reason or "") + self._write_protocol_data() + except (httpcore2.WriteError, InvalidState): + pass + self.stream.close() + + def _send(self, send_event: typing.Callable[[bytes], None], data: bytes) -> None: + import httpcore2 + + try: + with self._write_lock: + send_event(data) + self._write_protocol_data() + except httpcore2.WriteError as e: + self.close(INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e + + def _write_protocol_data(self) -> None: + for data in self.protocol.data_to_send(): + if data: + self.stream.write(data) + + def _background_receive(self, max_bytes: int) -> None: + """ + Background thread listening for data from the server. + + Internally, it'll: + + * Answer to Ping frames. + * Acknowledge Pong frames. + * Put messages in the `_events` queue that'll eventually be consumed by the user. + """ + import httpcore2 + + try: + while not self._should_close.is_set(): + data = self._wait_until_closed(self._read_stream, max_bytes) + # The protocol is not thread-safe: keep every interaction with it + # under the write lock, so it can't race user sends and closes. + with self._write_lock: + self.protocol.receive_data(data) + frames = self.protocol.events_received() + try: + self._write_protocol_data() + except httpcore2.WriteError: + # Tolerate failing to reply once the peer started the closing handshake. + if self.protocol.state is State.OPEN: + raise + for frame in frames: + assert isinstance(frame, Frame) + if frame.opcode is Opcode.PING: + continue + if frame.opcode is Opcode.PONG: + self._ping_manager.ack(frame.data) + continue + if frame.opcode is Opcode.CLOSE: + self._should_close.set() + close = Close.parse(frame.data) + self._events.put(WebSocketDisconnect(close.code, close.reason)) + continue + message = self._assembler.feed(frame) + if message is not None: + self._events.put(message) + except (httpcore2.ReadError, httpcore2.WriteError, EndOfStream): + self.close(INTERNAL_ERROR, "Stream error") + self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass + + def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: float | None = None) -> None: + try: + while not self._should_close.is_set(): + should_close = self._wait_until_closed(self._should_close.wait, interval_seconds) + if should_close: # pragma: no cover + raise ShouldClose() + + try: + pong_callback = self.ping() + # Connection is closing, exit the task + except InvalidState: + return + + if timeout_seconds is not None: + acknowledged = self._wait_until_closed(pong_callback.wait, timeout_seconds) + if not acknowledged: + self.close(INTERNAL_ERROR, "Keepalive ping timeout") + self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass + + def _wait_until_closed( + self, callable: typing.Callable[..., TaskResult], *args: typing.Any, **kwargs: typing.Any + ) -> TaskResult: + try: + executor, should_close_task = self._get_executor_should_close_task() + todo_task = executor.submit(callable, *args, **kwargs) + except RuntimeError as e: + raise ShouldClose() from e + else: + done, _ = concurrent.futures.wait( + (todo_task, should_close_task), # type: ignore[misc] + return_when=concurrent.futures.FIRST_COMPLETED, + ) + if should_close_task in done: + raise ShouldClose() + assert todo_task in done + result = todo_task.result() + return result + + def _read_stream(self, max_bytes: int) -> bytes: + data = self.stream.read(max_bytes) + if data == b"": + raise EndOfStream() + return data + + +class AsyncWebSocketSession(anyio.AsyncContextManagerMixin): + """ + Async context manager representing an opened WebSocket session. + + Internally, this session uses an anyio task group to manage background tasks. + As a result, exceptions that are not caught inside the context manager + and propagate out of the `async with` block will be wrapped in an `ExceptionGroup`. + + To handle them, use the `except*` syntax: + + ```python + async with AsyncWebSocketSession(stream) as ws: + try: + data = await ws.receive_text() + except WebSocketDisconnect: + # Caught inside the context manager: plain exception. + print("Connection closed") + + # If not caught inside: + try: + async with AsyncWebSocketSession(stream) as ws: + data = await ws.receive_text() + except* WebSocketDisconnect: + # Propagated out of the context manager: wrapped in ExceptionGroup. + print("Connection closed") + ``` + + Attributes: + subprotocol: Optional protocol that has been accepted by the server. + response: The WebSocket handshake response. + """ + + subprotocol: str | None + response: Response | None + _send_event: MemoryObjectSendStream[str | bytes | WebSocketException] + _receive_event: MemoryObjectReceiveStream[str | bytes | WebSocketException] + + def __init__( + self, + stream: AsyncNetworkStream, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: float | None = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: float | None = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + response: Response | None = None, + ) -> None: + self.stream = stream + self.protocol = Protocol(Side.CLIENT, state=State.OPEN, max_size=None) + self.response = response + if self.response is not None: + self.subprotocol = self.response.headers.get("sec-websocket-protocol") + else: + self.subprotocol = None + + self._ping_manager = AsyncPingManager() + self._should_close = anyio.Event() + self._write_lock = anyio.Lock() + self._assembler = MessageAssembler() + + self._max_message_size_bytes = max_message_size_bytes + self._queue_size = queue_size + + # Always disable keepalive ping when emulating ASGI + if isinstance(stream, ASGIWebSocketAsyncNetworkStream): + self._keepalive_ping_interval_seconds = None + self._keepalive_ping_timeout_seconds = None + else: + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + + @contextlib.asynccontextmanager + async def __asynccontextmanager__(self) -> AsyncGenerator[AsyncWebSocketSession]: + self._send_event, self._receive_event = anyio.create_memory_object_stream[str | bytes | WebSocketException]() + self._background_task_group = anyio.create_task_group() + + async with self._send_event, self._receive_event, self._background_task_group: + self._background_task_group.start_soon(self._background_receive, self._max_message_size_bytes) + if self._keepalive_ping_interval_seconds is not None: + self._background_task_group.start_soon( + self._background_keepalive_ping, + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, + ) + + try: + yield self + finally: + self._background_task_group.cancel_scope.cancel() + with anyio.CancelScope(shield=True): + await self.close() + + async def ping(self, payload: bytes = b"") -> anyio.Event: + """ + Send a Ping message. + + The payload is used internally to track this specific event. + If left empty, a random one will be generated. + + Returns an event that can be used to wait for the corresponding Pong response: + + ```python + pong_callback = await ws.ping() + await pong_callback.wait() + ``` + """ + ping_id, callback = self._ping_manager.create(payload) + await self._send(self.protocol.send_ping, ping_id) + return callback + + async def send_text(self, data: str) -> None: + """ + Send a text message. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + await self._send(self.protocol.send_text, data.encode("utf-8")) + + async def send_bytes(self, data: bytes) -> None: + """ + Send a bytes message. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + await self._send(self.protocol.send_binary, data) + + async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + """ + Send JSON data, serialized with `json.dumps()`, in `'text'` or `'binary'` mode. + + Raises `WebSocketNetworkError` if a network error occurred. + """ + assert mode in ["text", "binary"] + serialized_data = json.dumps(data) + if mode == "text": + await self.send_text(serialized_data) + else: + await self.send_bytes(serialized_data.encode("utf-8")) + + async def receive(self, timeout: float | None = None) -> str | bytes: + """ + Receive a message from the server, either text or bytes. + + If `timeout` is `None`, this blocks until a message is available. + + Raises: + TimeoutError: No message was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + """ + with anyio.fail_after(timeout): + event = await self._receive_event.receive() + if isinstance(event, WebSocketException): + raise event + return event + + async def receive_text(self, timeout: float | None = None) -> str: + """ + Receive text from the server. + + If `timeout` is `None`, this blocks until a message is available. + + Raises: + TimeoutError: No message was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received message was not a text message. + """ + message = await self.receive(timeout) + if isinstance(message, str): + return message + raise WebSocketInvalidTypeReceived(message) + + async def receive_bytes(self, timeout: float | None = None) -> bytes: + """ + Receive bytes from the server. + + If `timeout` is `None`, this blocks until a message is available. + + Raises: + TimeoutError: No message was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received message was not a bytes message. + """ + message = await self.receive(timeout) + if isinstance(message, bytes): + return message + raise WebSocketInvalidTypeReceived(message) + + async def receive_json(self, timeout: float | None = None, mode: JSONMode = "text") -> typing.Any: + """ + Receive JSON data from the server, parsed with `json.loads()`, in `'text'` or `'binary'` mode. + + If `timeout` is `None`, this blocks until a message is available. + + Raises: + TimeoutError: No message was received before the timeout delay. + WebSocketDisconnect: The server closed the WebSocket. + WebSocketNetworkError: A network error occurred. + WebSocketInvalidTypeReceived: The received message didn't correspond to the specified mode. + """ + assert mode in ["text", "binary"] + data: str | bytes + if mode == "text": + data = await self.receive_text(timeout) + else: + data = await self.receive_bytes(timeout) + return json.loads(data) + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + """ + Close the WebSocket session. + + Internally, it'll send a Close frame. + + *This method is automatically called when exiting the context manager.* + """ + import httpcore2 + + self._should_close.set() + try: + async with self._write_lock: + if self.protocol.state is State.OPEN: + self.protocol.send_close(code, reason or "") + await self._write_protocol_data() + except (httpcore2.WriteError, InvalidState): + pass + await self.stream.aclose() + + async def _send(self, send_event: typing.Callable[[bytes], None], data: bytes) -> None: + import httpcore2 + + try: + async with self._write_lock: + send_event(data) + await self._write_protocol_data() + except httpcore2.WriteError as e: + await self.close(INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e + + async def _write_protocol_data(self) -> None: + for data in self.protocol.data_to_send(): + if data: + await self.stream.write(data) + + async def _background_receive(self, max_bytes: int) -> None: + """ + Background task listening for data from the server. + + Internally, it'll: + + * Answer to Ping frames. + * Acknowledge Pong frames. + * Put messages in the `_events` queue that'll eventually be consumed by the user. + """ + import httpcore2 + + try: + while not self._should_close.is_set(): + data = await self._read_stream(max_bytes) + async with self._write_lock: + self.protocol.receive_data(data) + frames = self.protocol.events_received() + try: + await self._write_protocol_data() + except httpcore2.WriteError: + # Tolerate failing to reply once the peer started the closing handshake. + if self.protocol.state is State.OPEN: + raise + for frame in frames: + assert isinstance(frame, Frame) + if frame.opcode is Opcode.PING: + continue + if frame.opcode is Opcode.PONG: + self._ping_manager.ack(frame.data) + continue + if frame.opcode is Opcode.CLOSE: + self._should_close.set() + close = Close.parse(frame.data) + await self._send_event.send(WebSocketDisconnect(close.code, close.reason)) + continue + message = self._assembler.feed(frame) + if message is not None: + await self._send_event.send(message) + except (httpcore2.ReadError, httpcore2.WriteError, EndOfStream): + await self.close(INTERNAL_ERROR, "Stream error") + await self._send_event.send(WebSocketNetworkError()) + + async def _background_keepalive_ping(self, interval_seconds: float, timeout_seconds: float | None = None) -> None: + while not self._should_close.is_set(): + await anyio.sleep(interval_seconds) + + try: + pong_callback = await self.ping() + # Connection is closing, exit the task + except InvalidState: + return + + if timeout_seconds is not None: + try: + with anyio.fail_after(timeout_seconds): + await pong_callback.wait() + except TimeoutError: + await self.close(INTERNAL_ERROR, "Keepalive ping timeout") + await self._send_event.send(WebSocketNetworkError()) + + async def _read_stream(self, max_bytes: int) -> bytes: + data = await self.stream.read(max_bytes) + if data == b"": + raise EndOfStream() + return data + + +def _get_headers(subprotocols: list[str] | None) -> dict[str, str]: + headers = { + "connection": "upgrade", + "upgrade": "websocket", + "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), + "sec-websocket-version": "13", + } + if subprotocols is not None: + headers["sec-websocket-protocol"] = ", ".join(subprotocols) + return headers + + +def _get_url(url: URL | str) -> URL: + url = URL(url) + if url.scheme == "ws": + return url.copy_with(scheme="http") + if url.scheme == "wss": + return url.copy_with(scheme="https") + return url + + +@contextlib.contextmanager +def connect_ws( + client: Client, + url: URL | str, + *, + params: QueryParamTypes | None, + headers: HeaderTypes | None, + cookies: CookieTypes | None, + auth: AuthTypes | UseClientDefault | None, + follow_redirects: bool | UseClientDefault, + timeout: TimeoutTypes | UseClientDefault, + extensions: RequestExtensions | None, + subprotocols: list[str] | None, + max_message_size_bytes: int, + queue_size: int, + keepalive_ping_interval_seconds: float | None, + keepalive_ping_timeout_seconds: float | None, +) -> Generator[WebSocketSession]: + merged_headers = Headers(headers) + merged_headers.update(_get_headers(subprotocols)) + + with client.stream( + "GET", + _get_url(url), + params=params, + headers=merged_headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + session = WebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + response=response, + ) + with session: + yield session + + +@contextlib.asynccontextmanager +async def aconnect_ws( + client: AsyncClient, + url: URL | str, + *, + params: QueryParamTypes | None, + headers: HeaderTypes | None, + cookies: CookieTypes | None, + auth: AuthTypes | UseClientDefault | None, + follow_redirects: bool | UseClientDefault, + timeout: TimeoutTypes | UseClientDefault, + extensions: RequestExtensions | None, + subprotocols: list[str] | None, + max_message_size_bytes: int, + queue_size: int, + keepalive_ping_interval_seconds: float | None, + keepalive_ping_timeout_seconds: float | None, +) -> AsyncGenerator[AsyncWebSocketSession]: + merged_headers = Headers(headers) + merged_headers.update(_get_headers(subprotocols)) + + async with client.stream( + "GET", + _get_url(url), + params=params, + headers=merged_headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + session = AsyncWebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + response=response, + ) + async with session: + yield session diff --git a/src/httpx2/httpx2/_websockets/_transport.py b/src/httpx2/httpx2/_websockets/_transport.py new file mode 100644 index 00000000..ba6dd39c --- /dev/null +++ b/src/httpx2/httpx2/_websockets/_transport.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +import contextlib +import math +import typing +from types import TracebackType + +import anyio +import anyio.abc +import anyio.streams.stapled +from websockets.frames import Close, Frame, Opcode +from websockets.protocol import Protocol, Side, State +from websockets.utils import accept_key + +from .._models import Request, Response +from .._transports.asgi import ASGITransport +from .._types import AsyncByteStream +from ._exceptions import WebSocketDisconnect, WebSocketUpgradeError + +Scope = typing.MutableMapping[str, typing.Any] +Message = typing.MutableMapping[str, typing.Any] +Receive = typing.Callable[[], typing.Awaitable[Message]] +Send = typing.Callable[[Message], typing.Awaitable[None]] +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] + +INTERNAL_ERROR = 1011 + + +class ASGIWebSocketTransportError(Exception): + pass + + +class UnhandledASGIMessageType(ASGIWebSocketTransportError): + def __init__(self, message: Message) -> None: + self.message = message + + +class UnhandledWebSocketFrame(ASGIWebSocketTransportError): + def __init__(self, frame: Frame) -> None: + self.frame = frame + + +class ASGIWebSocketAsyncNetworkStream: + """ + An `httpcore2.AsyncNetworkStream` lookalike that translates reads and writes + into ASGI messages exchanged with the wrapped app. + """ + + def __init__( + self, + app: ASGIApp, + scope: Scope, + task_group: anyio.abc.TaskGroup, + initial_receive_timeout: float = 1.0, + ) -> None: + self.app = app + self.scope = scope + self._receive_queue = anyio.streams.stapled.StapledObjectStream( + *anyio.create_memory_object_stream[Message](max_buffer_size=math.inf) + ) + self._send_queue = anyio.streams.stapled.StapledObjectStream( + *anyio.create_memory_object_stream[Message](max_buffer_size=math.inf) + ) + self._task_group = task_group + self._initial_receive_timeout = initial_receive_timeout + self.protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + headers = {key.lower(): value for key, value in scope["headers"]} + self._websocket_key: bytes = headers[b"sec-websocket-key"] + self._aentered = False + + async def __aenter__(self) -> tuple[ASGIWebSocketAsyncNetworkStream, bytes]: + if self._aentered: + raise RuntimeError("Cannot use ASGIWebSocketAsyncNetworkStream in a context manager twice") + self._aentered = True + self._task_group.start_soon(self._run) + async with contextlib.AsyncExitStack() as stack: + stack.push_async_callback(self.aclose) + + await self.send({"type": "websocket.connect"}) + + try: + message = await self.receive(self._initial_receive_timeout) + except TimeoutError as e: + raise RuntimeError( + "WebSocket didn't accept the connection in time. Did you forget to call accept()?" + ) from e + + if message["type"] == "websocket.close": + await stack.aclose() + raise WebSocketDisconnect(message["code"], message.get("reason")) + + # Websocket Denial Response extension + # Ref: https://asgi.readthedocs.io/en/latest/extensions.html#websocket-denial-response + if message["type"] == "websocket.http.response.start": + status_code: int = message["status"] + headers: list[tuple[bytes, bytes]] = message["headers"] + body: list[bytes] = [] + while True: + message = await self.receive() + assert message["type"] == "websocket.http.response.body" + body.append(message["body"]) + if not message.get("more_body", False): + break + + await stack.aclose() + raise WebSocketUpgradeError(Response(status_code, headers=headers, content=b"".join(body))) + + assert message["type"] == "websocket.accept" + retval = self, self._build_accept_response(message) + self._exit_stack = stack.pop_all() + return retval + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + message = await self.receive(timeout=timeout) + message_type = message["type"] + + if message_type not in {"websocket.send", "websocket.close"}: + raise UnhandledASGIMessageType(message) + + if message_type == "websocket.send": + data_str: str | None = message.get("text") + if data_str is not None: + self.protocol.send_text(data_str.encode("utf-8")) + data_bytes: bytes | None = message.get("bytes") + if data_bytes is not None: + self.protocol.send_binary(data_bytes) + else: + self.protocol.send_close(message["code"], message.get("reason") or "") + + return b"".join(data for data in self.protocol.data_to_send() if data) + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.protocol.receive_data(buffer) + for frame in self.protocol.events_received(): + assert isinstance(frame, Frame) + if frame.opcode is Opcode.CLOSE: + close = Close.parse(frame.data) + await self.send( + { + "type": "websocket.disconnect", + "code": close.code, + "reason": close.reason, + } + ) + elif frame.opcode is Opcode.TEXT: + await self.send({"type": "websocket.receive", "text": bytes(frame.data).decode("utf-8")}) + elif frame.opcode is Opcode.BINARY: + await self.send({"type": "websocket.receive", "bytes": bytes(frame.data)}) + else: + raise UnhandledWebSocketFrame(frame) + + async def aclose(self) -> None: + with contextlib.suppress(anyio.ClosedResourceError): + await self.send({"type": "websocket.disconnect"}) + await self._receive_queue.aclose() + await self._send_queue.aclose() + + async def send(self, message: Message) -> None: + await self._receive_queue.send(message) + + async def receive(self, timeout: float | None = None) -> Message: + if timeout is None: + timeout = math.inf + with anyio.fail_after(timeout): + return await self._send_queue.receive() + + async def _run(self) -> None: + """ + The task in which the websocket session runs. + """ + scope = self.scope + receive = self._receive_queue.receive + send = self._send_queue.send + try: + await self.app(scope, receive, send) + except Exception as e: + message: Message = { + "type": "websocket.close", + "code": INTERNAL_ERROR, + "reason": str(e), + } + with contextlib.suppress(anyio.ClosedResourceError): + await send(message) + + def _build_accept_response(self, message: Message) -> bytes: + subprotocol: str | None = message.get("subprotocol", None) + headers: list[tuple[bytes, bytes]] = message.get("headers", []) + response_headers = [ + (b"Upgrade", b"websocket"), + (b"Connection", b"Upgrade"), + (b"Sec-WebSocket-Accept", accept_key(self._websocket_key.decode("utf-8")).encode("utf-8")), + ] + if subprotocol is not None: + response_headers.append((b"Sec-WebSocket-Protocol", subprotocol.encode("utf-8"))) + response_headers.extend(headers) + return b"".join( + [ + b"HTTP/1.1 101 Switching Protocols\r\n", + b"".join(key + b": " + value + b"\r\n" for key, value in response_headers), + b"\r\n", + ] + ) + + +class ASGIWebSocketTransport(ASGITransport): + """ + A custom `ASGITransport` that handles WebSocket upgrade requests + by emulating the WebSocket protocol against the ASGI app. + + Plain HTTP requests are handled as usual by `ASGITransport`. + + ```python + transport = httpx2.ASGIWebSocketTransport(app=app) + client = httpx2.AsyncClient(transport=transport) + ``` + """ + + scope: Scope + + def __init__( + self, + app: ASGIApp, + raise_app_exceptions: bool = True, + root_path: str = "", + client: tuple[str, int] = ("127.0.0.1", 123), + initial_receive_timeout: float = 1.0, + ) -> None: + super().__init__(app, raise_app_exceptions, root_path, client) + self._exit_stack: contextlib.AsyncExitStack | None = None + self._initial_receive_timeout = initial_receive_timeout + + async def __aenter__(self) -> ASGIWebSocketTransport: + async with contextlib.AsyncExitStack() as stack: + self._task_group = await stack.enter_async_context(anyio.create_task_group()) + self._exit_stack = stack.pop_all() + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_val: BaseException | None = None, + exc_tb: TracebackType | None = None, + ) -> None: + await super().__aexit__(exc_type, exc_val, exc_tb) + assert self._exit_stack is not None + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + + async def handle_async_request(self, request: Request) -> Response: + scheme = request.url.scheme + headers = request.headers + + if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": + subprotocols: list[str] = [] + if (subprotocols_header := headers.get("sec-websocket-protocol")) is not None: + subprotocols = subprotocols_header.split(",") + + scope: Scope = { + "type": "websocket", + "path": request.url.path, + "raw_path": request.url.raw_path, + "root_path": self.root_path, + "scheme": {"http": "ws", "https": "wss"}.get(scheme, scheme), + "query_string": request.url.query, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "client": self.client, + "server": (request.url.host, request.url.port), + "subprotocols": subprotocols, + } + return await self._handle_ws_request(request, scope) + + return await super().handle_async_request(request) + + async def _create_asgi_websocket_async_network_stream( + self, + *, + task_status: anyio.abc.TaskStatus[tuple[ASGIWebSocketAsyncNetworkStream, bytes]], + ) -> None: + stream = ASGIWebSocketAsyncNetworkStream( + self.app, + self.scope, + self._task_group, + self._initial_receive_timeout, + ) + assert self._exit_stack is not None + result = await self._exit_stack.enter_async_context(stream) + task_status.started(result) + + async def _handle_ws_request(self, request: Request, scope: Scope) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + self.scope = scope + stream, accept_response = await self._task_group.start(self._create_asgi_websocket_async_network_stream) + accept_response_lines = accept_response.decode("utf-8").splitlines() + headers = [ + typing.cast(tuple[str, str], line.split(": ", 1)) for line in accept_response_lines[1:] if line.strip() + ] + + return Response( + status_code=101, + headers=headers, + extensions={"network_stream": stream}, + ) diff --git a/src/httpx2/pyproject.toml b/src/httpx2/pyproject.toml index dc194f7f..e2a38c63 100644 --- a/src/httpx2/pyproject.toml +++ b/src/httpx2/pyproject.toml @@ -46,9 +46,10 @@ dynamic = ["readme", "version", "dependencies"] dependencies = [ "truststore>=0.10", "httpcore2=={{ version }}", - "anyio", + "anyio>=4.10", "idna>=3.18", "typing_extensions>=4.5.0; python_version < '3.13'", + "websockets>=15", ] [project.optional-dependencies] diff --git a/tests/httpx2/conftest.py b/tests/httpx2/conftest.py index 156c5e3a..363193ce 100644 --- a/tests/httpx2/conftest.py +++ b/tests/httpx2/conftest.py @@ -274,6 +274,8 @@ def serve_in_thread(server: TestServer) -> typing.Iterator[TestServer]: @pytest.fixture(scope="session") def server(free_tcp_port_factory: typing.Callable[[], int]) -> typing.Iterator[TestServer]: - config = Config(app=app, lifespan="off", loop="asyncio", port=free_tcp_port_factory()) + # `ws="auto"` would pick uvicorn's implementation based on the deprecated `websockets.legacy`, + # which warns on import, and warnings are errors here. This server only handles plain HTTP. + config = Config(app=app, lifespan="off", loop="asyncio", port=free_tcp_port_factory(), ws="none") server = TestServer(config=config) yield from serve_in_thread(server) diff --git a/tests/httpx2/websockets/__init__.py b/tests/httpx2/websockets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/httpx2/websockets/conftest.py b/tests/httpx2/websockets/conftest.py new file mode 100644 index 00000000..732bf1f7 --- /dev/null +++ b/tests/httpx2/websockets/conftest.py @@ -0,0 +1,67 @@ +import contextlib +import pathlib +import queue +import tempfile +import time +import typing +from unittest.mock import MagicMock + +import pytest +import uvicorn +from anyio.from_thread import start_blocking_portal +from starlette.applications import Starlette +from starlette.routing import WebSocketRoute +from starlette.websockets import WebSocket + +WebSocketEndpoint = typing.Callable[[WebSocket], typing.Awaitable[None]] + + +@pytest.fixture +def on_receive_message() -> MagicMock: + return MagicMock() + + +@pytest.fixture(params=("wsproto", "websockets-sansio")) +def websocket_implementation(request: pytest.FixtureRequest) -> typing.Literal["wsproto", "websockets-sansio"]: + return request.param # type: ignore[no-any-return] + + +class ServerFactoryFixture(typing.Protocol): + def __call__(self, endpoint: WebSocketEndpoint) -> contextlib.AbstractContextManager[str]: ... + + +@pytest.fixture +def server_factory(websocket_implementation: typing.Literal["wsproto", "websockets-sansio"]) -> ServerFactoryFixture: + @contextlib.contextmanager + def _server_factory(endpoint: WebSocketEndpoint) -> typing.Iterator[str]: + shutdown_queue: queue.Queue[bool] = queue.Queue() + + def create_app() -> Starlette: + routes = [ + WebSocketRoute("/ws", endpoint=endpoint), + ] + return Starlette(routes=routes) + + def create_server(app: Starlette, socket: str) -> uvicorn.Server: + config = uvicorn.Config(app, uds=socket, ws=websocket_implementation, lifespan="off") + return uvicorn.Server(config) + + def on_server_stopped(_task: object) -> None: + shutdown_queue.put(True) + + with start_blocking_portal(backend="asyncio") as portal: + with tempfile.TemporaryDirectory() as socket_directory: + socket = str(pathlib.Path(socket_directory) / "socket.sock") + app = create_app() + server = create_server(app, socket) + task = portal.start_task_soon(server.serve) + task.add_done_callback(on_server_stopped) + while not server.started and not task.done(): + time.sleep(0.01) + if task.done() and task.exception() is not None: # pragma: no cover + raise typing.cast(BaseException, task.exception()) + yield socket + server.should_exit = True + shutdown_queue.get(True) + + return _server_factory diff --git a/tests/httpx2/websockets/test_session.py b/tests/httpx2/websockets/test_session.py new file mode 100644 index 00000000..dbe200d5 --- /dev/null +++ b/tests/httpx2/websockets/test_session.py @@ -0,0 +1,928 @@ +import concurrent.futures +import queue +import threading +import time +from unittest.mock import MagicMock, call, patch + +import anyio +import pytest +from starlette.websockets import WebSocket, WebSocketDisconnect as StarletteWebSocketDisconnect +from websockets.frames import Frame, Opcode +from websockets.protocol import Protocol, Side, State + +import httpcore2 +import httpx2 +from httpcore2 import AsyncNetworkStream, NetworkStream +from httpx2 import ( + AsyncWebSocketSession, + WebSocketDisconnect, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketSession, + WebSocketUpgradeError, + _api, +) +from httpx2._websockets._session import JSONMode +from tests.httpx2.websockets.conftest import ServerFactoryFixture + + +def wire(protocol: Protocol) -> bytes: + return b"".join(data for data in protocol.data_to_send() if data) + + +def mock_network_stream(spec: type) -> MagicMock: + stream = MagicMock(spec=spec) + stream.read.return_value = b"" + return stream + + +@pytest.mark.anyio +async def test_upgrade_error() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(400) + + with httpx2.Client(base_url="http://localhost:8000", transport=httpx2.MockTransport(handler)) as client: + with pytest.raises(WebSocketUpgradeError): + with client.websocket("http://socket/ws"): + pass # pragma: no cover + + async with httpx2.AsyncClient(base_url="http://localhost:8000", transport=httpx2.MockTransport(handler)) as aclient: + with pytest.raises(WebSocketUpgradeError): + async with aclient.websocket("http://socket/ws"): + pass # pragma: no cover + + +def test_top_level_websocket() -> None: + with patch.object(_api, "Client") as mock_client_cls: + mock_client = mock_client_cls.return_value.__enter__.return_value + with httpx2.websocket("ws://socket/ws", subprotocols=["custom_protocol"]): + pass + mock_client.websocket.assert_called_once() + assert mock_client.websocket.call_args[1]["subprotocols"] == ["custom_protocol"] + + +@pytest.mark.anyio +class TestSend: + async def test_send_error(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self._should_close = False + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: # pragma: no cover + while not self._should_close: + time.sleep(0.1) + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise httpcore2.WriteError() + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession(stream) as websocket_session: + websocket_session.send_text("CLIENT_MESSAGE") + + async def test_async_send_error(self) -> None: + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self._should_close = False + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: # pragma: no cover + while not self._should_close: + await anyio.sleep(0.1) + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise httpcore2.WriteError() + + async def aclose(self) -> None: + self._should_close = True + + stream = AsyncMockNetworkStream() + with pytest.RaisesGroup(WebSocketNetworkError): + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.send_text("CLIENT_MESSAGE") + + async def test_send_text( + self, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_text() + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ws.send_text("CLIENT_MESSAGE") + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + await aws.send_text("CLIENT_MESSAGE") + + on_receive_message.assert_has_calls([call("CLIENT_MESSAGE"), call("CLIENT_MESSAGE")]) + + async def test_send_bytes( + self, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_bytes() + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ws.send_bytes(b"CLIENT_MESSAGE") + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + await aws.send_bytes(b"CLIENT_MESSAGE") + + on_receive_message.assert_has_calls([call(b"CLIENT_MESSAGE"), call(b"CLIENT_MESSAGE")]) + + @pytest.mark.parametrize("mode", ["text", "binary"]) + async def test_send_json( + self, + mode: JSONMode, + server_factory: ServerFactoryFixture, + on_receive_message: MagicMock, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + message = await websocket.receive_json(mode=mode) + on_receive_message(message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + await aws.send_json({"message": "CLIENT_MESSAGE"}, mode=mode) + + on_receive_message.assert_has_calls([call({"message": "CLIENT_MESSAGE"}), call({"message": "CLIENT_MESSAGE"})]) + + +@pytest.mark.anyio +class TestReceive: + async def test_receive_error(self) -> None: + class MockNetworkStream(NetworkStream): + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession(stream) as websocket_session: + websocket_session.receive() + + def test_receive_closed_socket(self) -> None: + class MockNetworkStream(NetworkStream): + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + return b"" + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession(stream) as websocket_session: + websocket_session.receive() + + def test_receive_timeout(self) -> None: + class MockNetworkStream(NetworkStream): + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + time.sleep(0.2) + return b"" + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with pytest.raises(TimeoutError): + with WebSocketSession(stream) as websocket_session: + websocket_session.receive(timeout=0.1) + + async def test_async_receive_error(self) -> None: + class AsyncMockNetworkStream(AsyncNetworkStream): + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + pass + + stream = AsyncMockNetworkStream() + with pytest.RaisesGroup(WebSocketNetworkError): + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.receive() + + async def test_async_receive_closed_socket(self) -> None: + class AsyncMockNetworkStream(AsyncNetworkStream): + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + return b"" + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + pass + + stream = AsyncMockNetworkStream() + with pytest.RaisesGroup(WebSocketNetworkError): + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.receive() + + async def test_receive(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + message = ws.receive() + assert message == "SERVER_MESSAGE" + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + message = await aws.receive() + assert message == "SERVER_MESSAGE" + + @pytest.mark.parametrize( + "full_message,send_method", + [ + pytest.param(b"A" * 1024 * 4, "send_bytes", id="bytes"), + pytest.param("A" * 1024 * 4, "send_text", id="text"), + ], + ) + async def test_receive_oversized_message( + self, + full_message: str | bytes, + send_method: str, + server_factory: ServerFactoryFixture, + ) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + method = getattr(websocket, send_method) + await method(full_message) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws", max_message_size_bytes=1024) as ws: + message = ws.receive() + assert message == full_message + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws", max_message_size_bytes=1024) as aws: + message = await aws.receive() + assert message == full_message + + async def test_receive_fragmented_message(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + protocol.send_text(b"SERVER", fin=False) + first = wire(protocol) + protocol.send_continuation(b"_MESSAGE", fin=True) + second = wire(protocol) + self.data_to_send = [first, second] + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + return self.data_to_send.pop(0) + except IndexError: + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with WebSocketSession(stream) as websocket_session: + assert websocket_session.receive() == "SERVER_MESSAGE" + + async def test_async_receive_fragmented_message(self) -> None: + class AsyncMockNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + protocol.send_text(b"SERVER", fin=False) + first = wire(protocol) + protocol.send_continuation(b"_MESSAGE", fin=True) + second = wire(protocol) + self.data_to_send = [first, second] + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + return self.data_to_send.pop(0) + except IndexError: + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + pass + + stream = AsyncMockNetworkStream() + async with AsyncWebSocketSession(stream) as websocket_session: + assert await websocket_session.receive() == "SERVER_MESSAGE" + + async def test_receive_text(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + data = ws.receive_text() + assert data == "SERVER_MESSAGE" + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + data = await aws.receive_text() + assert data == "SERVER_MESSAGE" + + async def test_receive_text_invalid_type(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_bytes(b"SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + with pytest.raises(WebSocketInvalidTypeReceived): + ws.receive_text() + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + with pytest.raises(WebSocketInvalidTypeReceived): + await aws.receive_text() + + async def test_receive_bytes(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_bytes(b"SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + data = ws.receive_bytes() + assert data == b"SERVER_MESSAGE" + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + data = await aws.receive_bytes() + assert data == b"SERVER_MESSAGE" + + async def test_receive_bytes_invalid_type(self, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_text("SERVER_MESSAGE") + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + with pytest.raises(WebSocketInvalidTypeReceived): + ws.receive_bytes() + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + with pytest.raises(WebSocketInvalidTypeReceived): + await aws.receive_bytes() + + @pytest.mark.parametrize("mode", ["text", "binary"]) + async def test_receive_json(self, mode: JSONMode, server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + + await websocket.send_json({"message": "SERVER_MESSAGE"}, mode=mode) + + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + data = ws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + data = await aws.receive_json(mode=mode) + assert data == {"message": "SERVER_MESSAGE"} + + +@pytest.mark.anyio +class TestReceivePing: + async def test_receive_ping(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + self.received_frames: list[Frame] = [] + self.protocol.send_ping(b"SERVER_PING") + self.protocol.send_close(1000) + self.data_to_send = [wire(self.protocol)] + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + return self.data_to_send.pop(0) + except IndexError: # pragma: no cover + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.protocol.receive_data(buffer) + self.received_frames.extend(e for e in self.protocol.events_received() if isinstance(e, Frame)) + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with WebSocketSession(stream): + await anyio.sleep(0.1) + + assert [frame.opcode for frame in stream.received_frames] == [Opcode.PONG, Opcode.CLOSE] + assert bytes(stream.received_frames[0].data) == b"SERVER_PING" + + async def test_async_receive_ping(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + self.received_frames: list[Frame] = [] + self.protocol.send_ping(b"SERVER_PING") + self.protocol.send_close(1000) + self.data_to_send = [wire(self.protocol)] + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + return self.data_to_send.pop(0) + except IndexError: # pragma: no cover + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.protocol.receive_data(buffer) + self.received_frames.extend(e for e in self.protocol.events_received() if isinstance(e, Frame)) + + async def aclose(self) -> None: + pass + + stream = MockAsyncNetworkStream() + async with AsyncWebSocketSession(stream): + await anyio.sleep(0.1) + + assert [frame.opcode for frame in stream.received_frames] == [Opcode.PONG, Opcode.CLOSE] + assert bytes(stream.received_frames[0].data) == b"SERVER_PING" + + async def test_receive_ping_reply_write_error(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + protocol.send_ping(b"SERVER_PING") + self.data_to_send = [wire(protocol)] + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + return self.data_to_send.pop(0) + except IndexError: # pragma: no cover + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise httpcore2.WriteError() + + def close(self) -> None: + pass + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession(stream) as websocket_session: + websocket_session.receive() + + async def test_async_receive_ping_reply_write_error(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + protocol.send_ping(b"SERVER_PING") + self.data_to_send = [wire(protocol)] + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + try: + return self.data_to_send.pop(0) + except IndexError: # pragma: no cover + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise httpcore2.WriteError() + + async def aclose(self) -> None: + pass + + stream = MockAsyncNetworkStream() + with pytest.RaisesGroup(WebSocketNetworkError): + async with AsyncWebSocketSession(stream) as websocket_session: + await websocket_session.receive() + + +class NoopNetworkStream(NetworkStream): + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise NotImplementedError + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise NotImplementedError + + def close(self) -> None: + pass + + +class NoopAsyncNetworkStream(AsyncNetworkStream): + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + raise NotImplementedError + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + raise NotImplementedError + + async def aclose(self) -> None: + pass + + +@pytest.mark.anyio +class TestKeepalivePing: + async def test_keepalive_ping_closing_connection(self) -> None: + session = WebSocketSession(NoopNetworkStream()) + session.protocol.receive_eof() + session._background_keepalive_ping(0.01) + session.close() + + async def test_async_keepalive_ping_closing_connection(self) -> None: + session = AsyncWebSocketSession(NoopAsyncNetworkStream()) + session.protocol.receive_eof() + await session._background_keepalive_ping(0.01) + await session.close() + + async def test_keepalive_ping(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self.protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + self._should_close = False + self.ping_received = 0 + self.ping_answered = 0 + self.data_to_send: queue.Queue[bytes] = queue.Queue() + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: + try: + data = self.data_to_send.get_nowait() + self.ping_answered += 1 + return data + except queue.Empty: + pass + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.protocol.receive_data(buffer) + for frame in self.protocol.events_received(): + if isinstance(frame, Frame) and frame.opcode is Opcode.PING: + self.ping_received += 1 + self.data_to_send.put(wire(self.protocol)) + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + with WebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ): + await anyio.sleep(0.2) + + assert stream.ping_received >= 1 + assert stream.ping_answered >= 1 + + async def test_keepalive_ping_timeout(self) -> None: + class MockNetworkStream(NetworkStream): + def __init__(self) -> None: + self._should_close = False + + def read(self, max_bytes: int, timeout: float | None = None) -> bytes: # pragma: no cover + while not self._should_close: + time.sleep(0.1) + raise httpcore2.ReadError() + + def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + def close(self) -> None: + self._should_close = True + + stream = MockNetworkStream() + with pytest.raises(WebSocketNetworkError): + with WebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ) as websocket_session: + websocket_session.receive() + + async def test_async_keepalive_ping(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self.protocol = Protocol(Side.SERVER, state=State.OPEN, max_size=None) + self._should_close = False + self.ping_received = 0 + self.ping_answered = 0 + ( + self.send_data, + self.receive_data, + ) = anyio.create_memory_object_stream[bytes]() + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: + while not self._should_close: + try: + data = self.receive_data.receive_nowait() + self.ping_answered += 1 + return data + except anyio.WouldBlock: + await anyio.sleep(0.1) + raise httpcore2.ReadError() # pragma: no cover + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + self.protocol.receive_data(buffer) + for frame in self.protocol.events_received(): + if isinstance(frame, Frame) and frame.opcode is Opcode.PING: + self.ping_received += 1 + await self.send_data.send(wire(self.protocol)) + + async def aclose(self) -> None: + self._should_close = True + self.send_data.close() + self.receive_data.close() + + stream = MockAsyncNetworkStream() + async with AsyncWebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ): + await anyio.sleep(0.3) + + assert stream.ping_received >= 1 + assert stream.ping_answered >= 1 + + async def test_async_keepalive_ping_timeout(self) -> None: + class MockAsyncNetworkStream(AsyncNetworkStream): + def __init__(self) -> None: + self._should_close = False + + async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: # pragma: no cover + while not self._should_close: + await anyio.sleep(0.1) + raise httpcore2.ReadError() + + async def write(self, buffer: bytes, timeout: float | None = None) -> None: + pass + + async def aclose(self) -> None: + self._should_close = True + + stream = MockAsyncNetworkStream() + with pytest.RaisesGroup(WebSocketNetworkError): + async with AsyncWebSocketSession( + stream, + keepalive_ping_interval_seconds=0.1, + keepalive_ping_timeout_seconds=0.1, + ) as websocket_session: + await websocket_session.receive() + + +@pytest.mark.anyio +async def test_ping_pong(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + try: + await websocket.receive_text() + except StarletteWebSocketDisconnect: + pass + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ping_callback = ws.ping() + result = ping_callback.wait() + assert result is True + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + aping_callback = await aws.ping() + await aping_callback.wait() + assert aping_callback.is_set() + + +@pytest.mark.anyio +async def test_send_close(server_factory: ServerFactoryFixture, on_receive_message: MagicMock) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + try: + await websocket.receive_text() + except StarletteWebSocketDisconnect as e: + on_receive_message(e.code, e.reason) + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + ws.close(code=1001, reason="CLOSE_REASON") + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + await aws.close(code=1001, reason="CLOSE_REASON") + + on_receive_message.assert_has_calls([call(1001, "CLOSE_REASON"), call(1001, "CLOSE_REASON")]) + + +@pytest.mark.anyio +async def test_receive_close(server_factory: ServerFactoryFixture) -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + with pytest.raises(WebSocketDisconnect): + ws.receive() + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + with pytest.raises(WebSocketDisconnect): + await aws.receive() + + +@pytest.mark.anyio +async def test_subprotocol_and_response() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + assert request.headers["sec-websocket-protocol"] == "custom_protocol, unsupported_protocol" + + return httpx2.Response( + 101, + headers={"sec-websocket-protocol": "custom_protocol"}, + extensions={"network_stream": mock_network_stream(NetworkStream)}, + ) + + def async_handler(request: httpx2.Request) -> httpx2.Response: + assert request.headers["sec-websocket-protocol"] == "custom_protocol, unsupported_protocol" + + return httpx2.Response( + 101, + headers={"sec-websocket-protocol": "custom_protocol"}, + extensions={"network_stream": mock_network_stream(AsyncNetworkStream)}, + ) + + with httpx2.Client(base_url="http://localhost:8000", transport=httpx2.MockTransport(handler)) as client: + with client.websocket( + "http://socket/ws", + subprotocols=["custom_protocol", "unsupported_protocol"], + ) as ws: + assert isinstance(ws.response, httpx2.Response) + assert ws.subprotocol == "custom_protocol" + assert ws.response.headers["sec-websocket-protocol"] == ws.subprotocol + + async with httpx2.AsyncClient( + base_url="http://localhost:8000", transport=httpx2.MockTransport(async_handler) + ) as aclient: + async with aclient.websocket( + "http://socket/ws", + subprotocols=["custom_protocol", "unsupported_protocol"], + ) as aws: + assert isinstance(aws.response, httpx2.Response) + assert aws.subprotocol == "custom_protocol" + assert aws.response.headers["sec-websocket-protocol"] == aws.subprotocol + + +@pytest.mark.anyio +async def test_threads_wont_hang(server_factory: ServerFactoryFixture) -> None: + """ + Check that all threads spawned in WebSocketSession are properly terminated during + a series of messages exchange. This used to be the cause of a memory leak in the + connect_ws client, see https://github.com/frankie567/httpx-ws/issues/76. + """ + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + for _ in range(50): + await websocket.send_text("SERVER_MESSAGE") + await websocket.receive_text() + await websocket.close() + + with server_factory(websocket_endpoint) as socket: + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + initial_threads = set(threading.enumerate()) + with client.websocket("http://socket/ws", keepalive_ping_interval_seconds=None) as ws: + for _ in range(50): + ws.receive() + ws.send_text("CLIENT_MESSAGE") + session_threads = set(threading.enumerate()) - initial_threads + assert session_threads + deadline = time.time() + 5 + while any(thread.is_alive() for thread in session_threads) and time.time() < deadline: + time.sleep(0.01) # pragma: no cover + assert not any(thread.is_alive() for thread in session_threads) + + +@pytest.mark.anyio +async def test_concurrency_write(server_factory: ServerFactoryFixture) -> None: + """ + Check that there is no error because of two tasks trying to write the stream at the + same time. Typically, this is when a background ping tries to send a ping while the + main task is sending a message. + + See: https://github.com/frankie567/httpx-ws/issues/29 + """ + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + async for message in websocket.iter_text(): + await websocket.send_text(message) + + with server_factory(websocket_endpoint) as socket: + # Added for completeness, but were not able to reproduce the issue with the sync client + with httpx2.Client(transport=httpx2.HTTPTransport(uds=socket)) as client: + with client.websocket("http://socket/ws") as ws: + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + for _ in range(10): + executor.submit(ws.send_text, "CLIENT_MESSAGE") + + async with httpx2.AsyncClient(transport=httpx2.AsyncHTTPTransport(uds=socket)) as aclient: + async with aclient.websocket("http://socket/ws") as aws: + async with anyio.create_task_group() as tg: + for _ in range(10): + tg.start_soon(aws.send_text, "CLIENT_MESSAGE") + + +@pytest.mark.anyio +async def test_client_websocket_with_wss_scheme() -> None: + def handler(request: httpx2.Request) -> httpx2.Response: + assert request.url.scheme == "https" + return httpx2.Response(101, extensions={"network_stream": mock_network_stream(NetworkStream)}) + + def async_handler(request: httpx2.Request) -> httpx2.Response: + assert request.url.scheme == "https" + return httpx2.Response(101, extensions={"network_stream": mock_network_stream(AsyncNetworkStream)}) + + with httpx2.Client(base_url="http://localhost:8000", transport=httpx2.MockTransport(handler)) as client: + with client.websocket("wss://socket/ws") as ws: + assert isinstance(ws.response, httpx2.Response) + + async with httpx2.AsyncClient( + base_url="http://localhost:8000", transport=httpx2.MockTransport(async_handler) + ) as aclient: + async with aclient.websocket("wss://socket/ws") as aws: + assert isinstance(aws.response, httpx2.Response) diff --git a/tests/httpx2/websockets/test_transport.py b/tests/httpx2/websockets/test_transport.py new file mode 100644 index 00000000..2ead09e5 --- /dev/null +++ b/tests/httpx2/websockets/test_transport.py @@ -0,0 +1,394 @@ +import base64 +import secrets +import sys +import typing + +import anyio +import pytest +from anyio import CancelScope, ClosedResourceError, create_task_group +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import PlainTextResponse +from starlette.routing import Route, WebSocketRoute +from starlette.websockets import WebSocket +from websockets.frames import Frame, Opcode +from websockets.protocol import Protocol, Side, State + +import httpx2 +from httpx2 import ASGIWebSocketTransport, WebSocketDisconnect, WebSocketUpgradeError +from httpx2._websockets._transport import ( + ASGIWebSocketAsyncNetworkStream, + Message, + Receive, + Scope, + Send, + UnhandledASGIMessageType, + UnhandledWebSocketFrame, +) + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup # pragma: no cover + + +def wire(protocol: Protocol) -> bytes: + return b"".join(data for data in protocol.data_to_send() if data) + + +@pytest.fixture +def websocket_request_headers() -> dict[str, str]: + return { + "connection": "upgrade", + "upgrade": "websocket", + "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), + "sec-websocket-version": "13", + } + + +@pytest.fixture +def scope(websocket_request_headers: dict[str, str]) -> Scope: + return { + "type": "websocket", + "path": "/ws", + "raw_path": "/ws", + "root_path": "/", + "scheme": "ws", + "headers": [ + (b"host", b"localhost"), + *((key.encode("utf-8"), value.encode("utf-8")) for key, value in websocket_request_headers.items()), + ], + "subprotocols": [], + "server": ("localhost", 8000), + } + + +@pytest.mark.anyio +class TestASGIWebSocketAsyncNetworkStream: + async def test_write(self, scope: Scope) -> None: + received_messages: list[Message] = [] + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + message = await receive() + received_messages.append(message) + while message["type"] != "websocket.disconnect": + message = await receive() + received_messages.append(message) + + protocol = Protocol(Side.CLIENT, state=State.OPEN, max_size=None) + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + protocol.send_text(b"CLIENT_MESSAGE") + await stream.write(wire(protocol)) + + protocol.send_binary(b"CLIENT_MESSAGE") + await stream.write(wire(protocol)) + + protocol.send_close(1000) + await stream.write(wire(protocol)) + + # Add a small delay to ensure the app has processed all messages + await anyio.sleep(0.1) + + assert received_messages == [ + {"type": "websocket.connect"}, + {"type": "websocket.receive", "text": "CLIENT_MESSAGE"}, + {"type": "websocket.receive", "bytes": b"CLIENT_MESSAGE"}, + {"type": "websocket.disconnect", "code": 1000, "reason": ""}, + ] + + async def test_write_unhandled_event(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await receive() + + protocol = Protocol(Side.CLIENT, state=State.OPEN, max_size=None) + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + with pytest.raises(UnhandledWebSocketFrame): + protocol.send_ping(b"PING") + await stream.write(wire(protocol)) + + async def test_read(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await send({"type": "websocket.send", "text": "SERVER_MESSAGE"}) + await send({"type": "websocket.send", "bytes": b"SERVER_MESSAGE"}) + await send({"type": "websocket.close", "code": 1000, "reason": ""}) + + protocol = Protocol(Side.CLIENT, state=State.OPEN, max_size=None) + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + for _ in range(3): + data = await stream.read(4096) + protocol.receive_data(data) + + frames = [event for event in protocol.events_received() if isinstance(event, Frame)] + assert [frame.opcode for frame in frames] == [Opcode.TEXT, Opcode.BINARY, Opcode.CLOSE] + assert bytes(frames[0].data) == b"SERVER_MESSAGE" + assert bytes(frames[1].data) == b"SERVER_MESSAGE" + + async def test_read_unhandled_asgi_message(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await send({"type": "websocket.foo"}) + + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + with pytest.raises(UnhandledASGIMessageType): + await stream.read(4096) + + async def test_close_immediately(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.close", "code": 1000, "reason": ""}) + + with pytest.raises(ExceptionGroup) as excinfo: + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg), + ): + pass # pragma: no cover + assert excinfo.group_contains(WebSocketDisconnect) + + async def test_denial_response(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.http.response.start", "status": 401, "headers": []}) + await send({"type": "websocket.http.response.body", "body": b"Unauthorized"}) + + with pytest.raises(ExceptionGroup) as excinfo: + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg), + ): + pass # pragma: no cover + assert excinfo.group_contains(WebSocketUpgradeError) + upgrade_error = excinfo.value.exceptions[0] + assert isinstance(upgrade_error, WebSocketUpgradeError) + assert upgrade_error.response.status_code == 401 + assert upgrade_error.response.content == b"Unauthorized" + + async def test_exception(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + raise Exception("Error") + + with pytest.raises(ExceptionGroup) as excinfo: + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg), + ): + pass # pragma: no cover + assert excinfo.group_contains(WebSocketDisconnect) + disconnect = excinfo.value.exceptions[0] + assert isinstance(disconnect, WebSocketDisconnect) + assert disconnect.code == 1011 + assert disconnect.reason == "Error" + + async def test_never_accepts(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + return + + with pytest.raises(ExceptionGroup) as excinfo: + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg), + ): + pass # pragma: no cover + + assert excinfo.group_contains(RuntimeError) + + async def test_context_manager_twice(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await receive() + + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + with pytest.raises(RuntimeError): + await stream.__aenter__() + + async def test_app_exception_with_closed_send_queue(self, scope: Scope) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.accept"}) + await receive() + raise Exception("App error") + + async with ( + create_task_group() as tg, + ASGIWebSocketAsyncNetworkStream(app, scope, tg) as (stream, _), + ): + await stream._send_queue.aclose() + await stream.send({"type": "websocket.receive", "text": "trigger"}) + + +@pytest.fixture +def test_app() -> Starlette: + async def http_endpoint(request: Request) -> PlainTextResponse: + return PlainTextResponse("Hello, world!") + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.receive_text() + await websocket.close() # pragma: no cover + + routes = [ + Route("/http", endpoint=http_endpoint), + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + + return Starlette(routes=routes) + + +@pytest.mark.anyio +class TestASGIWebSocketTransport: + async def test_http(self, test_app: Starlette) -> None: + async with ASGIWebSocketTransport(app=test_app) as transport: + request = httpx2.Request("GET", "http://localhost:8000/http") + response = await transport.handle_async_request(request) + assert response.status_code == 200 + + @pytest.mark.parametrize( + "url,headers", + [ + ("ws://localhost:8000/ws", {}), + ("wss://localhost:8000/ws", {}), + ("http://localhost:8000/ws", {"upgrade": "websocket"}), + ], + ) + async def test_websocket( + self, + url: str, + headers: dict[str, typing.Any], + test_app: Starlette, + websocket_request_headers: dict[str, str], + ) -> None: + async with ASGIWebSocketTransport(app=test_app) as transport: + request = httpx2.Request("GET", url, headers={**websocket_request_headers, **headers}) + response = await transport.handle_async_request(request) + assert response.status_code == 101 + + assert isinstance(response.extensions["network_stream"], ASGIWebSocketAsyncNetworkStream) + + @pytest.mark.parametrize("stream_count", [1, 3]) + async def test_transport_exit_closes_stream_queues( + self, + stream_count: int, + test_app: Starlette, + websocket_request_headers: dict[str, str], + ) -> None: + async with ASGIWebSocketTransport(app=test_app) as transport: + streams: list[ASGIWebSocketAsyncNetworkStream] = [] + for _ in range(stream_count): + request = httpx2.Request( + "GET", + "ws://localhost:8000/ws", + headers=websocket_request_headers, + ) + response = await transport.handle_async_request(request) + streams.append(response.extensions["network_stream"]) + + for stream in streams: + with pytest.raises(ClosedResourceError): + await stream._receive_queue.send({}) + with pytest.raises(ClosedResourceError): + await stream._send_queue.send({}) + + async def test_aclose_after_transport_exit_does_not_raise( + self, + test_app: Starlette, + websocket_request_headers: dict[str, str], + ) -> None: + async with ASGIWebSocketTransport(app=test_app) as transport: + request = httpx2.Request("GET", "ws://localhost:8000/ws", headers=websocket_request_headers) + response = await transport.handle_async_request(request) + stream = response.extensions["network_stream"] + + await stream.aclose() + + +@pytest.mark.anyio +async def test_subprotocol_support() -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept("custom_protocol") + assert websocket.scope.get("subprotocols") == ["custom_protocol"] + await websocket.send_text("SERVER_MESSAGE") + await websocket.close() + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx2.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + async with client.websocket("ws://localhost:8000/ws", subprotocols=["custom_protocol"]) as ws: + await ws.receive_text() + assert ws.subprotocol == "custom_protocol" + + +@pytest.mark.anyio +async def test_keepalive_ping_disabled() -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.receive_text() + await websocket.close() # pragma: no cover + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx2.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + async with client.websocket("ws://localhost:8000/ws") as ws: + assert ws._keepalive_ping_interval_seconds is None + + +@pytest.mark.anyio +async def test_cancel_scope_integrity() -> None: + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + await websocket.receive_text() + await websocket.close() # pragma: no cover + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx2.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + with CancelScope(): + async with client.websocket("ws://localhost:8000/ws"): + pass + + +@pytest.mark.anyio +async def test_receive() -> None: + messages: list[str] = [] + + async def websocket_endpoint(websocket: WebSocket) -> None: + await websocket.accept() + messages.append(await websocket.receive_text()) + await websocket.close() + + app = Starlette( + routes=[ + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ] + ) + + async with httpx2.AsyncClient(transport=ASGIWebSocketTransport(app)) as client: + async with client.websocket("ws://localhost:8000/ws") as ws: + await ws.send_text("RESULT") + + assert len(messages) == 1 + assert messages[0] == "RESULT" diff --git a/uv.lock b/uv.lock index ef708e99..feeae90d 100644 --- a/uv.lock +++ b/uv.lock @@ -41,12 +41,15 @@ dev = [ { name = "pytest-httpbin", specifier = "==2.0.0" }, { name = "pytest-trio", specifier = "==0.8.0" }, { name = "ruff", specifier = "==0.15.13" }, + { name = "starlette", specifier = ">=0.49" }, { name = "trio", specifier = "==0.31.0" }, { name = "trio-typing", specifier = "==0.10.0" }, { name = "trustme", specifier = "==1.2.1" }, { name = "twine", specifier = "==6.1.0" }, { name = "uvicorn", specifier = ">=0.35" }, + { name = "websockets", specifier = ">=15" }, { name = "werkzeug", specifier = ">=3.1.6" }, + { name = "wsproto", specifier = ">=1.2" }, ] docs = [ { name = "mkdocstrings", extras = ["python"], specifier = ">=0.27" }, @@ -1354,6 +1357,7 @@ dependencies = [ { name = "idna" }, { name = "truststore" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "websockets" }, ] [package.optional-dependencies] @@ -1378,7 +1382,7 @@ zstd = [ [package.metadata] requires-dist = [ - { name = "anyio" }, + { name = "anyio", specifier = ">=4.10" }, { name = "brotli", marker = "platform_python_implementation == 'CPython' and extra == 'brotli'" }, { name = "brotlicffi", marker = "platform_python_implementation != 'CPython' and extra == 'brotli'" }, { name = "click", marker = "extra == 'cli'", specifier = "==8.*" }, @@ -1390,6 +1394,7 @@ requires-dist = [ { name = "socksio", marker = "extra == 'socks'", specifier = "==1.*" }, { name = "truststore", specifier = ">=0.10" }, { name = "typing-extensions", marker = "python_full_version < '3.13'", specifier = ">=4.5.0" }, + { name = "websockets", specifier = ">=15" }, { name = "zstandard", marker = "python_full_version < '3.14' and extra == 'zstd'", specifier = ">=0.18.0" }, ] provides-extras = ["brotli", "cli", "http2", "socks", "zstd"] @@ -3175,6 +3180,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, ] +[[package]] +name = "starlette" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/44/ec35f1b6e83094b997da438a02c8c9b0ade2b1e84cfc48bd4656780760a6/starlette-1.2.1.tar.gz", hash = "sha256:9b9b5ebb992e67d6093741e63c2f59e4f6fff986f81163c087867bd7b924b3f6", size = 2701854, upload-time = "2026-05-31T01:07:51.847Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/54/196d0c1db10af76baa4f64894448505d60d3cdf70ef92cbb35f46a4e4c71/starlette-1.2.1-py3-none-any.whl", hash = "sha256:4de0082d08c8f6764a85a54cf1120d6939507a19905c7768acad2a9f875d2b89", size = 73350, upload-time = "2026-05-31T01:07:50.09Z" }, +] + [[package]] name = "tomli" version = "2.4.1" @@ -3370,6 +3388,74 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "websockets" +version = "16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/74/221f58decd852f4b59cc3354cccaf87e8ef695fede361d03dc9a7396573b/websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a", size = 177343, upload-time = "2026-01-10T09:22:21.28Z" }, + { url = "https://files.pythonhosted.org/packages/19/0f/22ef6107ee52ab7f0b710d55d36f5a5d3ef19e8a205541a6d7ffa7994e5a/websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0", size = 175021, upload-time = "2026-01-10T09:22:22.696Z" }, + { url = "https://files.pythonhosted.org/packages/10/40/904a4cb30d9b61c0e278899bf36342e9b0208eb3c470324a9ecbaac2a30f/websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957", size = 175320, upload-time = "2026-01-10T09:22:23.94Z" }, + { url = "https://files.pythonhosted.org/packages/9d/2f/4b3ca7e106bc608744b1cdae041e005e446124bebb037b18799c2d356864/websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72", size = 183815, upload-time = "2026-01-10T09:22:25.469Z" }, + { url = "https://files.pythonhosted.org/packages/86/26/d40eaa2a46d4302becec8d15b0fc5e45bdde05191e7628405a19cf491ccd/websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde", size = 185054, upload-time = "2026-01-10T09:22:27.101Z" }, + { url = "https://files.pythonhosted.org/packages/b0/ba/6500a0efc94f7373ee8fefa8c271acdfd4dca8bd49a90d4be7ccabfc397e/websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3", size = 184565, upload-time = "2026-01-10T09:22:28.293Z" }, + { url = "https://files.pythonhosted.org/packages/04/b4/96bf2cee7c8d8102389374a2616200574f5f01128d1082f44102140344cc/websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3", size = 183848, upload-time = "2026-01-10T09:22:30.394Z" }, + { url = "https://files.pythonhosted.org/packages/02/8e/81f40fb00fd125357814e8c3025738fc4ffc3da4b6b4a4472a82ba304b41/websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9", size = 178249, upload-time = "2026-01-10T09:22:32.083Z" }, + { url = "https://files.pythonhosted.org/packages/b4/5f/7e40efe8df57db9b91c88a43690ac66f7b7aa73a11aa6a66b927e44f26fa/websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35", size = 178685, upload-time = "2026-01-10T09:22:33.345Z" }, + { url = "https://files.pythonhosted.org/packages/f2/db/de907251b4ff46ae804ad0409809504153b3f30984daf82a1d84a9875830/websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8", size = 177340, upload-time = "2026-01-10T09:22:34.539Z" }, + { url = "https://files.pythonhosted.org/packages/f3/fa/abe89019d8d8815c8781e90d697dec52523fb8ebe308bf11664e8de1877e/websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad", size = 175022, upload-time = "2026-01-10T09:22:36.332Z" }, + { url = "https://files.pythonhosted.org/packages/58/5d/88ea17ed1ded2079358b40d31d48abe90a73c9e5819dbcde1606e991e2ad/websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d", size = 175319, upload-time = "2026-01-10T09:22:37.602Z" }, + { url = "https://files.pythonhosted.org/packages/d2/ae/0ee92b33087a33632f37a635e11e1d99d429d3d323329675a6022312aac2/websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe", size = 184631, upload-time = "2026-01-10T09:22:38.789Z" }, + { url = "https://files.pythonhosted.org/packages/c8/c5/27178df583b6c5b31b29f526ba2da5e2f864ecc79c99dae630a85d68c304/websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b", size = 185870, upload-time = "2026-01-10T09:22:39.893Z" }, + { url = "https://files.pythonhosted.org/packages/87/05/536652aa84ddc1c018dbb7e2c4cbcd0db884580bf8e95aece7593fde526f/websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5", size = 185361, upload-time = "2026-01-10T09:22:41.016Z" }, + { url = "https://files.pythonhosted.org/packages/6d/e2/d5332c90da12b1e01f06fb1b85c50cfc489783076547415bf9f0a659ec19/websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64", size = 184615, upload-time = "2026-01-10T09:22:42.442Z" }, + { url = "https://files.pythonhosted.org/packages/77/fb/d3f9576691cae9253b51555f841bc6600bf0a983a461c79500ace5a5b364/websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6", size = 178246, upload-time = "2026-01-10T09:22:43.654Z" }, + { url = "https://files.pythonhosted.org/packages/54/67/eaff76b3dbaf18dcddabc3b8c1dba50b483761cccff67793897945b37408/websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac", size = 178684, upload-time = "2026-01-10T09:22:44.941Z" }, + { url = "https://files.pythonhosted.org/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00", size = 177365, upload-time = "2026-01-10T09:22:46.787Z" }, + { url = "https://files.pythonhosted.org/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79", size = 175038, upload-time = "2026-01-10T09:22:47.999Z" }, + { url = "https://files.pythonhosted.org/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39", size = 175328, upload-time = "2026-01-10T09:22:49.809Z" }, + { url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" }, + { url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" }, + { url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" }, + { url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" }, + { url = "https://files.pythonhosted.org/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89", size = 178261, upload-time = "2026-01-10T09:22:56.251Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea", size = 178693, upload-time = "2026-01-10T09:22:57.478Z" }, + { url = "https://files.pythonhosted.org/packages/cc/9c/baa8456050d1c1b08dd0ec7346026668cbc6f145ab4e314d707bb845bf0d/websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9", size = 177364, upload-time = "2026-01-10T09:22:59.333Z" }, + { url = "https://files.pythonhosted.org/packages/7e/0c/8811fc53e9bcff68fe7de2bcbe75116a8d959ac699a3200f4847a8925210/websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230", size = 175039, upload-time = "2026-01-10T09:23:01.171Z" }, + { url = "https://files.pythonhosted.org/packages/aa/82/39a5f910cb99ec0b59e482971238c845af9220d3ab9fa76dd9162cda9d62/websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c", size = 175323, upload-time = "2026-01-10T09:23:02.341Z" }, + { url = "https://files.pythonhosted.org/packages/bd/28/0a25ee5342eb5d5f297d992a77e56892ecb65e7854c7898fb7d35e9b33bd/websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5", size = 184975, upload-time = "2026-01-10T09:23:03.756Z" }, + { url = "https://files.pythonhosted.org/packages/f9/66/27ea52741752f5107c2e41fda05e8395a682a1e11c4e592a809a90c6a506/websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82", size = 186203, upload-time = "2026-01-10T09:23:05.01Z" }, + { url = "https://files.pythonhosted.org/packages/37/e5/8e32857371406a757816a2b471939d51c463509be73fa538216ea52b792a/websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8", size = 185653, upload-time = "2026-01-10T09:23:06.301Z" }, + { url = "https://files.pythonhosted.org/packages/9b/67/f926bac29882894669368dc73f4da900fcdf47955d0a0185d60103df5737/websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f", size = 184920, upload-time = "2026-01-10T09:23:07.492Z" }, + { url = "https://files.pythonhosted.org/packages/3c/a1/3d6ccdcd125b0a42a311bcd15a7f705d688f73b2a22d8cf1c0875d35d34a/websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a", size = 178255, upload-time = "2026-01-10T09:23:09.245Z" }, + { url = "https://files.pythonhosted.org/packages/6b/ae/90366304d7c2ce80f9b826096a9e9048b4bb760e44d3b873bb272cba696b/websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156", size = 178689, upload-time = "2026-01-10T09:23:10.483Z" }, + { url = "https://files.pythonhosted.org/packages/f3/1d/e88022630271f5bd349ed82417136281931e558d628dd52c4d8621b4a0b2/websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0", size = 177406, upload-time = "2026-01-10T09:23:12.178Z" }, + { url = "https://files.pythonhosted.org/packages/f2/78/e63be1bf0724eeb4616efb1ae1c9044f7c3953b7957799abb5915bffd38e/websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904", size = 175085, upload-time = "2026-01-10T09:23:13.511Z" }, + { url = "https://files.pythonhosted.org/packages/bb/f4/d3c9220d818ee955ae390cf319a7c7a467beceb24f05ee7aaaa2414345ba/websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4", size = 175328, upload-time = "2026-01-10T09:23:14.727Z" }, + { url = "https://files.pythonhosted.org/packages/63/bc/d3e208028de777087e6fb2b122051a6ff7bbcca0d6df9d9c2bf1dd869ae9/websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e", size = 185044, upload-time = "2026-01-10T09:23:15.939Z" }, + { url = "https://files.pythonhosted.org/packages/ad/6e/9a0927ac24bd33a0a9af834d89e0abc7cfd8e13bed17a86407a66773cc0e/websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4", size = 186279, upload-time = "2026-01-10T09:23:17.148Z" }, + { url = "https://files.pythonhosted.org/packages/b9/ca/bf1c68440d7a868180e11be653c85959502efd3a709323230314fda6e0b3/websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1", size = 185711, upload-time = "2026-01-10T09:23:18.372Z" }, + { url = "https://files.pythonhosted.org/packages/c4/f8/fdc34643a989561f217bb477cbc47a3a07212cbda91c0e4389c43c296ebf/websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3", size = 184982, upload-time = "2026-01-10T09:23:19.652Z" }, + { url = "https://files.pythonhosted.org/packages/dd/d1/574fa27e233764dbac9c52730d63fcf2823b16f0856b3329fc6268d6ae4f/websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8", size = 177915, upload-time = "2026-01-10T09:23:21.458Z" }, + { url = "https://files.pythonhosted.org/packages/8a/f1/ae6b937bf3126b5134ce1f482365fde31a357c784ac51852978768b5eff4/websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d", size = 178381, upload-time = "2026-01-10T09:23:22.715Z" }, + { url = "https://files.pythonhosted.org/packages/06/9b/f791d1db48403e1f0a27577a6beb37afae94254a8c6f08be4a23e4930bc0/websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244", size = 177737, upload-time = "2026-01-10T09:23:24.523Z" }, + { url = "https://files.pythonhosted.org/packages/bd/40/53ad02341fa33b3ce489023f635367a4ac98b73570102ad2cdd770dacc9a/websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e", size = 175268, upload-time = "2026-01-10T09:23:25.781Z" }, + { url = "https://files.pythonhosted.org/packages/74/9b/6158d4e459b984f949dcbbb0c5d270154c7618e11c01029b9bbd1bb4c4f9/websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641", size = 175486, upload-time = "2026-01-10T09:23:27.033Z" }, + { url = "https://files.pythonhosted.org/packages/e5/2d/7583b30208b639c8090206f95073646c2c9ffd66f44df967981a64f849ad/websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8", size = 185331, upload-time = "2026-01-10T09:23:28.259Z" }, + { url = "https://files.pythonhosted.org/packages/45/b0/cce3784eb519b7b5ad680d14b9673a31ab8dcb7aad8b64d81709d2430aa8/websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e", size = 186501, upload-time = "2026-01-10T09:23:29.449Z" }, + { url = "https://files.pythonhosted.org/packages/19/60/b8ebe4c7e89fb5f6cdf080623c9d92789a53636950f7abacfc33fe2b3135/websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944", size = 186062, upload-time = "2026-01-10T09:23:31.368Z" }, + { url = "https://files.pythonhosted.org/packages/88/a8/a080593f89b0138b6cba1b28f8df5673b5506f72879322288b031337c0b8/websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206", size = 185356, upload-time = "2026-01-10T09:23:32.627Z" }, + { url = "https://files.pythonhosted.org/packages/c2/b6/b9afed2afadddaf5ebb2afa801abf4b0868f42f8539bfe4b071b5266c9fe/websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6", size = 178085, upload-time = "2026-01-10T09:23:33.816Z" }, + { url = "https://files.pythonhosted.org/packages/9f/3e/28135a24e384493fa804216b79a6a6759a38cc4ff59118787b9fb693df93/websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd", size = 178531, upload-time = "2026-01-10T09:23:35.016Z" }, + { url = "https://files.pythonhosted.org/packages/72/07/c98a68571dcf256e74f1f816b8cc5eae6eb2d3d5cfa44d37f801619d9166/websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d", size = 174947, upload-time = "2026-01-10T09:23:36.166Z" }, + { url = "https://files.pythonhosted.org/packages/7e/52/93e166a81e0305b33fe416338be92ae863563fe7bce446b0f687b9df5aea/websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03", size = 175260, upload-time = "2026-01-10T09:23:37.409Z" }, + { url = "https://files.pythonhosted.org/packages/56/0c/2dbf513bafd24889d33de2ff0368190a0e69f37bcfa19009ef819fe4d507/websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da", size = 176071, upload-time = "2026-01-10T09:23:39.158Z" }, + { url = "https://files.pythonhosted.org/packages/a5/8f/aea9c71cc92bf9b6cc0f7f70df8f0b420636b6c96ef4feee1e16f80f75dd/websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c", size = 176968, upload-time = "2026-01-10T09:23:41.031Z" }, + { url = "https://files.pythonhosted.org/packages/9a/3f/f70e03f40ffc9a30d817eef7da1be72ee4956ba8d7255c399a01b135902a/websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767", size = 178735, upload-time = "2026-01-10T09:23:42.259Z" }, + { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, +] + [[package]] name = "werkzeug" version = "3.1.8" @@ -3382,6 +3468,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/93/8c/2e650f2afeb7ee576912636c23ddb621c91ac6a98e66dc8d29c3c69446e1/werkzeug-3.1.8-py3-none-any.whl", hash = "sha256:63a77fb8892bf28ebc3178683445222aa500e48ebad5ec77b0ad80f8726b1f50", size = 226459, upload-time = "2026-04-02T18:49:12.72Z" }, ] +[[package]] +name = "wsproto" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" }, +] + [[package]] name = "yarl" version = "1.23.0"