From c663265206f1fc489ff2092d1a23c613f47e4676 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 12 Jun 2026 10:06:48 +0200 Subject: [PATCH 1/2] Stream response body in ASGITransport --- docs/advanced/transports.md | 14 +++ src/httpx2/httpx2/_transports/asgi.py | 166 +++++++++++++------------- tests/httpx2/test_asgi.py | 117 ++++++++++++++++++ 3 files changed, 212 insertions(+), 85 deletions(-) diff --git a/docs/advanced/transports.md b/docs/advanced/transports.md index 79f29c15..f120e31c 100644 --- a/docs/advanced/transports.md +++ b/docs/advanced/transports.md @@ -140,6 +140,20 @@ async with httpx2.AsyncClient(transport=transport, base_url="http://testserver") See [the ASGI documentation](https://asgi.readthedocs.io/en/latest/specs/www.html#connection-scope) for more details on the `client` and `root_path` keys. +### Streaming responses + +The ASGI transport streams response bodies. The app runs in a separate task, and a response is returned as soon as the app sends the response start, which generally happens before the app has fully run. This makes it possible to test streaming endpoints, such as server-sent events, by iterating over the response: + +```python +transport = httpx2.ASGITransport(app=app) +async with httpx2.AsyncClient(transport=transport, base_url="http://testserver") as client: + async with client.stream("GET", "/sse") as response: + async for chunk in response.aiter_text(): + ... +``` + +Because the app runs in a separate task, context variables set within the app are not visible to the caller. + ### ASGI startup and shutdown It is not in the scope of HTTPX to trigger ASGI lifespan events of your app. diff --git a/src/httpx2/httpx2/_transports/asgi.py b/src/httpx2/httpx2/_transports/asgi.py index 33f25db3..62c71d33 100644 --- a/src/httpx2/httpx2/_transports/asgi.py +++ b/src/httpx2/httpx2/_transports/asgi.py @@ -2,58 +2,46 @@ import typing +import anyio + from .._models import Request, Response from .._types import AsyncByteStream from .base import AsyncBaseTransport -if typing.TYPE_CHECKING: - import asyncio - - import trio - - Event = asyncio.Event | trio.Event - - _Message = typing.MutableMapping[str, typing.Any] _Receive = typing.Callable[[], typing.Awaitable[_Message]] -_Send = typing.Callable[[typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]] +_Send = typing.Callable[[_Message], typing.Awaitable[None]] _ASGIApp = typing.Callable[[typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]] __all__ = ["ASGITransport"] -def is_running_trio() -> bool: - try: - # sniffio is a dependency of trio. - - # See https://github.com/python-trio/trio/issues/2802 - import sniffio - - if sniffio.current_async_library() == "trio": - return True - except ImportError: # pragma: no cover - pass - - return False - - -def create_event() -> Event: - if is_running_trio(): - import trio - - return trio.Event() - - import asyncio - - return asyncio.Event() - - class ASGIResponseStream(AsyncByteStream): - def __init__(self, body: list[bytes]) -> None: - self._body = body + def __init__( + self, + ignore_body: bool, + asgi_messages: typing.AsyncGenerator[_Message, None], + disconnect_request: anyio.Event, + ) -> None: + self._ignore_body = ignore_body + self._asgi_messages = asgi_messages + self._disconnect_request = disconnect_request async def __aiter__(self) -> typing.AsyncIterator[bytes]: - yield b"".join(self._body) + more_body = True + async for message in self._asgi_messages: + if message["type"] == "http.response.body": + assert more_body + body = message.get("body", b"") + more_body = message.get("more_body", False) + if body and not self._ignore_body: + yield body + if not more_body: + self._disconnect_request.set() + + async def aclose(self) -> None: + self._disconnect_request.set() + await self._asgi_messages.aclose() class ASGITransport(AsyncBaseTransport): @@ -69,6 +57,10 @@ class ASGITransport(AsyncBaseTransport): client = httpx2.AsyncClient(transport=transport) ``` + The app is run in a separate task, and response events are streamed as soon as + they arrive. A response is returned as soon as the app sends the response start, + which generally happens before the app has fully run. + Arguments: app: The ASGI application. raise_app_exceptions: Boolean indicating if exceptions in the application @@ -91,6 +83,27 @@ def __init__( self.client = client async def handle_async_request(self, request: Request) -> Response: + disconnect_request = anyio.Event() + asgi_messages = self._run_app(request, disconnect_request) + + async for message in asgi_messages: + if message["type"] == "http.response.start": + return Response( + status_code=message["status"], + headers=message.get("headers", []), + stream=ASGIResponseStream( + ignore_body=request.method == "HEAD", + asgi_messages=asgi_messages, + disconnect_request=disconnect_request, + ), + ) + + disconnect_request.set() + return Response(status_code=500, headers=[]) + + async def _run_app( + self, request: Request, disconnect_request: anyio.Event + ) -> typing.AsyncGenerator[_Message, None]: assert isinstance(request.stream, AsyncByteStream) # ASGI scope. @@ -114,19 +127,16 @@ async def handle_async_request(self, request: Request) -> Response: request_complete = False # Response. - status_code = None - response_headers = None - body_parts: list[bytes] = [] - response_started = False - response_complete = create_event() + send_channel, receive_channel = anyio.create_memory_object_stream[_Message]() + app_exception: Exception | None = None # ASGI callables. - async def receive() -> dict[str, typing.Any]: + async def receive() -> _Message: nonlocal request_complete if request_complete: - await response_complete.wait() + await disconnect_request.wait() return {"type": "http.disconnect"} try: @@ -136,43 +146,29 @@ async def receive() -> dict[str, typing.Any]: return {"type": "http.request", "body": b"", "more_body": False} return {"type": "http.request", "body": body, "more_body": True} - async def send(message: typing.MutableMapping[str, typing.Any]) -> None: - nonlocal status_code, response_headers, response_started - - if message["type"] == "http.response.start": - assert not response_started - - status_code = message["status"] - response_headers = message.get("headers", []) - response_started = True - - elif message["type"] == "http.response.body": - assert not response_complete.is_set() - body = message.get("body", b"") - more_body = message.get("more_body", False) - - if body and request.method != "HEAD": - body_parts.append(body) - - if not more_body: - response_complete.set() - - try: - await self.app(scope, receive, send) - except Exception: - if self.raise_app_exceptions: - raise - - response_complete.set() - if status_code is None: - status_code = 500 - if response_headers is None: - response_headers = {} - - assert response_complete.is_set() - assert status_code is not None - assert response_headers is not None - - stream = ASGIResponseStream(body_parts) - - return Response(status_code, headers=response_headers, stream=stream) + async def run_app() -> None: + nonlocal app_exception + try: + await self.app(scope, receive, send_channel.send) + except Exception as exc: + app_exception = exc + finally: + send_channel.close() + + closed = False + + async with anyio.create_task_group() as task_group: + task_group.start_soon(run_app) + async with receive_channel: + try: + async for message in receive_channel: + yield message + except GeneratorExit: + # A `GeneratorExit` must not propagate through the task group, + # which would wrap it in an exception group. Cancel the app and + # return instead. + closed = True + task_group.cancel_scope.cancel() + + if not closed and app_exception is not None and self.raise_app_exceptions: + raise app_exception diff --git a/tests/httpx2/test_asgi.py b/tests/httpx2/test_asgi.py index 13f7bf1c..ed8b6359 100644 --- a/tests/httpx2/test_asgi.py +++ b/tests/httpx2/test_asgi.py @@ -1,6 +1,7 @@ import json import typing +import anyio import pytest import httpx2 @@ -9,6 +10,15 @@ Receive = typing.Callable[[], typing.Awaitable[Message]] Send = typing.Callable[[typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]] Scope = typing.MutableMapping[str, typing.Any] +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] + + +def run_in_task_group(app: ASGIApp) -> ASGIApp: + async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None: + async with anyio.create_task_group() as task_group: + task_group.start_soon(app, scope, receive, send) + + return wrapped_app async def hello_world(scope: Scope, receive: Receive, send: Send) -> None: @@ -64,6 +74,15 @@ async def raise_exc(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError() +async def raise_exc_after_response_start(scope: Scope, receive: Receive, send: Send) -> None: + status = 200 + output = b"Hello, World!" + headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))] + + await send({"type": "http.response.start", "status": status, "headers": headers}) + raise RuntimeError() + + async def raise_exc_after_response(scope: Scope, receive: Receive, send: Send) -> None: status = 200 output = b"Hello, World!" @@ -176,6 +195,14 @@ async def test_asgi_exc() -> None: await client.get("http://www.example.org/") +@pytest.mark.anyio +async def test_asgi_exc_after_response_start() -> None: + transport = httpx2.ASGITransport(app=raise_exc_after_response_start) + async with httpx2.AsyncClient(transport=transport) as client: + with pytest.raises(RuntimeError): + await client.get("http://www.example.org/") + + @pytest.mark.anyio async def test_asgi_exc_after_response() -> None: transport = httpx2.ASGITransport(app=raise_exc_after_response) @@ -224,3 +251,93 @@ async def test_asgi_exc_no_raise() -> None: response = await client.get("http://www.example.org/") assert response.status_code == 500 + + +@pytest.mark.anyio +async def test_asgi_exc_no_raise_after_response_start() -> None: + transport = httpx2.ASGITransport(app=raise_exc_after_response_start, raise_app_exceptions=False) + async with httpx2.AsyncClient(transport=transport) as client: + response = await client.get("http://www.example.org/") + + assert response.status_code == 200 + + +@pytest.mark.anyio +async def test_asgi_exc_no_raise_after_response() -> None: + transport = httpx2.ASGITransport(app=raise_exc_after_response, raise_app_exceptions=False) + async with httpx2.AsyncClient(transport=transport) as client: + response = await client.get("http://www.example.org/") + + assert response.status_code == 200 + + +@pytest.mark.parametrize( + "send_in_sub_task", [pytest.param(False, id="no_sub_task"), pytest.param(True, id="with_sub_task")] +) +@pytest.mark.anyio +async def test_asgi_stream_returns_before_waiting_for_body(send_in_sub_task: bool) -> None: + start_response_body = anyio.Event() + + async def send_response_body_after_event(scope: Scope, receive: Receive, send: Send) -> None: + status = 200 + headers = [(b"content-type", b"text/plain")] + await send({"type": "http.response.start", "status": status, "headers": headers}) + await start_response_body.wait() + await send({"type": "http.response.body", "body": b"body", "more_body": False}) + + app = run_in_task_group(send_response_body_after_event) if send_in_sub_task else send_response_body_after_event + + transport = httpx2.ASGITransport(app=app) + async with httpx2.AsyncClient(transport=transport) as client: + with anyio.fail_after(1): + async with client.stream("GET", "http://www.example.org/") as response: + assert response.status_code == 200 + start_response_body.set() + await response.aread() + assert response.text == "body" + + +@pytest.mark.parametrize( + "send_in_sub_task", [pytest.param(False, id="no_sub_task"), pytest.param(True, id="with_sub_task")] +) +@pytest.mark.anyio +async def test_asgi_stream_allows_iterative_streaming(send_in_sub_task: bool) -> None: + stream_events = [anyio.Event() for _ in range(4)] + + async def send_response_body_after_event(scope: Scope, receive: Receive, send: Send) -> None: + status = 200 + headers = [(b"content-type", b"text/plain")] + await send({"type": "http.response.start", "status": status, "headers": headers}) + for event in stream_events: + await event.wait() + await send({"type": "http.response.body", "body": b"chunk", "more_body": event is not stream_events[-1]}) + + app = run_in_task_group(send_response_body_after_event) if send_in_sub_task else send_response_body_after_event + + transport = httpx2.ASGITransport(app=app) + async with httpx2.AsyncClient(transport=transport) as client: + with anyio.fail_after(1): + async with client.stream("GET", "http://www.example.org/") as response: + assert response.status_code == 200 + iterator = response.aiter_raw() + for event in stream_events: + event.set() + assert await iterator.__anext__() == b"chunk" + with pytest.raises(StopAsyncIteration): + await iterator.__anext__() + + +@pytest.mark.anyio +async def test_asgi_stream_early_close() -> None: + async def stream_forever(scope: Scope, receive: Receive, send: Send) -> None: + status = 200 + headers = [(b"content-type", b"text/plain")] + await send({"type": "http.response.start", "status": status, "headers": headers}) + while True: + await send({"type": "http.response.body", "body": b"chunk", "more_body": True}) + + transport = httpx2.ASGITransport(app=stream_forever) + async with httpx2.AsyncClient(transport=transport) as client: + with anyio.fail_after(1): + async with client.stream("GET", "http://www.example.org/") as response: + assert response.status_code == 200 From 620f9b8a2986d0f78aa61b83c1a9b18804d40968 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 12 Jun 2026 10:10:33 +0200 Subject: [PATCH 2/2] Keep the disconnect event local to the ASGI message generator --- src/httpx2/httpx2/_transports/asgi.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/httpx2/httpx2/_transports/asgi.py b/src/httpx2/httpx2/_transports/asgi.py index 62c71d33..fcbdf672 100644 --- a/src/httpx2/httpx2/_transports/asgi.py +++ b/src/httpx2/httpx2/_transports/asgi.py @@ -17,15 +17,9 @@ class ASGIResponseStream(AsyncByteStream): - def __init__( - self, - ignore_body: bool, - asgi_messages: typing.AsyncGenerator[_Message, None], - disconnect_request: anyio.Event, - ) -> None: + def __init__(self, ignore_body: bool, asgi_messages: typing.AsyncGenerator[_Message, None]) -> None: self._ignore_body = ignore_body self._asgi_messages = asgi_messages - self._disconnect_request = disconnect_request async def __aiter__(self) -> typing.AsyncIterator[bytes]: more_body = True @@ -36,11 +30,8 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: more_body = message.get("more_body", False) if body and not self._ignore_body: yield body - if not more_body: - self._disconnect_request.set() async def aclose(self) -> None: - self._disconnect_request.set() await self._asgi_messages.aclose() @@ -83,8 +74,7 @@ def __init__( self.client = client async def handle_async_request(self, request: Request) -> Response: - disconnect_request = anyio.Event() - asgi_messages = self._run_app(request, disconnect_request) + asgi_messages = self._run_app(request) async for message in asgi_messages: if message["type"] == "http.response.start": @@ -94,16 +84,12 @@ async def handle_async_request(self, request: Request) -> Response: stream=ASGIResponseStream( ignore_body=request.method == "HEAD", asgi_messages=asgi_messages, - disconnect_request=disconnect_request, ), ) - disconnect_request.set() return Response(status_code=500, headers=[]) - async def _run_app( - self, request: Request, disconnect_request: anyio.Event - ) -> typing.AsyncGenerator[_Message, None]: + async def _run_app(self, request: Request) -> typing.AsyncGenerator[_Message, None]: assert isinstance(request.stream, AsyncByteStream) # ASGI scope. @@ -127,6 +113,7 @@ async def _run_app( request_complete = False # Response. + disconnect_request = anyio.Event() send_channel, receive_channel = anyio.create_memory_object_stream[_Message]() app_exception: Exception | None = None @@ -162,6 +149,8 @@ async def run_app() -> None: async with receive_channel: try: async for message in receive_channel: + if message["type"] == "http.response.body" and not message.get("more_body", False): + disconnect_request.set() yield message except GeneratorExit: # A `GeneratorExit` must not propagate through the task group,