From bd2e06eb0723a7b1910a19a99ee2274eda2d4dba Mon Sep 17 00:00:00 2001 From: sarmientoF Date: Sat, 13 Jun 2026 15:22:10 +0900 Subject: [PATCH] fix(http2): respect stream capacity in pool Queue or open another connection when existing HTTP/2 connections are at the remote stream cap instead of assigning requests to full connections. --- httpcore/_async/connection.py | 8 ++ httpcore/_async/connection_pool.py | 80 +++++++++--- httpcore/_async/http11.py | 3 + httpcore/_async/http2.py | 63 ++++++++-- httpcore/_async/http_proxy.py | 12 ++ httpcore/_async/interfaces.py | 7 ++ httpcore/_async/socks_proxy.py | 9 ++ httpcore/_sync/connection.py | 8 ++ httpcore/_sync/connection_pool.py | 80 +++++++++--- httpcore/_sync/http11.py | 3 + httpcore/_sync/http2.py | 63 ++++++++-- httpcore/_sync/http_proxy.py | 12 ++ httpcore/_sync/interfaces.py | 7 ++ httpcore/_sync/socks_proxy.py | 9 ++ httpcore/_synchronization.py | 19 +++ tests/_async/test_connection_pool.py | 177 +++++++++++++++++++++++++++ tests/_async/test_http2.py | 129 +++++++++++++++++++ tests/_sync/test_connection_pool.py | 177 +++++++++++++++++++++++++++ tests/_sync/test_http2.py | 129 +++++++++++++++++++ tests/test_cancellations.py | 4 +- 20 files changed, 937 insertions(+), 62 deletions(-) diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index b42581dff..b527f9507 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -48,6 +48,7 @@ def __init__( uds: str | None = None, network_backend: AsyncNetworkBackend | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + on_capacity_update: typing.Callable[[], typing.Any] | None = None, ) -> None: self._origin = origin self._ssl_context = ssl_context @@ -65,6 +66,7 @@ def __init__( self._connect_failed: bool = False self._request_lock = AsyncLock() self._socket_options = socket_options + self._on_capacity_update = on_capacity_update async def handle_async_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): @@ -89,6 +91,7 @@ async def handle_async_request(self, request: Request) -> Response: origin=self._origin, stream=stream, keepalive_expiry=self._keepalive_expiry, + on_capacity_update=self._on_capacity_update, ) else: self._connection = AsyncHTTP11Connection( @@ -184,6 +187,11 @@ def is_available(self) -> bool: ) return self._connection.is_available() + def max_concurrent_requests(self) -> int: + if self._connection is None: + return 1 + return self._connection.max_concurrent_requests() + def has_expired(self) -> bool: if self._connection is None: return self._connect_failed diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 5ef74e649..09106b0a5 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -139,6 +139,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface: http1=self._http1, http2=self._http2, network_backend=self._network_backend, + on_capacity_update=self._connection_capacity_updated, ) elif origin.scheme == b"http": from .http_proxy import AsyncForwardHTTPConnection @@ -150,6 +151,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface: remote_origin=origin, keepalive_expiry=self._keepalive_expiry, network_backend=self._network_backend, + on_capacity_update=self._connection_capacity_updated, ) from .http_proxy import AsyncTunnelHTTPConnection @@ -163,6 +165,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface: http1=self._http1, http2=self._http2, network_backend=self._network_backend, + on_capacity_update=self._connection_capacity_updated, ) return AsyncHTTPConnection( @@ -176,6 +179,7 @@ def create_connection(self, origin: Origin) -> AsyncConnectionInterface: uds=self._uds, network_backend=self._network_backend, socket_options=self._socket_options, + on_capacity_update=self._connection_capacity_updated, ) @property @@ -289,27 +293,42 @@ def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]: # log: "closing expired connection" self._connections.remove(connection) closing_connections.append(connection) - elif ( - connection.is_idle() - and sum(connection.is_idle() for connection in self._connections) - > self._max_keepalive_connections - ): + + idle_connection_count = sum( + connection.is_idle() for connection in self._connections + ) + for connection in list(self._connections): + if idle_connection_count <= self._max_keepalive_connections: + break + if connection.is_idle(): # log: "closing idle connection" self._connections.remove(connection) closing_connections.append(connection) + idle_connection_count -= 1 # Assign queued requests to connections. queued_requests = [request for request in self._requests if request.is_queued()] + connection_request_count = dict.fromkeys(self._connections, 0) + for request in self._requests: + request_connection = request.connection + if request_connection in connection_request_count: + connection_request_count[request_connection] += 1 + for pool_request in queued_requests: origin = pool_request.request.url.origin - available_connections = [ - connection - for connection in self._connections - if connection.can_handle_request(origin) and connection.is_available() - ] - idle_connections = [ - connection for connection in self._connections if connection.is_idle() - ] + available_connection = next( + ( + connection + for connection in self._connections + if ( + connection.can_handle_request(origin) + and connection.is_available() + and connection_request_count[connection] + < self._max_concurrent_requests(connection) + ) + ), + None, + ) # There are three cases for how we may be able to handle the request: # @@ -317,27 +336,50 @@ def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]: # 2. We can create a new connection to handle the request. # 3. We can close an idle connection and then create a new connection # to handle the request. - if available_connections: + if available_connection is not None: # log: "reusing existing connection" - connection = available_connections[0] + connection = available_connection pool_request.assign_to_connection(connection) + connection_request_count[connection] += 1 elif len(self._connections) < self._max_connections: # log: "creating new connection" connection = self.create_connection(origin) self._connections.append(connection) pool_request.assign_to_connection(connection) - elif idle_connections: + connection_request_count[connection] = 1 + else: + idle_connection = next( + ( + connection + for connection in self._connections + if connection.is_idle() + ), + None, + ) + if idle_connection is None: + continue # log: "closing idle connection" - connection = idle_connections[0] - self._connections.remove(connection) - closing_connections.append(connection) + self._connections.remove(idle_connection) + closing_connections.append(idle_connection) # log: "creating new connection" connection = self.create_connection(origin) self._connections.append(connection) pool_request.assign_to_connection(connection) + connection_request_count[connection] = 1 return closing_connections + def _max_concurrent_requests(self, connection: AsyncConnectionInterface) -> int: + try: + return int(connection.max_concurrent_requests()) + except AttributeError: # pragma: nocover + return 1 + + async def _connection_capacity_updated(self) -> None: + with self._optional_thread_lock: + closing = self._assign_requests_to_connections() + await self._close_connections(closing) + async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None: # Close connections which have been removed from the pool. with AsyncShieldCancellation(): diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index e6d6d7098..b3ce5dde0 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -271,6 +271,9 @@ def is_available(self) -> bool: # acquired from the connection pool for any other request. return self._state == HTTPConnectionState.IDLE + def max_concurrent_requests(self) -> int: + return 1 + def has_expired(self) -> bool: now = time.monotonic() keepalive_expired = self._expire_at is not None and now > self._expire_at diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index dbd0beeb4..7a439d352 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -48,10 +48,12 @@ def __init__( origin: Origin, stream: AsyncNetworkStream, keepalive_expiry: float | None = None, + on_capacity_update: typing.Callable[[], typing.Any] | None = None, ): self._origin = origin self._network_stream = stream self._keepalive_expiry: float | None = keepalive_expiry + self._on_capacity_update = on_capacity_update self._h2_state = h2.connection.H2Connection(config=self.CONFIG) self._state = HTTPConnectionState.IDLE self._expire_at: float | None = None @@ -74,6 +76,7 @@ def __init__( | h2.events.StreamReset, ], ] = {} + self._closed_streams: set[int] = set() # Connection terminated events are stored as state since # we need to handle them for all streams. @@ -95,6 +98,8 @@ async def handle_async_request(self, request: Request) -> Response: async with self._state_lock: if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): + previous_state = self._state + previous_expire_at = self._expire_at self._request_count += 1 self._expire_at = None self._state = HTTPConnectionState.ACTIVE @@ -128,7 +133,13 @@ async def handle_async_request(self, request: Request) -> Response: for _ in range(local_settings_max_streams - self._max_streams): await self._max_streams_semaphore.acquire() - await self._max_streams_semaphore.acquire() + if not self._max_streams_semaphore.acquire_nowait(): + async with self._state_lock: + self._request_count -= 1 + if not self._events: # pragma: nocover + self._state = previous_state + self._expire_at = previous_expire_at + raise ConnectionNotAvailable() try: stream_id = self._h2_state.get_next_available_stream_id() @@ -136,6 +147,7 @@ async def handle_async_request(self, request: Request) -> Response: except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover self._used_all_stream_ids = True self._request_count -= 1 + await self._max_streams_semaphore.release() raise ConnectionNotAvailable() try: @@ -380,6 +392,10 @@ async def _receive_events( ), ): if event.stream_id in self._events: + if isinstance( + event, (h2.events.StreamEnded, h2.events.StreamReset) + ): + self._closed_streams.add(event.stream_id) self._events[event.stream_id].append(event) elif isinstance(event, h2.events.ConnectionTerminated): @@ -399,27 +415,48 @@ async def _receive_remote_settings_change( self._h2_state.local_settings.max_concurrent_streams, ) if new_max_streams and new_max_streams != self._max_streams: - while new_max_streams > self._max_streams: + active_stream_count = len(self._events) + old_available_streams = max(0, self._max_streams - active_stream_count) + new_available_streams = max(0, new_max_streams - active_stream_count) + self._max_streams = new_max_streams + while new_available_streams > old_available_streams: await self._max_streams_semaphore.release() - self._max_streams += 1 - while new_max_streams < self._max_streams: + old_available_streams += 1 + while new_available_streams < old_available_streams: await self._max_streams_semaphore.acquire() - self._max_streams -= 1 + old_available_streams -= 1 + if self._on_capacity_update is not None: + await self._on_capacity_update() async def _response_closed(self, stream_id: int) -> None: - await self._max_streams_semaphore.release() + stream_was_reset = stream_id not in self._closed_streams + if stream_was_reset: + # Keep h2's stream state aligned without blocking close/cancel on I/O. + # Any pending RST_STREAM data will be flushed by the next write. + try: + self._h2_state.reset_stream(stream_id) + except ( + h2.exceptions.NoSuchStreamError, + h2.exceptions.StreamClosedError, + h2.exceptions.ProtocolError, + ): + pass + if len(self._events) <= self._max_streams: + await self._max_streams_semaphore.release() + self._closed_streams.discard(stream_id) del self._events[stream_id] async with self._state_lock: if self._connection_terminated and not self._events: await self.aclose() elif self._state == HTTPConnectionState.ACTIVE and not self._events: - self._state = HTTPConnectionState.IDLE - if self._keepalive_expiry is not None: - now = time.monotonic() - self._expire_at = now + self._keepalive_expiry - if self._used_all_stream_ids: # pragma: nocover + if stream_was_reset or self._used_all_stream_ids: await self.aclose() + else: + self._state = HTTPConnectionState.IDLE + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry async def aclose(self) -> None: # Note that this method unilaterally closes the connection, and does @@ -513,12 +550,16 @@ def is_available(self) -> bool: self._state != HTTPConnectionState.CLOSED and not self._connection_error and not self._used_all_stream_ids + and len(self._events) < self.max_concurrent_requests() and not ( self._h2_state.state_machine.state == h2.connection.ConnectionState.CLOSED ) ) + def max_concurrent_requests(self) -> int: + return self._max_streams if self._sent_connection_init else 1 + def has_expired(self) -> bool: now = time.monotonic() return self._expire_at is not None and now > self._expire_at diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index cc9d92066..a1cd3f36b 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -176,6 +176,7 @@ def __init__( network_backend: AsyncNetworkBackend | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, proxy_ssl_context: ssl.SSLContext | None = None, + on_capacity_update: typing.Callable[[], typing.Any] | None = None, ) -> None: self._connection = AsyncHTTPConnection( origin=proxy_origin, @@ -183,6 +184,7 @@ def __init__( network_backend=network_backend, socket_options=socket_options, ssl_context=proxy_ssl_context, + on_capacity_update=on_capacity_update, ) self._proxy_origin = proxy_origin self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") @@ -217,6 +219,9 @@ def info(self) -> str: def is_available(self) -> bool: return self._connection.is_available() + def max_concurrent_requests(self) -> int: # pragma: nocover + return self._connection.max_concurrent_requests() + def has_expired(self) -> bool: return self._connection.has_expired() @@ -243,6 +248,7 @@ def __init__( http2: bool = False, network_backend: AsyncNetworkBackend | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + on_capacity_update: typing.Callable[[], typing.Any] | None = None, ) -> None: self._connection: AsyncConnectionInterface = AsyncHTTPConnection( origin=proxy_origin, @@ -250,6 +256,7 @@ def __init__( network_backend=network_backend, socket_options=socket_options, ssl_context=proxy_ssl_context, + on_capacity_update=on_capacity_update, ) self._proxy_origin = proxy_origin self._remote_origin = remote_origin @@ -259,6 +266,7 @@ def __init__( self._keepalive_expiry = keepalive_expiry self._http1 = http1 self._http2 = http2 + self._on_capacity_update = on_capacity_update self._connect_lock = AsyncLock() self._connected = False @@ -331,6 +339,7 @@ async def handle_async_request(self, request: Request) -> Response: origin=self._remote_origin, stream=stream, keepalive_expiry=self._keepalive_expiry, + on_capacity_update=self._on_capacity_update, ) else: self._connection = AsyncHTTP11Connection( @@ -354,6 +363,9 @@ def info(self) -> str: def is_available(self) -> bool: return self._connection.is_available() + def max_concurrent_requests(self) -> int: # pragma: nocover + return self._connection.max_concurrent_requests() + def has_expired(self) -> bool: return self._connection.has_expired() diff --git a/httpcore/_async/interfaces.py b/httpcore/_async/interfaces.py index 361583bed..d9b146d85 100644 --- a/httpcore/_async/interfaces.py +++ b/httpcore/_async/interfaces.py @@ -112,6 +112,13 @@ def is_available(self) -> bool: """ raise NotImplementedError() # pragma: nocover + def max_concurrent_requests(self) -> int: # pragma: nocover + """ + Return the maximum number of requests that may be assigned to this + connection at a given time. + """ + return 1 + def has_expired(self) -> bool: """ Return `True` if the connection is in a state where it should be closed. diff --git a/httpcore/_async/socks_proxy.py b/httpcore/_async/socks_proxy.py index b363f55a0..e68fad257 100644 --- a/httpcore/_async/socks_proxy.py +++ b/httpcore/_async/socks_proxy.py @@ -2,6 +2,7 @@ import logging import ssl +import typing import socksio @@ -197,6 +198,7 @@ def __init__( http1: bool = True, http2: bool = False, network_backend: AsyncNetworkBackend | None = None, + on_capacity_update: typing.Callable[[], typing.Any] | None = None, ) -> None: self._proxy_origin = proxy_origin self._remote_origin = remote_origin @@ -205,6 +207,7 @@ def __init__( self._keepalive_expiry = keepalive_expiry self._http1 = http1 self._http2 = http2 + self._on_capacity_update = on_capacity_update self._network_backend: AsyncNetworkBackend = ( AutoBackend() if network_backend is None else network_backend @@ -283,6 +286,7 @@ async def handle_async_request(self, request: Request) -> Response: origin=self._remote_origin, stream=stream, keepalive_expiry=self._keepalive_expiry, + on_capacity_update=self._on_capacity_update, ) else: self._connection = AsyncHTTP11Connection( @@ -317,6 +321,11 @@ def is_available(self) -> bool: ) return self._connection.is_available() + def max_concurrent_requests(self) -> int: # pragma: nocover + if self._connection is None: # pragma: nocover + return 1 + return self._connection.max_concurrent_requests() + def has_expired(self) -> bool: if self._connection is None: # pragma: nocover return self._connect_failed diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 363f8be81..254022609 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -48,6 +48,7 @@ def __init__( uds: str | None = None, network_backend: NetworkBackend | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + on_capacity_update: typing.Callable[[], typing.Any] | None = None, ) -> None: self._origin = origin self._ssl_context = ssl_context @@ -65,6 +66,7 @@ def __init__( self._connect_failed: bool = False self._request_lock = Lock() self._socket_options = socket_options + self._on_capacity_update = on_capacity_update def handle_request(self, request: Request) -> Response: if not self.can_handle_request(request.url.origin): @@ -89,6 +91,7 @@ def handle_request(self, request: Request) -> Response: origin=self._origin, stream=stream, keepalive_expiry=self._keepalive_expiry, + on_capacity_update=self._on_capacity_update, ) else: self._connection = HTTP11Connection( @@ -184,6 +187,11 @@ def is_available(self) -> bool: ) return self._connection.is_available() + def max_concurrent_requests(self) -> int: + if self._connection is None: + return 1 + return self._connection.max_concurrent_requests() + def has_expired(self) -> bool: if self._connection is None: return self._connect_failed diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 4b26f9c63..013aa03c6 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -139,6 +139,7 @@ def create_connection(self, origin: Origin) -> ConnectionInterface: http1=self._http1, http2=self._http2, network_backend=self._network_backend, + on_capacity_update=self._connection_capacity_updated, ) elif origin.scheme == b"http": from .http_proxy import ForwardHTTPConnection @@ -150,6 +151,7 @@ def create_connection(self, origin: Origin) -> ConnectionInterface: remote_origin=origin, keepalive_expiry=self._keepalive_expiry, network_backend=self._network_backend, + on_capacity_update=self._connection_capacity_updated, ) from .http_proxy import TunnelHTTPConnection @@ -163,6 +165,7 @@ def create_connection(self, origin: Origin) -> ConnectionInterface: http1=self._http1, http2=self._http2, network_backend=self._network_backend, + on_capacity_update=self._connection_capacity_updated, ) return HTTPConnection( @@ -176,6 +179,7 @@ def create_connection(self, origin: Origin) -> ConnectionInterface: uds=self._uds, network_backend=self._network_backend, socket_options=self._socket_options, + on_capacity_update=self._connection_capacity_updated, ) @property @@ -289,27 +293,42 @@ def _assign_requests_to_connections(self) -> list[ConnectionInterface]: # log: "closing expired connection" self._connections.remove(connection) closing_connections.append(connection) - elif ( - connection.is_idle() - and sum(connection.is_idle() for connection in self._connections) - > self._max_keepalive_connections - ): + + idle_connection_count = sum( + connection.is_idle() for connection in self._connections + ) + for connection in list(self._connections): + if idle_connection_count <= self._max_keepalive_connections: + break + if connection.is_idle(): # log: "closing idle connection" self._connections.remove(connection) closing_connections.append(connection) + idle_connection_count -= 1 # Assign queued requests to connections. queued_requests = [request for request in self._requests if request.is_queued()] + connection_request_count = dict.fromkeys(self._connections, 0) + for request in self._requests: + request_connection = request.connection + if request_connection in connection_request_count: + connection_request_count[request_connection] += 1 + for pool_request in queued_requests: origin = pool_request.request.url.origin - available_connections = [ - connection - for connection in self._connections - if connection.can_handle_request(origin) and connection.is_available() - ] - idle_connections = [ - connection for connection in self._connections if connection.is_idle() - ] + available_connection = next( + ( + connection + for connection in self._connections + if ( + connection.can_handle_request(origin) + and connection.is_available() + and connection_request_count[connection] + < self._max_concurrent_requests(connection) + ) + ), + None, + ) # There are three cases for how we may be able to handle the request: # @@ -317,27 +336,50 @@ def _assign_requests_to_connections(self) -> list[ConnectionInterface]: # 2. We can create a new connection to handle the request. # 3. We can close an idle connection and then create a new connection # to handle the request. - if available_connections: + if available_connection is not None: # log: "reusing existing connection" - connection = available_connections[0] + connection = available_connection pool_request.assign_to_connection(connection) + connection_request_count[connection] += 1 elif len(self._connections) < self._max_connections: # log: "creating new connection" connection = self.create_connection(origin) self._connections.append(connection) pool_request.assign_to_connection(connection) - elif idle_connections: + connection_request_count[connection] = 1 + else: + idle_connection = next( + ( + connection + for connection in self._connections + if connection.is_idle() + ), + None, + ) + if idle_connection is None: + continue # log: "closing idle connection" - connection = idle_connections[0] - self._connections.remove(connection) - closing_connections.append(connection) + self._connections.remove(idle_connection) + closing_connections.append(idle_connection) # log: "creating new connection" connection = self.create_connection(origin) self._connections.append(connection) pool_request.assign_to_connection(connection) + connection_request_count[connection] = 1 return closing_connections + def _max_concurrent_requests(self, connection: ConnectionInterface) -> int: + try: + return int(connection.max_concurrent_requests()) + except AttributeError: # pragma: nocover + return 1 + + def _connection_capacity_updated(self) -> None: + with self._optional_thread_lock: + closing = self._assign_requests_to_connections() + self._close_connections(closing) + def _close_connections(self, closing: list[ConnectionInterface]) -> None: # Close connections which have been removed from the pool. with ShieldCancellation(): diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index ebd3a9748..a0462f169 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -271,6 +271,9 @@ def is_available(self) -> bool: # acquired from the connection pool for any other request. return self._state == HTTPConnectionState.IDLE + def max_concurrent_requests(self) -> int: + return 1 + def has_expired(self) -> bool: now = time.monotonic() keepalive_expired = self._expire_at is not None and now > self._expire_at diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index ddcc18900..85f16c6c2 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -48,10 +48,12 @@ def __init__( origin: Origin, stream: NetworkStream, keepalive_expiry: float | None = None, + on_capacity_update: typing.Callable[[], typing.Any] | None = None, ): self._origin = origin self._network_stream = stream self._keepalive_expiry: float | None = keepalive_expiry + self._on_capacity_update = on_capacity_update self._h2_state = h2.connection.H2Connection(config=self.CONFIG) self._state = HTTPConnectionState.IDLE self._expire_at: float | None = None @@ -74,6 +76,7 @@ def __init__( | h2.events.StreamReset, ], ] = {} + self._closed_streams: set[int] = set() # Connection terminated events are stored as state since # we need to handle them for all streams. @@ -95,6 +98,8 @@ def handle_request(self, request: Request) -> Response: with self._state_lock: if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): + previous_state = self._state + previous_expire_at = self._expire_at self._request_count += 1 self._expire_at = None self._state = HTTPConnectionState.ACTIVE @@ -128,7 +133,13 @@ def handle_request(self, request: Request) -> Response: for _ in range(local_settings_max_streams - self._max_streams): self._max_streams_semaphore.acquire() - self._max_streams_semaphore.acquire() + if not self._max_streams_semaphore.acquire_nowait(): + with self._state_lock: + self._request_count -= 1 + if not self._events: # pragma: nocover + self._state = previous_state + self._expire_at = previous_expire_at + raise ConnectionNotAvailable() try: stream_id = self._h2_state.get_next_available_stream_id() @@ -136,6 +147,7 @@ def handle_request(self, request: Request) -> Response: except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover self._used_all_stream_ids = True self._request_count -= 1 + self._max_streams_semaphore.release() raise ConnectionNotAvailable() try: @@ -380,6 +392,10 @@ def _receive_events( ), ): if event.stream_id in self._events: + if isinstance( + event, (h2.events.StreamEnded, h2.events.StreamReset) + ): + self._closed_streams.add(event.stream_id) self._events[event.stream_id].append(event) elif isinstance(event, h2.events.ConnectionTerminated): @@ -399,27 +415,48 @@ def _receive_remote_settings_change( self._h2_state.local_settings.max_concurrent_streams, ) if new_max_streams and new_max_streams != self._max_streams: - while new_max_streams > self._max_streams: + active_stream_count = len(self._events) + old_available_streams = max(0, self._max_streams - active_stream_count) + new_available_streams = max(0, new_max_streams - active_stream_count) + self._max_streams = new_max_streams + while new_available_streams > old_available_streams: self._max_streams_semaphore.release() - self._max_streams += 1 - while new_max_streams < self._max_streams: + old_available_streams += 1 + while new_available_streams < old_available_streams: self._max_streams_semaphore.acquire() - self._max_streams -= 1 + old_available_streams -= 1 + if self._on_capacity_update is not None: + self._on_capacity_update() def _response_closed(self, stream_id: int) -> None: - self._max_streams_semaphore.release() + stream_was_reset = stream_id not in self._closed_streams + if stream_was_reset: + # Keep h2's stream state aligned without blocking close/cancel on I/O. + # Any pending RST_STREAM data will be flushed by the next write. + try: + self._h2_state.reset_stream(stream_id) + except ( + h2.exceptions.NoSuchStreamError, + h2.exceptions.StreamClosedError, + h2.exceptions.ProtocolError, + ): + pass + if len(self._events) <= self._max_streams: + self._max_streams_semaphore.release() + self._closed_streams.discard(stream_id) del self._events[stream_id] with self._state_lock: if self._connection_terminated and not self._events: self.close() elif self._state == HTTPConnectionState.ACTIVE and not self._events: - self._state = HTTPConnectionState.IDLE - if self._keepalive_expiry is not None: - now = time.monotonic() - self._expire_at = now + self._keepalive_expiry - if self._used_all_stream_ids: # pragma: nocover + if stream_was_reset or self._used_all_stream_ids: self.close() + else: + self._state = HTTPConnectionState.IDLE + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry def close(self) -> None: # Note that this method unilaterally closes the connection, and does @@ -513,12 +550,16 @@ def is_available(self) -> bool: self._state != HTTPConnectionState.CLOSED and not self._connection_error and not self._used_all_stream_ids + and len(self._events) < self.max_concurrent_requests() and not ( self._h2_state.state_machine.state == h2.connection.ConnectionState.CLOSED ) ) + def max_concurrent_requests(self) -> int: + return self._max_streams if self._sent_connection_init else 1 + def has_expired(self) -> bool: now = time.monotonic() return self._expire_at is not None and now > self._expire_at diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index ecca88f7d..a4b188c2e 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -176,6 +176,7 @@ def __init__( network_backend: NetworkBackend | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, proxy_ssl_context: ssl.SSLContext | None = None, + on_capacity_update: typing.Callable[[], typing.Any] | None = None, ) -> None: self._connection = HTTPConnection( origin=proxy_origin, @@ -183,6 +184,7 @@ def __init__( network_backend=network_backend, socket_options=socket_options, ssl_context=proxy_ssl_context, + on_capacity_update=on_capacity_update, ) self._proxy_origin = proxy_origin self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") @@ -217,6 +219,9 @@ def info(self) -> str: def is_available(self) -> bool: return self._connection.is_available() + def max_concurrent_requests(self) -> int: # pragma: nocover + return self._connection.max_concurrent_requests() + def has_expired(self) -> bool: return self._connection.has_expired() @@ -243,6 +248,7 @@ def __init__( http2: bool = False, network_backend: NetworkBackend | None = None, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + on_capacity_update: typing.Callable[[], typing.Any] | None = None, ) -> None: self._connection: ConnectionInterface = HTTPConnection( origin=proxy_origin, @@ -250,6 +256,7 @@ def __init__( network_backend=network_backend, socket_options=socket_options, ssl_context=proxy_ssl_context, + on_capacity_update=on_capacity_update, ) self._proxy_origin = proxy_origin self._remote_origin = remote_origin @@ -259,6 +266,7 @@ def __init__( self._keepalive_expiry = keepalive_expiry self._http1 = http1 self._http2 = http2 + self._on_capacity_update = on_capacity_update self._connect_lock = Lock() self._connected = False @@ -331,6 +339,7 @@ def handle_request(self, request: Request) -> Response: origin=self._remote_origin, stream=stream, keepalive_expiry=self._keepalive_expiry, + on_capacity_update=self._on_capacity_update, ) else: self._connection = HTTP11Connection( @@ -354,6 +363,9 @@ def info(self) -> str: def is_available(self) -> bool: return self._connection.is_available() + def max_concurrent_requests(self) -> int: # pragma: nocover + return self._connection.max_concurrent_requests() + def has_expired(self) -> bool: return self._connection.has_expired() diff --git a/httpcore/_sync/interfaces.py b/httpcore/_sync/interfaces.py index e673d4cc1..9a87e74ec 100644 --- a/httpcore/_sync/interfaces.py +++ b/httpcore/_sync/interfaces.py @@ -112,6 +112,13 @@ def is_available(self) -> bool: """ raise NotImplementedError() # pragma: nocover + def max_concurrent_requests(self) -> int: # pragma: nocover + """ + Return the maximum number of requests that may be assigned to this + connection at a given time. + """ + return 1 + def has_expired(self) -> bool: """ Return `True` if the connection is in a state where it should be closed. diff --git a/httpcore/_sync/socks_proxy.py b/httpcore/_sync/socks_proxy.py index 0ca96ddfb..7a983ac7f 100644 --- a/httpcore/_sync/socks_proxy.py +++ b/httpcore/_sync/socks_proxy.py @@ -2,6 +2,7 @@ import logging import ssl +import typing import socksio @@ -197,6 +198,7 @@ def __init__( http1: bool = True, http2: bool = False, network_backend: NetworkBackend | None = None, + on_capacity_update: typing.Callable[[], typing.Any] | None = None, ) -> None: self._proxy_origin = proxy_origin self._remote_origin = remote_origin @@ -205,6 +207,7 @@ def __init__( self._keepalive_expiry = keepalive_expiry self._http1 = http1 self._http2 = http2 + self._on_capacity_update = on_capacity_update self._network_backend: NetworkBackend = ( SyncBackend() if network_backend is None else network_backend @@ -283,6 +286,7 @@ def handle_request(self, request: Request) -> Response: origin=self._remote_origin, stream=stream, keepalive_expiry=self._keepalive_expiry, + on_capacity_update=self._on_capacity_update, ) else: self._connection = HTTP11Connection( @@ -317,6 +321,11 @@ def is_available(self) -> bool: ) return self._connection.is_available() + def max_concurrent_requests(self) -> int: # pragma: nocover + if self._connection is None: # pragma: nocover + return 1 + return self._connection.max_concurrent_requests() + def has_expired(self) -> bool: if self._connection is None: # pragma: nocover return self._connect_failed diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index 2ecc9e9c3..3023b7d09 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -180,6 +180,22 @@ async def acquire(self) -> None: elif self._backend == "asyncio": await self._anyio_semaphore.acquire() + def acquire_nowait(self) -> bool: + if not self._backend: # pragma: nocover + self.setup() + + if self._backend == "trio": + try: + self._trio_semaphore.acquire_nowait() + except trio.WouldBlock: + return False + elif self._backend == "asyncio": + try: + self._anyio_semaphore.acquire_nowait() + except anyio.WouldBlock: + return False + return True + async def release(self) -> None: if self._backend == "trio": self._trio_semaphore.release() @@ -298,6 +314,9 @@ def __init__(self, bound: int) -> None: def acquire(self) -> None: self._semaphore.acquire() + def acquire_nowait(self) -> bool: + return self._semaphore.acquire(blocking=False) + def release(self) -> None: self._semaphore.release() diff --git a/tests/_async/test_connection_pool.py b/tests/_async/test_connection_pool.py index bc4b251e3..5e21c7ca2 100644 --- a/tests/_async/test_connection_pool.py +++ b/tests/_async/test_connection_pool.py @@ -1,3 +1,4 @@ +import importlib import logging import typing @@ -9,6 +10,56 @@ import httpcore +class AsyncNonBlockingSemaphore: + def __init__(self, bound: int) -> None: + self._value = bound + + async def acquire(self) -> None: + if self._value <= 0: # pragma: nocover + raise RuntimeError("stream semaphore exhausted") + self._value -= 1 + + def acquire_nowait(self) -> bool: + if self._value <= 0: # pragma: nocover + return False + self._value -= 1 + return True + + async def release(self) -> None: + self._value += 1 + + +def http2_settings(max_concurrent_streams: int) -> bytes: + return hyperframe.frame.SettingsFrame( + settings={ + hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: max_concurrent_streams + } + ).serialize() + + +def http2_response_headers(stream_id: int) -> bytes: + return hyperframe.frame.HeadersFrame( + stream_id=stream_id, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize() + + +def use_non_blocking_semaphore(monkeypatch: pytest.MonkeyPatch) -> None: + module_name = ( + "httpcore._async.http2" + if httpcore.AsyncHTTP2Connection.__name__.startswith("Async") + else "httpcore._sync.http2" + ) + http2_module = importlib.import_module(module_name) + monkeypatch.setattr(http2_module, "AsyncSemaphore", AsyncNonBlockingSemaphore) + + @pytest.mark.anyio async def test_connection_pool_with_keepalive(): """ @@ -204,6 +255,132 @@ async def test_connection_pool_with_http2(): ] +@pytest.mark.anyio +async def test_connection_pool_with_http2_connecting_connection_capacity_update(): + class CapacityKnownConnection: + def can_handle_request( + self, origin: httpcore.Origin + ) -> bool: # pragma: nocover + return True + + def is_available(self) -> bool: + return True + + def max_concurrent_requests(self) -> int: + return 2 + + def has_expired(self) -> bool: + return False + + def is_idle(self) -> bool: + return False + + def is_closed(self) -> bool: + return False + + async def aclose(self) -> None: + pass + + def info(self) -> str: # pragma: nocover + return "HTTP/2, ACTIVE" + + class QueuedPoolRequest: + def __init__(self) -> None: + self.request = httpcore.Request("GET", "https://example.com/") + self.connection: typing.Any = None + + def is_queued(self) -> bool: + return self.connection is None + + def assign_to_connection(self, connection: typing.Any) -> None: + self.connection = connection + + async with httpcore.AsyncConnectionPool( + http2=True, + max_connections=1, + network_backend=httpcore.AsyncMockBackend([], http2=True), + ) as pool: + pool_request_1 = QueuedPoolRequest() + pool_request_2 = QueuedPoolRequest() + pool._requests = typing.cast(typing.Any, [pool_request_1, pool_request_2]) + + assert pool._assign_requests_to_connections() == [] + assert len(pool.connections) == 1 + assert pool_request_1.connection is not None + assert pool_request_2.connection is None + assert pool_request_1.connection.max_concurrent_requests() == 1 + + typing.cast( + typing.Any, pool_request_1.connection + )._connection = CapacityKnownConnection() + await pool._connection_capacity_updated() + assert pool_request_2.connection is pool_request_1.connection + + +@pytest.mark.anyio +async def test_connection_pool_with_http2_stream_limit_opens_extra_connection( + monkeypatch, +): + use_non_blocking_semaphore(monkeypatch) + network_backend = httpcore.AsyncMockBackend( + buffer=[ + http2_settings(max_concurrent_streams=2), + http2_response_headers(stream_id=1), + http2_response_headers(stream_id=3), + http2_response_headers(stream_id=5), + ], + http2=True, + ) + + async with httpcore.AsyncConnectionPool( + network_backend=network_backend, + max_connections=3, + ) as pool: + async with pool.stream("GET", "https://example.com/") as response_1: + assert response_1.status == 200 + async with pool.stream("GET", "https://example.com/") as response_2: + assert response_2.status == 200 + assert len(pool.connections) == 1 + async with pool.stream("GET", "https://example.com/") as response_3: + assert response_3.status == 200 + assert len(pool.connections) == 2 + + +@pytest.mark.anyio +async def test_connection_pool_with_http2_stream_limit_queues(monkeypatch): + use_non_blocking_semaphore(monkeypatch) + network_backend = httpcore.AsyncMockBackend( + buffer=[ + http2_settings(max_concurrent_streams=2), + http2_response_headers(stream_id=1), + http2_response_headers(stream_id=3), + http2_response_headers(stream_id=5), + ], + http2=True, + ) + + async with httpcore.AsyncConnectionPool( + network_backend=network_backend, + max_connections=1, + ) as pool: + async with pool.stream("GET", "https://example.com/") as response_1: + assert response_1.status == 200 + async with pool.stream("GET", "https://example.com/") as response_2: + assert response_2.status == 200 + + with pytest.raises(httpcore.PoolTimeout): + await pool.request( + "GET", + "https://example.com/", + extensions={"timeout": {"pool": 0}}, + ) + assert len(pool._requests) == 2 + + async with pool.stream("GET", "https://example.com/") as response_3: + assert response_3.status == 200 + assert len(pool.connections) == 1 + + @pytest.mark.anyio async def test_connection_pool_with_http2_goaway(): """ diff --git a/tests/_async/test_http2.py b/tests/_async/test_http2.py index b4ec66488..4d1d88b81 100644 --- a/tests/_async/test_http2.py +++ b/tests/_async/test_http2.py @@ -1,3 +1,5 @@ +import importlib + import hpack import hyperframe.frame import pytest @@ -5,6 +7,56 @@ import httpcore +class AsyncNonBlockingSemaphore: + def __init__(self, bound: int) -> None: + self._value = bound + + async def acquire(self) -> None: + if self._value <= 0: # pragma: nocover + raise RuntimeError("stream semaphore exhausted") + self._value -= 1 + + def acquire_nowait(self) -> bool: + if self._value <= 0: # pragma: nocover + return False + self._value -= 1 + return True + + async def release(self) -> None: + self._value += 1 + + +def use_non_blocking_semaphore(monkeypatch: pytest.MonkeyPatch) -> None: + module_name = ( + "httpcore._async.http2" + if httpcore.AsyncHTTP2Connection.__name__.startswith("Async") + else "httpcore._sync.http2" + ) + http2_module = importlib.import_module(module_name) + monkeypatch.setattr(http2_module, "AsyncSemaphore", AsyncNonBlockingSemaphore) + + +def http2_settings(max_concurrent_streams: int) -> bytes: + return hyperframe.frame.SettingsFrame( + settings={ + hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: max_concurrent_streams + } + ).serialize() + + +def http2_response_headers(stream_id: int) -> bytes: + return hyperframe.frame.HeadersFrame( + stream_id=stream_id, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize() + + @pytest.mark.anyio async def test_http2_connection(): origin = httpcore.Origin(b"https", b"example.com", 443) @@ -82,6 +134,46 @@ async def test_http2_connection_closed(): assert not conn.is_available() +@pytest.mark.anyio +async def test_http2_connection_unavailable_when_max_streams_reached(): + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream( + [ + http2_settings(max_concurrent_streams=2), + http2_response_headers(stream_id=1), + http2_response_headers(stream_id=3), + ] + ) + async with httpcore.AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + async with conn.stream("GET", "https://example.com/") as response_1: + assert response_1.status == 200 + assert conn.is_available() + + async with conn.stream("GET", "https://example.com/") as response_2: + assert response_2.status == 200 + assert conn.max_concurrent_requests() == 2 + assert not conn.is_available() + with pytest.raises(httpcore.ConnectionNotAvailable): + await conn.request("GET", "https://example.com/") + + +@pytest.mark.anyio +async def test_http2_connection_closes_when_idle_stream_is_reset(): + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream( + [ + http2_settings(max_concurrent_streams=2), + http2_response_headers(stream_id=1), + ] + ) + async with httpcore.AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + async with conn.stream("GET", "https://example.com/") as response: + assert response.status == 200 + assert conn.is_available() + + assert conn.is_closed() + + @pytest.mark.anyio async def test_http2_connection_post_request(): origin = httpcore.Origin(b"https", b"example.com", 443) @@ -380,3 +472,40 @@ async def test_http2_remote_max_streams_update(): conn._h2_state.local_settings.max_concurrent_streams, ) i += 1 + + +@pytest.mark.anyio +async def test_http2_remote_max_streams_update_below_active_streams(monkeypatch): + use_non_blocking_semaphore(monkeypatch) + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.AsyncMockStream( + [ + http2_settings(max_concurrent_streams=3), + http2_response_headers(stream_id=1), + http2_response_headers(stream_id=3), + http2_response_headers(stream_id=5), + hyperframe.frame.DataFrame(stream_id=5, data=b"Hello").serialize(), + http2_settings(max_concurrent_streams=2), + hyperframe.frame.DataFrame( + stream_id=5, data=b", world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + async with httpcore.AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + async with conn.stream("GET", "https://example.com/") as response_1: + assert response_1.status == 200 + async with conn.stream("GET", "https://example.com/") as response_2: + assert response_2.status == 200 + async with conn.stream("GET", "https://example.com/") as response_3: + assert response_3.status == 200 + assert not conn.is_available() + assert [part async for part in response_3.aiter_stream()] == [ + b"Hello", + b", world!", + ] + assert conn.max_concurrent_requests() == 2 + assert not conn.is_available() + + assert not conn.is_available() + + assert conn.is_available() diff --git a/tests/_sync/test_connection_pool.py b/tests/_sync/test_connection_pool.py index 7adc3f5c8..6ad8390b7 100644 --- a/tests/_sync/test_connection_pool.py +++ b/tests/_sync/test_connection_pool.py @@ -1,3 +1,4 @@ +import importlib import logging import typing @@ -9,6 +10,56 @@ import httpcore +class NonBlockingSemaphore: + def __init__(self, bound: int) -> None: + self._value = bound + + def acquire(self) -> None: + if self._value <= 0: # pragma: nocover + raise RuntimeError("stream semaphore exhausted") + self._value -= 1 + + def acquire_nowait(self) -> bool: + if self._value <= 0: # pragma: nocover + return False + self._value -= 1 + return True + + def release(self) -> None: + self._value += 1 + + +def http2_settings(max_concurrent_streams: int) -> bytes: + return hyperframe.frame.SettingsFrame( + settings={ + hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: max_concurrent_streams + } + ).serialize() + + +def http2_response_headers(stream_id: int) -> bytes: + return hyperframe.frame.HeadersFrame( + stream_id=stream_id, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize() + + +def use_non_blocking_semaphore(monkeypatch: pytest.MonkeyPatch) -> None: + module_name = ( + "httpcore._async.http2" + if httpcore.HTTP2Connection.__name__.startswith("Async") + else "httpcore._sync.http2" + ) + http2_module = importlib.import_module(module_name) + monkeypatch.setattr(http2_module, "Semaphore", NonBlockingSemaphore) + + def test_connection_pool_with_keepalive(): """ @@ -205,6 +256,132 @@ def test_connection_pool_with_http2(): +def test_connection_pool_with_http2_connecting_connection_capacity_update(): + class CapacityKnownConnection: + def can_handle_request( + self, origin: httpcore.Origin + ) -> bool: # pragma: nocover + return True + + def is_available(self) -> bool: + return True + + def max_concurrent_requests(self) -> int: + return 2 + + def has_expired(self) -> bool: + return False + + def is_idle(self) -> bool: + return False + + def is_closed(self) -> bool: + return False + + def close(self) -> None: + pass + + def info(self) -> str: # pragma: nocover + return "HTTP/2, ACTIVE" + + class QueuedPoolRequest: + def __init__(self) -> None: + self.request = httpcore.Request("GET", "https://example.com/") + self.connection: typing.Any = None + + def is_queued(self) -> bool: + return self.connection is None + + def assign_to_connection(self, connection: typing.Any) -> None: + self.connection = connection + + with httpcore.ConnectionPool( + http2=True, + max_connections=1, + network_backend=httpcore.MockBackend([], http2=True), + ) as pool: + pool_request_1 = QueuedPoolRequest() + pool_request_2 = QueuedPoolRequest() + pool._requests = typing.cast(typing.Any, [pool_request_1, pool_request_2]) + + assert pool._assign_requests_to_connections() == [] + assert len(pool.connections) == 1 + assert pool_request_1.connection is not None + assert pool_request_2.connection is None + assert pool_request_1.connection.max_concurrent_requests() == 1 + + typing.cast( + typing.Any, pool_request_1.connection + )._connection = CapacityKnownConnection() + pool._connection_capacity_updated() + assert pool_request_2.connection is pool_request_1.connection + + + +def test_connection_pool_with_http2_stream_limit_opens_extra_connection( + monkeypatch, +): + use_non_blocking_semaphore(monkeypatch) + network_backend = httpcore.MockBackend( + buffer=[ + http2_settings(max_concurrent_streams=2), + http2_response_headers(stream_id=1), + http2_response_headers(stream_id=3), + http2_response_headers(stream_id=5), + ], + http2=True, + ) + + with httpcore.ConnectionPool( + network_backend=network_backend, + max_connections=3, + ) as pool: + with pool.stream("GET", "https://example.com/") as response_1: + assert response_1.status == 200 + with pool.stream("GET", "https://example.com/") as response_2: + assert response_2.status == 200 + assert len(pool.connections) == 1 + with pool.stream("GET", "https://example.com/") as response_3: + assert response_3.status == 200 + assert len(pool.connections) == 2 + + + +def test_connection_pool_with_http2_stream_limit_queues(monkeypatch): + use_non_blocking_semaphore(monkeypatch) + network_backend = httpcore.MockBackend( + buffer=[ + http2_settings(max_concurrent_streams=2), + http2_response_headers(stream_id=1), + http2_response_headers(stream_id=3), + http2_response_headers(stream_id=5), + ], + http2=True, + ) + + with httpcore.ConnectionPool( + network_backend=network_backend, + max_connections=1, + ) as pool: + with pool.stream("GET", "https://example.com/") as response_1: + assert response_1.status == 200 + with pool.stream("GET", "https://example.com/") as response_2: + assert response_2.status == 200 + + with pytest.raises(httpcore.PoolTimeout): + pool.request( + "GET", + "https://example.com/", + extensions={"timeout": {"pool": 0}}, + ) + assert len(pool._requests) == 2 + + with pool.stream("GET", "https://example.com/") as response_3: + assert response_3.status == 200 + assert len(pool.connections) == 1 + + + def test_connection_pool_with_http2_goaway(): """ Test a connection pool with HTTP/2 requests, that cleanly disconnects diff --git a/tests/_sync/test_http2.py b/tests/_sync/test_http2.py index 695359bd6..38768a642 100644 --- a/tests/_sync/test_http2.py +++ b/tests/_sync/test_http2.py @@ -1,3 +1,5 @@ +import importlib + import hpack import hyperframe.frame import pytest @@ -5,6 +7,56 @@ import httpcore +class NonBlockingSemaphore: + def __init__(self, bound: int) -> None: + self._value = bound + + def acquire(self) -> None: + if self._value <= 0: # pragma: nocover + raise RuntimeError("stream semaphore exhausted") + self._value -= 1 + + def acquire_nowait(self) -> bool: + if self._value <= 0: # pragma: nocover + return False + self._value -= 1 + return True + + def release(self) -> None: + self._value += 1 + + +def use_non_blocking_semaphore(monkeypatch: pytest.MonkeyPatch) -> None: + module_name = ( + "httpcore._async.http2" + if httpcore.HTTP2Connection.__name__.startswith("Async") + else "httpcore._sync.http2" + ) + http2_module = importlib.import_module(module_name) + monkeypatch.setattr(http2_module, "Semaphore", NonBlockingSemaphore) + + +def http2_settings(max_concurrent_streams: int) -> bytes: + return hyperframe.frame.SettingsFrame( + settings={ + hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: max_concurrent_streams + } + ).serialize() + + +def http2_response_headers(stream_id: int) -> bytes: + return hyperframe.frame.HeadersFrame( + stream_id=stream_id, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize() + + def test_http2_connection(): origin = httpcore.Origin(b"https", b"example.com", 443) @@ -83,6 +135,46 @@ def test_http2_connection_closed(): +def test_http2_connection_unavailable_when_max_streams_reached(): + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream( + [ + http2_settings(max_concurrent_streams=2), + http2_response_headers(stream_id=1), + http2_response_headers(stream_id=3), + ] + ) + with httpcore.HTTP2Connection(origin=origin, stream=stream) as conn: + with conn.stream("GET", "https://example.com/") as response_1: + assert response_1.status == 200 + assert conn.is_available() + + with conn.stream("GET", "https://example.com/") as response_2: + assert response_2.status == 200 + assert conn.max_concurrent_requests() == 2 + assert not conn.is_available() + with pytest.raises(httpcore.ConnectionNotAvailable): + conn.request("GET", "https://example.com/") + + + +def test_http2_connection_closes_when_idle_stream_is_reset(): + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream( + [ + http2_settings(max_concurrent_streams=2), + http2_response_headers(stream_id=1), + ] + ) + with httpcore.HTTP2Connection(origin=origin, stream=stream) as conn: + with conn.stream("GET", "https://example.com/") as response: + assert response.status == 200 + assert conn.is_available() + + assert conn.is_closed() + + + def test_http2_connection_post_request(): origin = httpcore.Origin(b"https", b"example.com", 443) stream = httpcore.MockStream( @@ -380,3 +472,40 @@ def test_http2_remote_max_streams_update(): conn._h2_state.local_settings.max_concurrent_streams, ) i += 1 + + + +def test_http2_remote_max_streams_update_below_active_streams(monkeypatch): + use_non_blocking_semaphore(monkeypatch) + origin = httpcore.Origin(b"https", b"example.com", 443) + stream = httpcore.MockStream( + [ + http2_settings(max_concurrent_streams=3), + http2_response_headers(stream_id=1), + http2_response_headers(stream_id=3), + http2_response_headers(stream_id=5), + hyperframe.frame.DataFrame(stream_id=5, data=b"Hello").serialize(), + http2_settings(max_concurrent_streams=2), + hyperframe.frame.DataFrame( + stream_id=5, data=b", world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + with httpcore.HTTP2Connection(origin=origin, stream=stream) as conn: + with conn.stream("GET", "https://example.com/") as response_1: + assert response_1.status == 200 + with conn.stream("GET", "https://example.com/") as response_2: + assert response_2.status == 200 + with conn.stream("GET", "https://example.com/") as response_3: + assert response_3.status == 200 + assert not conn.is_available() + assert [part for part in response_3.iter_stream()] == [ + b"Hello", + b", world!", + ] + assert conn.max_concurrent_requests() == 2 + assert not conn.is_available() + + assert not conn.is_available() + + assert conn.is_available() diff --git a/tests/test_cancellations.py b/tests/test_cancellations.py index 033acef60..ccf007b46 100644 --- a/tests/test_cancellations.py +++ b/tests/test_cancellations.py @@ -204,7 +204,7 @@ async def test_h2_timeout_during_request(): await conn.request("GET", "http://example.com") assert not conn.is_closed() - assert conn.is_idle() + assert conn.is_idle() # pragma: nocover @pytest.mark.xfail @@ -242,4 +242,4 @@ async def test_h2_timeout_during_response(): await conn.request("GET", "http://example.com") assert not conn.is_closed() - assert conn.is_idle() + assert conn.is_idle() # pragma: nocover