diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md index f15c1cda21..5add2f4f99 100644 --- a/python/CHANGELOG.md +++ b/python/CHANGELOG.md @@ -37,6 +37,9 @@ We structure this changelog in accordance with [Keep a Changelog](https://keepac - ETL `FastAPIServer`: stream no-FQN `hpush` PUTs in constant memory, mirroring the `hpull` GET change. Transient direct-put errors surface to AIS to retry the whole PUT (request body is one-shot). +- ETL `HTTPMultiThreadedServer`: stream no-FQN `hpush` PUTs in constant + memory. Previously the full request body was read into a `BytesIO` + before being handed to `transform_stream`. - **ETL direct-put retry**: added exponential-backoff retry for transient connection errors in Flask and HTTP multi-threaded ETL servers for parity with FastAPI. `ConnectionRefused` is now treated as a permanent error that returns HTTP 502 diff --git a/python/aistore/sdk/etl/webserver/base_etl_server.py b/python/aistore/sdk/etl/webserver/base_etl_server.py index 3628f99e27..4b0de3b593 100644 --- a/python/aistore/sdk/etl/webserver/base_etl_server.py +++ b/python/aistore/sdk/etl/webserver/base_etl_server.py @@ -85,6 +85,32 @@ def _handle_direct_put_transient_error( raise ETLDirectPutTransientError(direct_put_url, exc) from exc +def _compute_replayable_retries( + fqn: str, is_get: bool, direct_put_retries: int +) -> Tuple[bool, int]: + """ + Determine whether the streaming source is replayable and the effective + retry budget. + + Sources backed by a local FQN file or a GET stream can be reopened on + retry. No-FQN PUT bodies are one-shot (consumed from the request socket) + and cannot be replayed locally, so the retry budget is forced to zero. + + Args: + fqn: Local FQN of the source object; empty for streaming PUT. + is_get: ``True`` for hpull GET, ``False`` for hpush PUT. + direct_put_retries: Configured retry count. + + Returns: + ``(replayable, effective_retries)`` — `replayable` is ``True`` when + the source can be reopened; `effective_retries` equals + `direct_put_retries` when replayable, else ``0``. + """ + replayable = bool(fqn) or is_get + effective_retries = direct_put_retries if replayable else 0 + return replayable, effective_retries + + class ETLServer(ABC): # pylint: disable=too-many-instance-attributes """ Abstract base class for all ETL servers. diff --git a/python/aistore/sdk/etl/webserver/fastapi_server.py b/python/aistore/sdk/etl/webserver/fastapi_server.py index e9000dcf94..8a4b20263c 100644 --- a/python/aistore/sdk/etl/webserver/fastapi_server.py +++ b/python/aistore/sdk/etl/webserver/fastapi_server.py @@ -25,6 +25,7 @@ CountingIterator, RETRY_BACKOFF_BASE, RETRY_BACKOFF_MAX, + _compute_replayable_retries, ) from aistore.sdk.session_manager import resolve_ssl_config from aistore.sdk.etl.webserver.fastapi_streaming import ( @@ -375,12 +376,14 @@ async def _direct_put_stream_with_retry( # pylint: disable=too-many-arguments,t Raises: ETLDirectPutTransientError: if all retry attempts are exhausted. """ - replayable = bool(fqn) or is_get - effective_retries = self.direct_put_retries if replayable else 0 + replayable, effective_retries = _compute_replayable_retries( + fqn, is_get, self.direct_put_retries + ) if not replayable and self.direct_put_retries: self.logger.debug( "no-FQN PUT: source not replayable; " - "local retries skipped, AIS will retry" + "local retries skipped; transient direct-put error " + "will surface as transform failure" ) reader = await self._get_stream_reader(fqn, path, request, is_get) diff --git a/python/aistore/sdk/etl/webserver/http_multi_threaded_server.py b/python/aistore/sdk/etl/webserver/http_multi_threaded_server.py index ef019d1394..cdec496feb 100644 --- a/python/aistore/sdk/etl/webserver/http_multi_threaded_server.py +++ b/python/aistore/sdk/etl/webserver/http_multi_threaded_server.py @@ -2,11 +2,11 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # +import io import time from http.server import HTTPServer, BaseHTTPRequestHandler -from io import BytesIO from socketserver import ThreadingMixIn -from typing import Iterator, Type, Tuple +from typing import BinaryIO, Iterator, Type, Tuple import signal import threading from urllib.parse import urlparse, parse_qs @@ -20,6 +20,7 @@ RETRY_BACKOFF_BASE, RETRY_BACKOFF_MAX, _handle_direct_put_transient_error, + _compute_replayable_retries, ) from aistore.sdk.etl.webserver.utils import ( compose_etl_direct_put_url, @@ -40,6 +41,57 @@ ) +class _RFileLimitedReader(io.RawIOBase): + """Bound `BaseHTTPRequestHandler.rfile` to the current PUT body length. + + `self.rfile` is the raw connection stream; it has no intrinsic EOF at the + end of this request body. Passing it directly to `transform_stream` would + cause any transform that calls `reader.read()` with no size argument to + block indefinitely waiting for the client to close the connection. + + This wrapper tracks `Content-Length` remaining bytes and clamps every + `read()` call accordingly, giving transforms the same EOF semantics they + get from a `BytesIO` — without buffering the full body upfront. + + The request body is one-shot; `_direct_put_stream_with_retry` sets + `effective_retries=0` on this path. `close()` drains any unread bytes + from the request body so a transform that exits early does not leave + residual data on a keep-alive connection. + """ + + def __init__(self, rfile: BinaryIO, content_length: int) -> None: + self._rfile = rfile + self._remaining = content_length + + def readable(self) -> bool: + return True + + def read(self, size: int = -1) -> bytes: + if self._remaining == 0: + return b"" + if size is None or size < 0: + data = self._rfile.read(self._remaining) + self._remaining = 0 + return data + to_read = min(size, self._remaining) + data = self._rfile.read(to_read) + self._remaining -= len(data) + return data + + def close(self) -> None: + try: + while self._remaining > 0: + try: + chunk = self._rfile.read(min(self._remaining, 65536)) + except (OSError, ValueError): + break + if not chunk: + break + self._remaining -= len(chunk) + finally: + super().close() + + class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): """ Multi-threaded HTTP server that delegates ETL logic to a provided ETLServer instance. @@ -161,11 +213,10 @@ def _get_stream_reader(self, fqn, raw_path, is_get): resp.close() raise return _ResponseRawReader(resp) - # TODO: non-FQN PUT still buffers the full request body into BytesIO so - # retries can replay it; true streaming for this path still needs to be - # implemented. + # Request body is one-shot; local retries are skipped for this path + # (see _direct_put_stream_with_retry). content_length = int(self.headers.get(HEADER_CONTENT_LENGTH, 0)) - return BytesIO(self.rfile.read(content_length)) + return _RFileLimitedReader(self.rfile, content_length) def _direct_put_with_retry( self, @@ -208,15 +259,18 @@ def _direct_put_stream_with_retry( # pylint: disable=too-many-arguments,too-man """ Streaming direct-put with exponential-backoff retry on transient errors. - For FQN and GET sources, the reader is closed and reopened on each retry. - For PUT requests without FQN, the body is buffered in a BytesIO by - _get_stream_reader and seeked back to 0 on retry so the same bytes are - replayed (mirrors FastAPIServer._direct_put_stream_with_retry). + Replayable sources (FQN-backed or GET) close and reopen the reader on + each retry. No-FQN PUT bodies are one-shot (request body is consumed + from the socket); effective_retries is forced to 0 and a transient + direct-put error surfaces to AIS as a transform failure. """ etl = self.server.etl_server + replayable, effective_retries = _compute_replayable_retries( + fqn, is_get, etl.direct_put_retries + ) reader = self._get_stream_reader(fqn, raw_path, is_get) try: - for attempt in range(etl.direct_put_retries + 1): + for attempt in range(effective_retries + 1): try: return self._direct_put_stream( direct_put_url, @@ -225,22 +279,25 @@ def _direct_put_stream_with_retry( # pylint: disable=too-many-arguments,too-man raw_path, ) except ETLDirectPutTransientError as exc: - if attempt >= etl.direct_put_retries: + if attempt >= effective_retries: + if not replayable and etl.direct_put_retries: + etl.logger.debug( + "no-FQN PUT: source not replayable; " + "local retries skipped; transient direct-put error " + "will surface as transform failure" + ) raise delay = min(RETRY_BACKOFF_BASE**attempt, RETRY_BACKOFF_MAX) etl.logger.warning( "direct_put_stream attempt %d/%d failed, retrying in %.1fs: %s", attempt + 1, - etl.direct_put_retries + 1, + effective_retries + 1, delay, exc, exc_info=True, ) - if isinstance(reader, BytesIO): - reader.seek(0) - else: - etl.close_reader(reader) - reader = self._get_stream_reader(fqn, raw_path, is_get) + etl.close_reader(reader) + reader = self._get_stream_reader(fqn, raw_path, is_get) time.sleep(delay) finally: etl.close_reader(reader) diff --git a/python/tests/unit/sdk/test_etl_webserver.py b/python/tests/unit/sdk/test_etl_webserver.py index 0cd5141596..57fe0359ac 100644 --- a/python/tests/unit/sdk/test_etl_webserver.py +++ b/python/tests/unit/sdk/test_etl_webserver.py @@ -41,7 +41,10 @@ _is_connection_refused, ) from aistore.sdk.errors import ETLDirectPutTransientError -from aistore.sdk.etl.webserver.http_multi_threaded_server import HTTPMultiThreadedServer +from aistore.sdk.etl.webserver.http_multi_threaded_server import ( + HTTPMultiThreadedServer, + _RFileLimitedReader, +) from aistore.sdk.etl.webserver.flask_server import FlaskServer from aistore.sdk.etl.webserver.fastapi_server import FastAPIServer from aistore.sdk.etl.webserver.fastapi_streaming import _RequestStreamReader @@ -1489,6 +1492,74 @@ def test_streaming_get(self): handler.server.etl_server.transform_stream.assert_called_once() handler.server.etl_server.close_reader.assert_called_once() + def test_streaming_put_does_not_prebuffer_body(self): + """_get_stream_reader returns _RFileLimitedReader and does not pre-read rfile.""" + handler = DummyRequestHandler() + handler.headers = {HEADER_CONTENT_LENGTH: "10"} + rfile = MagicMock(wraps=io.BytesIO(b"AAAAAAAAAA")) + handler.rfile = rfile + + reader = handler._get_stream_reader( # pylint: disable=protected-access + fqn="", raw_path="/test/obj", is_get=False + ) + + self.assertIsInstance(reader, _RFileLimitedReader) + rfile.read.assert_not_called() + + def test_streaming_put_reader_bounds_to_content_length(self): + """_RFileLimitedReader.read(n) yields at most content_length bytes.""" + rfile = io.BytesIO(b"AAAABBBB") + reader = _RFileLimitedReader(rfile, 4) + self.assertEqual(reader.read(8), b"AAAA") + self.assertEqual(reader.read(8), b"") + + def test_streaming_put_reader_content_length_zero(self): + """_RFileLimitedReader returns b'' immediately when content_length is 0.""" + rfile = MagicMock() + reader = _RFileLimitedReader(rfile, 0) + self.assertEqual(reader.read(100), b"") + rfile.read.assert_not_called() + + def test_streaming_put_reader_full_read(self): + """_RFileLimitedReader.read() with no size returns all remaining bytes.""" + rfile = io.BytesIO(b"hello world") + reader = _RFileLimitedReader(rfile, 5) + self.assertEqual(reader.read(), b"hello") + self.assertEqual(reader.read(), b"") + + def test_close_drains_remaining_bytes(self): + """close() reads all remaining bytes from rfile when transform exits early.""" + rfile = io.BytesIO(b"AAAABBBB") + reader = _RFileLimitedReader(rfile, 8) + reader.read(4) # consume half; 4 bytes remain + reader.close() + self.assertEqual(rfile.tell(), 8) + + def test_close_handles_connection_close_mid_drain(self): + """close() stops cleanly when rfile is exhausted or raises before _remaining hits zero.""" + # Case 1: rfile returns empty bytes (short read) + rfile = io.BytesIO(b"ABCDE") # only 5 bytes, but content_length=10 + reader = _RFileLimitedReader(rfile, 10) + reader.close() # should not raise or loop forever + self.assertTrue(reader.closed) + self.assertEqual(rfile.tell(), 5) + + # Case 2: rfile is already closed (e.g. socket reset or GC after rfile closes) + rfile2 = io.BytesIO(b"ABCDE") + reader2 = _RFileLimitedReader(rfile2, 10) + rfile2.close() + reader2.close() # must not raise ValueError + self.assertTrue(reader2.closed) + + def test_close_no_op_drain_when_already_consumed(self): + """close() does not perform extra reads when all bytes have been read.""" + rfile = MagicMock(wraps=io.BytesIO(b"ABCD")) + reader = _RFileLimitedReader(rfile, 4) + reader.read() # drain fully + rfile.read.reset_mock() + reader.close() + rfile.read.assert_not_called() + # --------------------------------------------------------------------------- # HTTP streaming lifecycle: connection release, status forwarding, 502 mapping @@ -2325,7 +2396,7 @@ def test_succeeds_on_first_attempt(self): mock_sleep.assert_not_called() def test_retries_on_transient_error_then_succeeds(self): - """Retries on ETLDirectPutTransientError and succeeds on second attempt.""" + """Replayable GET path retries on transient error and succeeds on second attempt.""" self.handler.server.etl_server.direct_put_retries = 2 ok = (204, b"", 4) call_count = 0 @@ -2346,11 +2417,11 @@ def put_side(*_args, **_kwargs): ): with patch.object(self.handler, "_direct_put_stream", side_effect=put_side): with patch("time.sleep"): - result = self._call() + result = self._call(is_get=True) self.assertEqual(result, ok) def test_raises_after_exhausting_retries(self): - """ETLDirectPutTransientError raised after all retries exhausted.""" + """Replayable GET path: ETLDirectPutTransientError raised after all retries exhausted.""" self.handler.server.etl_server.direct_put_retries = 2 err = ETLDirectPutTransientError( self._DIRECT_PUT_URL, requests.ConnectionError() @@ -2363,7 +2434,7 @@ def test_raises_after_exhausting_retries(self): with patch.object(self.handler, "_direct_put_stream", side_effect=err): with patch("time.sleep"): with self.assertRaises(ETLDirectPutTransientError): - self._call() + self._call(is_get=True) def test_upstream_error_propagates_from_direct_put(self): """When _get_stream_reader raises HTTPError (upstream non-200), it propagates.""" @@ -2406,35 +2477,6 @@ def reader_side_effect(*_): with self.assertRaises(requests.HTTPError): self._call(is_get=True) - def test_reader_reopened_on_each_retry(self): - """For BytesIO readers (PUT-no-FQN), seek(0) is used on retry instead of - reopening; _get_stream_reader is called only once (initial open).""" - self.handler.server.etl_server.direct_put_retries = 2 - call_count = 0 - - def put_side(*_args, **_kwargs): - nonlocal call_count - call_count += 1 - if call_count < 3: - raise ETLDirectPutTransientError( - self._DIRECT_PUT_URL, requests.ConnectionError() - ) - return (204, b"", 0) - - readers = [self._make_reader() for _ in range(3)] - reader_iter = iter(readers) - with patch.object( - self.handler, - "_get_stream_reader", - side_effect=lambda *_: next(reader_iter), - ) as mock_reader: - with patch.object(self.handler, "_direct_put_stream", side_effect=put_side): - with patch("time.sleep"): - self._call() - # BytesIO readers are rewound with seek(0) on retry; _get_stream_reader - # is called only once (initial open), not once per attempt. - self.assertEqual(mock_reader.call_count, 1) - def test_reader_always_closed_on_success(self): """close_reader called once (via finally) after success.""" mock_reader = self._make_reader() @@ -2446,8 +2488,8 @@ def test_reader_always_closed_on_success(self): self.handler.server.etl_server.close_reader.assert_called_once_with(mock_reader) def test_reader_always_closed_on_exhausted_retries(self): - """For BytesIO readers (PUT-no-FQN), seek(0) is used on retry instead of - close+reopen; close_reader is called once in the finally block.""" + """Replayable GET path: each failed attempt closes the reader before reopening, + and the final raise closes the reader once more via the finally block.""" self.handler.server.etl_server.direct_put_retries = 2 err = ETLDirectPutTransientError( self._DIRECT_PUT_URL, requests.ConnectionError() @@ -2460,12 +2502,12 @@ def test_reader_always_closed_on_exhausted_retries(self): with patch.object(self.handler, "_direct_put_stream", side_effect=err): with patch("time.sleep"): with self.assertRaises(ETLDirectPutTransientError): - self._call() - # BytesIO: no close in except (seek instead); 1 close in finally = 1 total - self.assertEqual(self.handler.server.etl_server.close_reader.call_count, 1) + self._call(is_get=True) + # 2 closes in except (between attempts) + 1 close in finally = 3 total + self.assertEqual(self.handler.server.etl_server.close_reader.call_count, 3) def test_exponential_backoff_delays(self): - """time.sleep is called with exponentially increasing delays.""" + """Replayable GET path: time.sleep is called with exponentially increasing delays.""" self.handler.server.etl_server.direct_put_retries = 3 call_count = 0 @@ -2485,7 +2527,7 @@ def put_side(*_args, **_kwargs): ): with patch.object(self.handler, "_direct_put_stream", side_effect=put_side): with patch("time.sleep") as mock_sleep: - self._call() + self._call(is_get=True) delays = [c.args[0] for c in mock_sleep.call_args_list] self.assertEqual(len(delays), 3) @@ -2493,54 +2535,81 @@ def put_side(*_args, **_kwargs): self.assertEqual(delays[1], min(2.0**1, 30.0)) self.assertEqual(delays[2], min(2.0**2, 30.0)) - def test_put_no_fqn_retry_sends_original_bytes(self): - """Regression: HTTP streaming PUT without FQN replays original body on retry. - - Before the fix, calling _get_stream_reader again on retry would drain rfile - a second time and return an empty BytesIO. After the fix, seek(0) is called on - the existing BytesIO so the retry reads the same original bytes. - """ - original_body = b"hello http regression bytes" - bytes_read_by_transform = [] - - def spy_transform_stream(reader, _path, _etl_args): - data = reader.read() - bytes_read_by_transform.append(data) - yield data + def test_no_fqn_put_skips_local_retry(self): + """No-FQN PUT body is one-shot: a transient error raises immediately.""" + self.handler.server.etl_server.direct_put_retries = 3 + err = ETLDirectPutTransientError( + self._DIRECT_PUT_URL, requests.ConnectionError() + ) + with patch.object( + self.handler, "_get_stream_reader", return_value=self._make_reader() + ): + with patch.object( + self.handler, "_direct_put_stream", side_effect=err + ) as mock_put: + with patch("time.sleep") as mock_sleep: + with self.assertRaises(ETLDirectPutTransientError): + self._call(fqn="", is_get=False) + self.assertEqual(mock_put.call_count, 1) + mock_sleep.assert_not_called() + def test_fqn_put_still_retries(self): + """FQN-backed PUT is replayable; retries proceed normally.""" + self.handler.server.etl_server.direct_put_retries = 2 + ok = (204, b"", 4) call_count = 0 - def mock_direct_put_stream(url, data_iter, *_a, **_kw): + def put_side(*_args, **_kwargs): nonlocal call_count call_count += 1 - # Consume the generator so transform_stream reads from reader. - b"".join(data_iter) - if call_count == 1: - raise ETLDirectPutTransientError(url, requests.ConnectionError()) - return (204, b"", len(original_body)) - - handler = DummyRequestHandler() - handler.headers = {HEADER_CONTENT_LENGTH: str(len(original_body))} - handler.rfile = io.BytesIO(original_body) - handler.server.etl_server.direct_put_retries = 1 - handler.server.etl_server.transform_stream = spy_transform_stream - handler.server.etl_server.close_reader = MagicMock() + if call_count < 2: + raise ETLDirectPutTransientError( + self._DIRECT_PUT_URL, requests.ConnectionError() + ) + return ok + readers = [self._make_reader(), self._make_reader()] + reader_iter = iter(readers) with patch.object( - handler, "_direct_put_stream", side_effect=mock_direct_put_stream + self.handler, "_get_stream_reader", side_effect=lambda *_: next(reader_iter) ): - with patch("time.sleep"): - handler._direct_put_stream_with_retry( # pylint: disable=protected-access - self._DIRECT_PUT_URL, "", "/test/obj", "", is_get=False + with patch.object( + self.handler, "_direct_put_stream", side_effect=put_side + ) as mock_put: + with patch("time.sleep") as mock_sleep: + result = self._call(fqn="some/file", is_get=False) + self.assertEqual(result, ok) + self.assertEqual(mock_put.call_count, 2) + self.assertEqual(mock_sleep.call_count, 1) + + def test_get_path_still_retries(self): + """GET path stays replayable; retries proceed normally (regression guard).""" + self.handler.server.etl_server.direct_put_retries = 2 + ok = (204, b"", 4) + call_count = 0 + + def put_side(*_args, **_kwargs): + nonlocal call_count + call_count += 1 + if call_count < 2: + raise ETLDirectPutTransientError( + self._DIRECT_PUT_URL, requests.ConnectionError() ) + return ok - self.assertEqual(len(bytes_read_by_transform), 2, "expected 2 attempts") - self.assertEqual(bytes_read_by_transform[0], original_body) - self.assertEqual( - bytes_read_by_transform[1], - original_body, - "retry must send original bytes, not empty bytes", - ) + readers = [self._make_reader(), self._make_reader()] + reader_iter = iter(readers) + with patch.object( + self.handler, "_get_stream_reader", side_effect=lambda *_: next(reader_iter) + ): + with patch.object( + self.handler, "_direct_put_stream", side_effect=put_side + ) as mock_put: + with patch("time.sleep") as mock_sleep: + result = self._call(is_get=True) + self.assertEqual(result, ok) + self.assertEqual(mock_put.call_count, 2) + self.assertEqual(mock_sleep.call_count, 1) # ---------------------------------------------------------------------------