Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/advanced/transports.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,20 @@ async with httpx2.AsyncClient(transport=transport, base_url="http://testserver")

See [the ASGI documentation](https://asgi.readthedocs.io/en/latest/specs/www.html#connection-scope) for more details on the `client` and `root_path` keys.

### Streaming responses

The ASGI transport streams response bodies. The app runs in a separate task, and a response is returned as soon as the app sends the response start, which generally happens before the app has fully run. This makes it possible to test streaming endpoints, such as server-sent events, by iterating over the response:

```python
transport = httpx2.ASGITransport(app=app)
async with httpx2.AsyncClient(transport=transport, base_url="http://testserver") as client:
async with client.stream("GET", "/sse") as response:
async for chunk in response.aiter_text():
...
```

Because the app runs in a separate task, context variables set within the app are not visible to the caller.

### ASGI startup and shutdown

It is not in the scope of HTTPX to trigger ASGI lifespan events of your app.
Expand Down
155 changes: 70 additions & 85 deletions src/httpx2/httpx2/_transports/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,37 @@

import typing

import anyio

from .._models import Request, Response
from .._types import AsyncByteStream
from .base import AsyncBaseTransport

if typing.TYPE_CHECKING:
import asyncio

import trio

Event = asyncio.Event | trio.Event


_Message = typing.MutableMapping[str, typing.Any]
_Receive = typing.Callable[[], typing.Awaitable[_Message]]
_Send = typing.Callable[[typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]]
_Send = typing.Callable[[_Message], typing.Awaitable[None]]
_ASGIApp = typing.Callable[[typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]]

__all__ = ["ASGITransport"]


def is_running_trio() -> bool:
try:
# sniffio is a dependency of trio.

# See https://github.com/python-trio/trio/issues/2802
import sniffio

if sniffio.current_async_library() == "trio":
return True
except ImportError: # pragma: no cover
pass

return False


def create_event() -> Event:
if is_running_trio():
import trio

return trio.Event()

import asyncio

return asyncio.Event()


class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: list[bytes]) -> None:
self._body = body
def __init__(self, ignore_body: bool, asgi_messages: typing.AsyncGenerator[_Message, None]) -> None:
self._ignore_body = ignore_body
self._asgi_messages = asgi_messages

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b"".join(self._body)
more_body = True
async for message in self._asgi_messages:
if message["type"] == "http.response.body":
assert more_body
body = message.get("body", b"")
more_body = message.get("more_body", False)
if body and not self._ignore_body:
yield body

async def aclose(self) -> None:
await self._asgi_messages.aclose()


class ASGITransport(AsyncBaseTransport):
Expand All @@ -69,6 +48,10 @@ class ASGITransport(AsyncBaseTransport):
client = httpx2.AsyncClient(transport=transport)
```

The app is run in a separate task, and response events are streamed as soon as
they arrive. A response is returned as soon as the app sends the response start,
which generally happens before the app has fully run.

Arguments:
app: The ASGI application.
raise_app_exceptions: Boolean indicating if exceptions in the application
Expand All @@ -91,6 +74,22 @@ def __init__(
self.client = client

async def handle_async_request(self, request: Request) -> Response:
asgi_messages = self._run_app(request)

async for message in asgi_messages:
if message["type"] == "http.response.start":
return Response(
status_code=message["status"],
headers=message.get("headers", []),
stream=ASGIResponseStream(
ignore_body=request.method == "HEAD",
asgi_messages=asgi_messages,
),
)

return Response(status_code=500, headers=[])

async def _run_app(self, request: Request) -> typing.AsyncGenerator[_Message, None]:
assert isinstance(request.stream, AsyncByteStream)

# ASGI scope.
Expand All @@ -114,19 +113,17 @@ async def handle_async_request(self, request: Request) -> Response:
request_complete = False

# Response.
status_code = None
response_headers = None
body_parts: list[bytes] = []
response_started = False
response_complete = create_event()
disconnect_request = anyio.Event()
send_channel, receive_channel = anyio.create_memory_object_stream[_Message]()
app_exception: Exception | None = None

# ASGI callables.

async def receive() -> dict[str, typing.Any]:
async def receive() -> _Message:
nonlocal request_complete

if request_complete:
await response_complete.wait()
await disconnect_request.wait()
return {"type": "http.disconnect"}

try:
Expand All @@ -136,43 +133,31 @@ async def receive() -> dict[str, typing.Any]:
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
nonlocal status_code, response_headers, response_started

if message["type"] == "http.response.start":
assert not response_started

status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and request.method != "HEAD":
body_parts.append(body)

if not more_body:
response_complete.set()

try:
await self.app(scope, receive, send)
except Exception:
if self.raise_app_exceptions:
raise

response_complete.set()
if status_code is None:
status_code = 500
if response_headers is None:
response_headers = {}

assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None

stream = ASGIResponseStream(body_parts)

return Response(status_code, headers=response_headers, stream=stream)
async def run_app() -> None:
nonlocal app_exception
try:
await self.app(scope, receive, send_channel.send)
except Exception as exc:
app_exception = exc
finally:
send_channel.close()

closed = False

async with anyio.create_task_group() as task_group:
task_group.start_soon(run_app)
async with receive_channel:
try:
async for message in receive_channel:
if message["type"] == "http.response.body" and not message.get("more_body", False):
disconnect_request.set()
yield message
except GeneratorExit:
# A `GeneratorExit` must not propagate through the task group,
# which would wrap it in an exception group. Cancel the app and
# return instead.
closed = True
task_group.cancel_scope.cancel()

if not closed and app_exception is not None and self.raise_app_exceptions:
raise app_exception
117 changes: 117 additions & 0 deletions tests/httpx2/test_asgi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import typing

import anyio
import pytest

import httpx2
Expand All @@ -9,6 +10,15 @@
Receive = typing.Callable[[], typing.Awaitable[Message]]
Send = typing.Callable[[typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]]
Scope = typing.MutableMapping[str, typing.Any]
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]


def run_in_task_group(app: ASGIApp) -> ASGIApp:
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
task_group.start_soon(app, scope, receive, send)

return wrapped_app


async def hello_world(scope: Scope, receive: Receive, send: Send) -> None:
Expand Down Expand Up @@ -64,6 +74,15 @@ async def raise_exc(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError()


async def raise_exc_after_response_start(scope: Scope, receive: Receive, send: Send) -> None:
status = 200
output = b"Hello, World!"
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]

await send({"type": "http.response.start", "status": status, "headers": headers})
raise RuntimeError()


async def raise_exc_after_response(scope: Scope, receive: Receive, send: Send) -> None:
status = 200
output = b"Hello, World!"
Expand Down Expand Up @@ -176,6 +195,14 @@ async def test_asgi_exc() -> None:
await client.get("http://www.example.org/")


@pytest.mark.anyio
async def test_asgi_exc_after_response_start() -> None:
transport = httpx2.ASGITransport(app=raise_exc_after_response_start)
async with httpx2.AsyncClient(transport=transport) as client:
with pytest.raises(RuntimeError):
await client.get("http://www.example.org/")


@pytest.mark.anyio
async def test_asgi_exc_after_response() -> None:
transport = httpx2.ASGITransport(app=raise_exc_after_response)
Expand Down Expand Up @@ -224,3 +251,93 @@ async def test_asgi_exc_no_raise() -> None:
response = await client.get("http://www.example.org/")

assert response.status_code == 500


@pytest.mark.anyio
async def test_asgi_exc_no_raise_after_response_start() -> None:
transport = httpx2.ASGITransport(app=raise_exc_after_response_start, raise_app_exceptions=False)
async with httpx2.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")

assert response.status_code == 200


@pytest.mark.anyio
async def test_asgi_exc_no_raise_after_response() -> None:
transport = httpx2.ASGITransport(app=raise_exc_after_response, raise_app_exceptions=False)
async with httpx2.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")

assert response.status_code == 200


@pytest.mark.parametrize(
"send_in_sub_task", [pytest.param(False, id="no_sub_task"), pytest.param(True, id="with_sub_task")]
)
@pytest.mark.anyio
async def test_asgi_stream_returns_before_waiting_for_body(send_in_sub_task: bool) -> None:
start_response_body = anyio.Event()

async def send_response_body_after_event(scope: Scope, receive: Receive, send: Send) -> None:
status = 200
headers = [(b"content-type", b"text/plain")]
await send({"type": "http.response.start", "status": status, "headers": headers})
await start_response_body.wait()
await send({"type": "http.response.body", "body": b"body", "more_body": False})

app = run_in_task_group(send_response_body_after_event) if send_in_sub_task else send_response_body_after_event

transport = httpx2.ASGITransport(app=app)
async with httpx2.AsyncClient(transport=transport) as client:
with anyio.fail_after(1):
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
start_response_body.set()
await response.aread()
assert response.text == "body"


@pytest.mark.parametrize(
"send_in_sub_task", [pytest.param(False, id="no_sub_task"), pytest.param(True, id="with_sub_task")]
)
@pytest.mark.anyio
async def test_asgi_stream_allows_iterative_streaming(send_in_sub_task: bool) -> None:
stream_events = [anyio.Event() for _ in range(4)]

async def send_response_body_after_event(scope: Scope, receive: Receive, send: Send) -> None:
status = 200
headers = [(b"content-type", b"text/plain")]
await send({"type": "http.response.start", "status": status, "headers": headers})
for event in stream_events:
await event.wait()
await send({"type": "http.response.body", "body": b"chunk", "more_body": event is not stream_events[-1]})

app = run_in_task_group(send_response_body_after_event) if send_in_sub_task else send_response_body_after_event

transport = httpx2.ASGITransport(app=app)
async with httpx2.AsyncClient(transport=transport) as client:
with anyio.fail_after(1):
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
iterator = response.aiter_raw()
for event in stream_events:
event.set()
assert await iterator.__anext__() == b"chunk"
with pytest.raises(StopAsyncIteration):
await iterator.__anext__()


@pytest.mark.anyio
async def test_asgi_stream_early_close() -> None:
async def stream_forever(scope: Scope, receive: Receive, send: Send) -> None:
status = 200
headers = [(b"content-type", b"text/plain")]
await send({"type": "http.response.start", "status": status, "headers": headers})
while True:
await send({"type": "http.response.body", "body": b"chunk", "more_body": True})

transport = httpx2.ASGITransport(app=stream_forever)
async with httpx2.AsyncClient(transport=transport) as client:
with anyio.fail_after(1):
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
Loading