diff --git a/aiohttp/client.py b/aiohttp/client.py index dcbdd23dfd4..f4d19dc0179 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -65,8 +65,10 @@ SSL_ALLOWED_TYPES, ClientRequest, ClientResponse, + ClientTimeout, Fingerprint, RequestInfo, + ResponseParams, ) from .client_ws import ( DEFAULT_WS_CLIENT_TIMEOUT, @@ -87,7 +89,6 @@ BasicAuth, TimeoutHandle, basicauth_from_netrc, - frozen_dataclass_decorator, get_env_proxy_for_url, netrc_from_env, sentinel, @@ -187,28 +188,6 @@ class _RequestOptions(TypedDict, total=False): middlewares: Sequence[ClientMiddlewareType] | None -@frozen_dataclass_decorator -class ClientTimeout: - total: float | None = None - connect: float | None = None - sock_read: float | None = None - sock_connect: float | None = None - ceil_threshold: float = 5 - - # pool_queue_timeout: Optional[float] = None - # dns_resolution_timeout: Optional[float] = None - # socket_connect_timeout: Optional[float] = None - # connection_acquiring_timeout: Optional[float] = None - # new_connection_timeout: Optional[float] = None - # http_header_timeout: Optional[float] = None - # response_body_timeout: Optional[float] = None - - # to create a timeout specific for a single request, either - # - create a completely new one to overwrite the default - # - or use https://docs.python.org/3/library/dataclasses.html#dataclasses.replace - # to overwrite the defaults - - # 5 Minute default read timeout DEFAULT_TIMEOUT: Final[ClientTimeout] = ClientTimeout(total=5 * 60, sock_connect=30) @@ -631,6 +610,18 @@ async def _request( get_env_proxy_for_url, url ) + response_params: ResponseParams = { + "timer": timer, + "skip_payload": method in EMPTY_BODY_METHODS, + "read_until_eof": read_until_eof, + "auto_decompress": auto_decompress, + "read_timeout": real_timeout.sock_read, + "read_bufsize": read_bufsize, + "timeout_ceil_threshold": self._connector._timeout_ceil_threshold, + "max_line_size": max_line_size, + "max_field_size": max_field_size, + } + req = self._request_class( method, url, @@ -648,7 +639,9 @@ async def _request( response_class=self._response_class, proxy=proxy_, proxy_auth=proxy_auth, + response_params=response_params, timer=timer, + timeout=real_timeout, session=self, ssl=ssl, server_hostname=server_hostname, @@ -664,7 +657,7 @@ async def _connect_and_send_request( assert self._connector is not None try: conn = await self._connector.connect( - req, traces=traces, timeout=real_timeout + req, traces=traces, timeout=req._timeout ) except asyncio.TimeoutError as exc: raise ConnectionTimeoutError( @@ -672,17 +665,7 @@ async def _connect_and_send_request( ) from exc assert conn.protocol is not None - conn.protocol.set_response_params( - timer=timer, - skip_payload=req.method in EMPTY_BODY_METHODS, - read_until_eof=read_until_eof, - auto_decompress=auto_decompress, - read_timeout=real_timeout.sock_read, - read_bufsize=read_bufsize, - timeout_ceil_threshold=self._connector._timeout_ceil_threshold, - max_line_size=max_line_size, - max_field_size=max_field_size, - ) + conn.protocol.set_response_params(**req._response_params) try: resp = await req._send(conn) try: diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 0d6f435b6e5..310949d3bbf 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -84,6 +84,28 @@ _CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]") +@frozen_dataclass_decorator +class ClientTimeout: + total: float | None = None + connect: float | None = None + sock_read: float | None = None + sock_connect: float | None = None + ceil_threshold: float = 5 + + # pool_queue_timeout: Optional[float] = None + # dns_resolution_timeout: Optional[float] = None + # socket_connect_timeout: Optional[float] = None + # connection_acquiring_timeout: Optional[float] = None + # new_connection_timeout: Optional[float] = None + # http_header_timeout: Optional[float] = None + # response_body_timeout: Optional[float] = None + + # to create a timeout specific for a single request, either + # - create a completely new one to overwrite the default + # - or use https://docs.python.org/3/library/dataclasses.html#dataclasses.replace + # to overwrite the defaults + + def _gen_default_accept_encoding() -> str: encodings = [ "gzip", @@ -185,6 +207,18 @@ class ConnectionKey(NamedTuple): proxy_headers_hash: int | None # hash(CIMultiDict) +class ResponseParams(TypedDict): + timer: BaseTimerContext | None + skip_payload: bool + read_until_eof: bool + auto_decompress: bool + read_timeout: float | None + read_bufsize: int + timeout_ceil_threshold: float + max_line_size: int + max_field_size: int + + class ClientResponse(HeadersMixin): # Some of these attributes are None when created, # but will be set by the start() method. @@ -946,7 +980,9 @@ class ClientRequestArgs(TypedDict, total=False): response_class: type[ClientResponse] proxy: URL | None proxy_auth: BasicAuth | None + response_params: ResponseParams timer: BaseTimerContext + timeout: ClientTimeout session: "ClientSession" ssl: SSLContext | bool | Fingerprint proxy_headers: CIMultiDict[str] | None @@ -959,6 +995,8 @@ class ClientRequest(ClientRequestBase): _EMPTY_BODY = payload.PAYLOAD_REGISTRY.get(b"", disposition=None) _body = _EMPTY_BODY _continue = None # waiter future for '100 Continue' response + _response_params: ResponseParams = None # type: ignore[assignment] + _timeout = ClientTimeout() GET_METHODS = { hdrs.METH_GET, @@ -990,7 +1028,9 @@ def __init__( response_class: type[ClientResponse], proxy: URL | None, proxy_auth: BasicAuth | None, + response_params: ResponseParams, timer: BaseTimerContext, + timeout: ClientTimeout, session: "ClientSession", ssl: SSLContext | bool | Fingerprint, proxy_headers: CIMultiDict[str] | None, @@ -1014,7 +1054,9 @@ def __init__( self._session = session self.chunked = chunked self.response_class = response_class + self._response_params = response_params self._timer = timer + self._timeout = timeout self.server_hostname = server_hostname self.version = version diff --git a/tests/conftest.py b/tests/conftest.py index e5dc79cad4d..3f68bdea399 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ HAS_BLOCKBUSTER = False from aiohttp import payload -from aiohttp.client import ClientSession +from aiohttp.client import ClientSession, ClientTimeout from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientRequest, ClientRequestArgs, ClientResponse from aiohttp.compression_utils import ZLibBackend, ZLibBackendProtocol, set_zlib_backend @@ -418,6 +418,8 @@ def maker( ) -> ClientRequest: nonlocal request, session session = ClientSession() + timer = TimerNoop() + timeout = ClientTimeout() default_args: ClientRequestArgs = { "loop": loop, "params": {}, @@ -433,7 +435,19 @@ def maker( "response_class": ClientResponse, "proxy": None, "proxy_auth": None, - "timer": TimerNoop(), + "response_params": { + "timer": timer, + "skip_payload": True, + "read_until_eof": True, + "auto_decompress": True, + "read_timeout": timeout.sock_read, + "read_bufsize": 2**16, + "timeout_ceil_threshold": 5, + "max_line_size": 8190, + "max_field_size": 8190, + }, + "timer": timer, + "timeout": timeout, "session": session, "ssl": True, "proxy_headers": None, diff --git a/tests/test_benchmarks_client_request.py b/tests/test_benchmarks_client_request.py index 0a37087380f..af14c2550da 100644 --- a/tests/test_benchmarks_client_request.py +++ b/tests/test_benchmarks_client_request.py @@ -10,7 +10,13 @@ from pytest_codspeed import BenchmarkFixture from yarl import URL -from aiohttp.client_reqrep import ClientRequest, ClientRequestArgs, ClientResponse +from aiohttp.client_reqrep import ( + ClientRequest, + ClientRequestArgs, + ClientResponse, + ClientTimeout, + ResponseParams, +) from aiohttp.cookiejar import CookieJar from aiohttp.helpers import TimerNoop from aiohttp.http_writer import HttpVersion11 @@ -50,8 +56,20 @@ def test_create_client_request_with_cookies( cookies = cookie_jar.filter_cookies(url) assert cookies["cookie"].value == "value" timer = TimerNoop() + timeout = ClientTimeout() traces: list[Trace] = [] headers = CIMultiDict[str]() + response_params: ResponseParams = { + "timer": timer, + "skip_payload": True, + "read_until_eof": True, + "auto_decompress": True, + "read_timeout": timeout.sock_read, + "read_bufsize": 2**16, + "timeout_ceil_threshold": 5, + "max_line_size": 8190, + "max_field_size": 8190, + } @benchmark def _run() -> None: @@ -65,7 +83,9 @@ def _run() -> None: proxy=None, proxy_auth=None, proxy_headers=None, + response_params=response_params, timer=timer, + timeout=timeout, session=None, # type: ignore[arg-type] ssl=True, traces=traces, @@ -88,9 +108,21 @@ def test_create_client_request_with_headers( ) -> None: url = URL("http://python.org") timer = TimerNoop() + timeout = ClientTimeout() traces: list[Trace] = [] headers = CIMultiDict({"header": "value", "another": "header"}) cookies = BaseCookie[str]() + response_params: ResponseParams = { + "timer": timer, + "skip_payload": True, + "read_until_eof": True, + "auto_decompress": True, + "read_timeout": timeout.sock_read, + "read_bufsize": 2**16, + "timeout_ceil_threshold": 5, + "max_line_size": 8190, + "max_field_size": 8190, + } @benchmark def _run() -> None: @@ -104,7 +136,9 @@ def _run() -> None: proxy=None, proxy_auth=None, proxy_headers=None, + response_params=response_params, timer=timer, + timeout=timeout, session=None, # type: ignore[arg-type] ssl=True, traces=traces, diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 74670cbc9f7..6937519449d 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -21,6 +21,7 @@ ClientRequest, ClientRequestArgs, ClientResponse, + ClientTimeout, Fingerprint, _gen_default_accept_encoding, ) @@ -1615,6 +1616,8 @@ def test_terminate_with_closed_loop( async def go() -> None: nonlocal req, resp, writer # Can't use make_client_request here, due to closing the loop mid-test. + timer = TimerNoop() + timeout = ClientTimeout() req = ClientRequest( "get", URL("http://python.org"), @@ -1632,7 +1635,19 @@ async def go() -> None: response_class=ClientResponse, proxy=None, proxy_auth=None, - timer=TimerNoop(), + response_params={ + "timer": timer, + "skip_payload": True, + "read_until_eof": True, + "auto_decompress": True, + "read_timeout": timeout.sock_read, + "read_bufsize": 2**16, + "timeout_ceil_threshold": 5, + "max_line_size": 8190, + "max_field_size": 8190, + }, + timer=timer, + timeout=timeout, session=None, # type: ignore[arg-type] ssl=True, proxy_headers=None, diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 21057d3fbb5..cda40a62baf 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -21,7 +21,7 @@ from aiohttp import abc, client, hdrs, tracing, web from aiohttp.client import ClientSession from aiohttp.client_proto import ResponseHandler -from aiohttp.client_reqrep import ClientRequest, ConnectionKey +from aiohttp.client_reqrep import ClientRequest, ClientTimeout, ConnectionKey from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector from aiohttp.cookiejar import CookieJar from aiohttp.http import RawResponseMessage @@ -560,6 +560,7 @@ async def test_reraise_os_error( req._send = mock.AsyncMock(side_effect=err) req._body = mock.Mock() req._body.close = mock.AsyncMock() + req._timeout = ClientTimeout() session = await create_session(request_class=req_factory) async def create_connection( @@ -592,6 +593,7 @@ class UnexpectedException(BaseException): req._send = mock.AsyncMock(side_effect=err) req._body = mock.Mock() req._body.close = mock.AsyncMock() + req._timeout = ClientTimeout() session = await create_session(request_class=req_factory) connections = [] @@ -652,6 +654,7 @@ async def test_ws_connect_allowed_protocols( # type: ignore[misc] req._body = None # No body for WebSocket upgrade requests req_factory = mock.Mock(return_value=req) req._send = mock.AsyncMock(return_value=resp) + req._timeout = ClientTimeout() # BaseConnector allows all high level protocols by default connector = BaseConnector() @@ -715,6 +718,7 @@ async def test_ws_connect_unix_socket_allowed_protocols( # type: ignore[misc] req._body = None # No body for WebSocket upgrade requests req_factory = mock.Mock(return_value=req) req._send = mock.AsyncMock(return_value=resp) + req._timeout = ClientTimeout() # UnixConnector allows all high level protocols by default and unix sockets session = await create_session( connector=UnixConnector(path=""), request_class=req_factory