From 4c4079afc9eb1a90682ead27b0a41a3ce40712e1 Mon Sep 17 00:00:00 2001 From: Pulkit Aggarwal Date: Mon, 8 Dec 2025 11:40:13 +0000 Subject: [PATCH 01/10] integrate retry logic with the MRD --- tests/unit/asyncio/test_async_multi_range_downloader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/asyncio/test_async_multi_range_downloader.py b/tests/unit/asyncio/test_async_multi_range_downloader.py index 2f0600f8d..222c13a67 100644 --- a/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -36,6 +36,7 @@ class TestAsyncMultiRangeDownloader: + def create_read_ranges(self, num_ranges): ranges = [] for i in range(num_ranges): @@ -143,6 +144,7 @@ async def test_download_ranges_via_async_gather( ) ] ), + None, _storage_v2.BidiReadObjectResponse( object_data_ranges=[ _storage_v2.ObjectRangeData( From c7f6d4664321b47e7d18aaa2060cbe2445934ce6 Mon Sep 17 00:00:00 2001 From: Pulkit Aggarwal Date: Mon, 15 Dec 2025 07:16:25 +0000 Subject: [PATCH 02/10] feat(experimental): add write resumption strategy --- tests/unit/asyncio/test_async_multi_range_downloader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/asyncio/test_async_multi_range_downloader.py b/tests/unit/asyncio/test_async_multi_range_downloader.py index 222c13a67..3e0db3626 100644 --- a/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -36,7 +36,6 @@ class TestAsyncMultiRangeDownloader: - def create_read_ranges(self, num_ranges): ranges = [] for i in range(num_ranges): From c0350f5c6acf83ea4911e11b7cc65b8b7e5c4476 Mon Sep 17 00:00:00 2001 From: Pulkit Aggarwal Date: Mon, 29 Dec 2025 06:19:01 +0000 Subject: [PATCH 03/10] address gemini bot comments --- tests/unit/asyncio/test_async_multi_range_downloader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/asyncio/test_async_multi_range_downloader.py b/tests/unit/asyncio/test_async_multi_range_downloader.py index 3e0db3626..2f0600f8d 100644 --- a/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -143,7 +143,6 @@ async def test_download_ranges_via_async_gather( ) ] ), - None, _storage_v2.BidiReadObjectResponse( object_data_ranges=[ _storage_v2.ObjectRangeData( From d1cc1ef98c44b813c4d2bb0c419ff56e13916ce5 Mon Sep 17 00:00:00 2001 From: Pulkit Aggarwal Date: Tue, 23 Dec 2025 06:50:54 +0000 Subject: [PATCH 04/10] feat(experimental): integrate writes strategy and appendable object writer --- .../asyncio/async_appendable_object_writer.py | 345 ++++++++++++++---- .../retry/writes_resumption_strategy.py | 6 + .../test_async_appendable_object_writer.py | 54 +++ 3 files changed, 342 insertions(+), 63 deletions(-) diff --git a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py index c808cb52a..7c1bec4f8 100644 --- a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py +++ b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py @@ -21,11 +21,16 @@ if you want to use these Rapid Storage APIs. """ -from io import BufferedReader -from typing import Optional, Union +from io import BufferedReader, BytesIO +import asyncio +from typing import List, Optional, Tuple, Union from google_crc32c import Checksum from google.api_core import exceptions +from google.api_core.retry_async import AsyncRetry +from google.rpc import status_pb2 +from google.cloud._storage_v2.types import BidiWriteObjectRedirectedError + from ._utils import raise_if_no_fast_crc32c from google.cloud import _storage_v2 @@ -35,10 +40,58 @@ from google.cloud.storage._experimental.asyncio.async_write_object_stream import ( _AsyncWriteObjectStream, ) +from google.cloud.storage._experimental.asyncio.retry.bidi_stream_retry_manager import ( + _BidiStreamRetryManager, +) +from google.cloud.storage._experimental.asyncio.retry.writes_resumption_strategy import ( + _WriteResumptionStrategy, + _WriteState, +) _MAX_CHUNK_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB _DEFAULT_FLUSH_INTERVAL_BYTES = 16 * 1024 * 1024 # 16 MiB +_BIDI_WRITE_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" +) + + +def _is_write_retryable(exc): + """Predicate to determine if a write operation should be retried.""" + if isinstance( + exc, + ( + exceptions.InternalServerError, + exceptions.ServiceUnavailable, + exceptions.DeadlineExceeded, + exceptions.TooManyRequests, + ), + ): + return True + + grpc_error = None + if isinstance(exc, exceptions.Aborted): + grpc_error = exc.errors[0] + trailers = grpc_error.trailing_metadata() + if not trailers: + return False + + status_details_bin = None + for key, value in trailers: + if key == "grpc-status-details-bin": + status_details_bin = value + break + + if status_details_bin: + status_proto = status_pb2.Status() + try: + status_proto.ParseFromString(status_details_bin) + for detail in status_proto.details: + if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL: + return True + except Exception: + return False + return False class AsyncAppendableObjectWriter: @@ -114,13 +167,7 @@ def __init__( self.write_handle = write_handle self.generation = generation - self.write_obj_stream = _AsyncWriteObjectStream( - client=self.client, - bucket_name=self.bucket_name, - object_name=self.object_name, - generation_number=self.generation, - write_handle=self.write_handle, - ) + self.write_obj_stream: Optional[_AsyncWriteObjectStream] = None self._is_stream_open: bool = False # `offset` is the latest size of the object without staleless. self.offset: Optional[int] = None @@ -143,6 +190,8 @@ def __init__( f"flush_interval must be a multiple of {_MAX_CHUNK_SIZE_BYTES}, but provided {self.flush_interval}" ) self.bytes_appended_since_last_flush = 0 + self._lock = asyncio.Lock() + self._routing_token: Optional[str] = None async def state_lookup(self) -> int: """Returns the persisted_size @@ -165,7 +214,55 @@ async def state_lookup(self) -> int: self.persisted_size = response.persisted_size return self.persisted_size - async def open(self) -> None: + def _on_open_error(self, exc): + """Extracts routing token and write handle on redirect error during open.""" + grpc_error = None + if isinstance(exc, exceptions.Aborted) and exc.errors: + grpc_error = exc.errors[0] + + if grpc_error: + if isinstance(grpc_error, BidiWriteObjectRedirectedError): + self._routing_token = grpc_error.routing_token + if grpc_error.write_handle: + self.write_handle = grpc_error.write_handle + return + + if hasattr(grpc_error, "trailing_metadata"): + trailers = grpc_error.trailing_metadata() + if not trailers: + return + + status_details_bin = None + for key, value in trailers: + if key == "grpc-status-details-bin": + status_details_bin = value + break + + if status_details_bin: + status_proto = status_pb2.Status() + try: + status_proto.ParseFromString(status_details_bin) + for detail in status_proto.details: + if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL: + redirect_proto = ( + BidiWriteObjectRedirectedError.deserialize( + detail.value + ) + ) + if redirect_proto.routing_token: + self._routing_token = redirect_proto.routing_token + if redirect_proto.write_handle: + self.write_handle = redirect_proto.write_handle + break + except Exception: + # Could not parse the error, ignore + pass + + async def open( + self, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + ) -> None: """Opens the underlying bidi-gRPC stream. :raises ValueError: If the stream is already open. @@ -174,15 +271,172 @@ async def open(self) -> None: if self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is already open") - await self.write_obj_stream.open() - self._is_stream_open = True - if self.generation is None: - self.generation = self.write_obj_stream.generation_number - self.write_handle = self.write_obj_stream.write_handle - self.persisted_size = self.write_obj_stream.persisted_size + if retry_policy is None: + retry_policy = AsyncRetry( + predicate=_is_write_retryable, on_error=self._on_open_error + ) + else: + original_on_error = retry_policy._on_error + + def combined_on_error(exc): + self._on_open_error(exc) + if original_on_error: + original_on_error(exc) + + retry_policy = retry_policy.with_predicate( + _is_write_retryable + ).with_on_error(combined_on_error) + + async def _do_open(): + current_metadata = list(metadata) if metadata else [] + + # Cleanup stream from previous failed attempt, if any. + if self.write_obj_stream: + if self._is_stream_open: + try: + await self.write_obj_stream.close() + except Exception: # ignore cleanup errors + pass + self.write_obj_stream = None + self._is_stream_open = False + + self.write_obj_stream = _AsyncWriteObjectStream( + client=self.client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation, + write_handle=self.write_handle, + ) + + if self._routing_token: + current_metadata.append( + ("x-goog-request-params", f"routing_token={self._routing_token}") + ) + self._routing_token = None + + await self.write_obj_stream.open( + metadata=current_metadata if metadata else None + ) + + if self.write_obj_stream.generation_number: + self.generation = self.write_obj_stream.generation_number + if self.write_obj_stream.write_handle: + self.write_handle = self.write_obj_stream.write_handle + if self.write_obj_stream.persisted_size is not None: + self.persisted_size = self.write_obj_stream.persisted_size + + self._is_stream_open = True + + await retry_policy(_do_open)() + + async def _upload_with_retry( + self, + data: bytes, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + ) -> None: + if not self._is_stream_open: + raise ValueError("Underlying bidi-gRPC stream is not open") + + if retry_policy is None: + retry_policy = AsyncRetry(predicate=_is_write_retryable) + + # Initialize Global State for Retry Strategy + spec = _storage_v2.AppendObjectSpec( + bucket=self.bucket_name, + object=self.object_name, + generation=self.generation, + ) + buffer = BytesIO(data) + write_state = _WriteState( + spec=spec, + chunk_size=_MAX_CHUNK_SIZE_BYTES, + user_buffer=buffer, + ) + write_state.write_handle = self.write_handle + + initial_state = { + "write_state": write_state, + "first_request": True, + } + + # Track attempts to manage stream reuse + attempt_count = 0 + + def stream_opener( + requests, + state, + metadata: Optional[List[Tuple[str, str]]] = None, + ): + async def generator(): + nonlocal attempt_count + attempt_count += 1 + + async with self._lock: + current_handle = state["write_state"].write_handle + current_token = state["write_state"].routing_token + + should_reopen = (attempt_count > 1) or (current_token is not None) + + if should_reopen: + if self.write_obj_stream and self.write_obj_stream._is_stream_open: + await self.write_obj_stream.close() + + self.write_obj_stream = _AsyncWriteObjectStream( + client=self.client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation, + write_handle=current_handle, + ) + + current_metadata = list(metadata) if metadata else [] + if current_token: + current_metadata.append( + ( + "x-goog-request-params", + f"routing_token={current_token}", + ) + ) + + await self.write_obj_stream.open( + metadata=current_metadata if current_metadata else None + ) + self._is_stream_open = True + + # Let the strategy generate the request sequence + async for request in requests: + await self.write_obj_stream.send(request) + + # Signal that we are done sending requests. + await self.write_obj_stream.requests.put(None) + + # Process responses + async for response in self.write_obj_stream: + yield response + + return generator() + + strategy = _WriteResumptionStrategy() + retry_manager = _BidiStreamRetryManager( + strategy, lambda r, s: stream_opener(r, s, metadata=metadata) + ) - async def append(self, data: bytes) -> None: - """Appends data to the Appendable object. + await retry_manager.execute(initial_state, retry_policy) + + # Update the writer's state from the strategy's final state + final_write_state = initial_state["write_state"] + self.persisted_size = final_write_state.persisted_size + self.write_handle = final_write_state.write_handle + self.offset = self.persisted_size + + async def append( + self, + data: bytes, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + ) -> None: + """Appends data to the Appendable object with automatic retries. calling `self.append` will append bytes at the end of the current size ie. `self.offset` bytes relative to the begining of the object. @@ -195,55 +449,20 @@ async def append(self, data: bytes) -> None: :type data: bytes :param data: The bytes to append to the object. - :rtype: None + :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` + :param retry_policy: (Optional) The retry policy to use for the operation. - :raises ValueError: If the stream is not open (i.e., `open()` has not - been called). - """ + :type metadata: List[Tuple[str, str]] + :param metadata: (Optional) The metadata to be sent with the request. + :raises ValueError: If the stream is not open. + """ if not self._is_stream_open: raise ValueError("Stream is not open. Call open() before append().") - total_bytes = len(data) - if total_bytes == 0: - # TODO: add warning. - return - if self.offset is None: - assert self.persisted_size is not None - self.offset = self.persisted_size - - start_idx = 0 - while start_idx < total_bytes: - end_idx = min(start_idx + _MAX_CHUNK_SIZE_BYTES, total_bytes) - data_chunk = data[start_idx:end_idx] - is_last_chunk = end_idx == total_bytes - chunk_size = end_idx - start_idx - await self.write_obj_stream.send( - _storage_v2.BidiWriteObjectRequest( - write_offset=self.offset, - checksummed_data=_storage_v2.ChecksummedData( - content=data_chunk, - crc32c=int.from_bytes(Checksum(data_chunk).digest(), "big"), - ), - state_lookup=is_last_chunk, - flush=is_last_chunk - or ( - self.bytes_appended_since_last_flush + chunk_size - >= self.flush_interval - ), - ) - ) - self.offset += chunk_size - self.bytes_appended_since_last_flush += chunk_size - - if self.bytes_appended_since_last_flush >= self.flush_interval: - self.bytes_appended_since_last_flush = 0 - - if is_last_chunk: - response = await self.write_obj_stream.recv() - self.persisted_size = response.persisted_size - self.offset = self.persisted_size - self.bytes_appended_since_last_flush = 0 - start_idx = end_idx + if not data: + return # Do nothing for empty data + + await self._upload_with_retry(data, retry_policy, metadata) async def simple_flush(self) -> None: """Flushes the data to the server. diff --git a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py index c6ae36339..1c1d9849b 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py @@ -15,12 +15,18 @@ from typing import Any, Dict, IO, Iterable, Optional, Union import google_crc32c +from google.api_core import exceptions +from google.rpc import status_pb2 from google.cloud._storage_v2.types import storage as storage_type from google.cloud._storage_v2.types.storage import BidiWriteObjectRedirectedError from google.cloud.storage._experimental.asyncio.retry.base_strategy import ( _BaseResumptionStrategy, ) +_BIDI_WRITE_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" +) + class _WriteState: """A helper class to track the state of a single upload operation. diff --git a/tests/unit/asyncio/test_async_appendable_object_writer.py b/tests/unit/asyncio/test_async_appendable_object_writer.py index 07f7047d8..c496b5810 100644 --- a/tests/unit/asyncio/test_async_appendable_object_writer.py +++ b/tests/unit/asyncio/test_async_appendable_object_writer.py @@ -675,3 +675,57 @@ async def test_append_from_file(file_size, block_size, mock_client): else file_size // block_size + 1 ) assert writer.append.await_count == exepected_calls + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager" +) +async def test_append_with_retry_on_service_unavailable( + mock_retry_manager_class, mock_client +): + """Test that append retries on ServiceUnavailable.""" + # Arrange + writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) + writer._is_stream_open = True + writer.write_handle = WRITE_HANDLE + + mock_retry_manager = mock_retry_manager_class.return_value + mock_retry_manager.execute = mock.AsyncMock( + side_effect=[exceptions.ServiceUnavailable("testing"), None] + ) + + data_to_append = b"some data" + + # Act + await writer.append(data_to_append) + + # Assert + assert mock_retry_manager.execute.await_count == 2 + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager" +) +async def test_append_with_non_retryable_error( + mock_retry_manager_class, mock_client +): + """Test that append does not retry on non-retriable errors.""" + # Arrange + writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) + writer._is_stream_open = True + writer.write_handle = WRITE_HANDLE + + mock_retry_manager = mock_retry_manager_class.return_value + mock_retry_manager.execute = mock.AsyncMock( + side_effect=exceptions.BadRequest("testing") + ) + + data_to_append = b"some data" + + # Act & Assert + with pytest.raises(exceptions.BadRequest): + await writer.append(data_to_append) + + assert mock_retry_manager.execute.await_count == 1 From ccf667a541563a42919234f6b02479970d400a88 Mon Sep 17 00:00:00 2001 From: Pulkit Aggarwal Date: Thu, 8 Jan 2026 07:16:08 +0000 Subject: [PATCH 05/10] more changes --- .../asyncio/async_appendable_object_writer.py | 258 ++++++++-------- .../asyncio/async_write_object_stream.py | 35 ++- .../retry/bidi_stream_retry_manager.py | 4 + .../retry/writes_resumption_strategy.py | 30 +- tests/conformance/test_bidi_writes.py | 281 ++++++++++++++++++ 5 files changed, 463 insertions(+), 145 deletions(-) create mode 100644 tests/conformance/test_bidi_writes.py diff --git a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py index 7c1bec4f8..7d7cc2206 100644 --- a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py +++ b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +from __future__ import annotations + NOTE: This is _experimental module for upcoming support for Rapid Storage. (https://cloud.google.com/blog/products/storage-data-transfer/high-performance-storage-innovations-for-ai-hpc#:~:text=your%20AI%20workloads%3A-,Rapid%20Storage,-%3A%20A%20new) @@ -23,6 +25,7 @@ """ from io import BufferedReader, BytesIO import asyncio +import io from typing import List, Optional, Tuple, Union from google_crc32c import Checksum @@ -30,6 +33,7 @@ from google.api_core.retry_async import AsyncRetry from google.rpc import status_pb2 from google.cloud._storage_v2.types import BidiWriteObjectRedirectedError +from google.cloud._storage_v2.types.storage import BidiWriteObjectRequest from ._utils import raise_if_no_fast_crc32c @@ -58,6 +62,9 @@ def _is_write_retryable(exc): """Predicate to determine if a write operation should be retried.""" + + print("In _is_write_retryable method, exception:", exc) + if isinstance( exc, ( @@ -192,6 +199,17 @@ def __init__( self.bytes_appended_since_last_flush = 0 self._lock = asyncio.Lock() self._routing_token: Optional[str] = None + self.object_resource: Optional[_storage_v2.Object] = None + + def _stream_opener(self, write_handle=None): + """Helper to create a new _AsyncWriteObjectStream.""" + return _AsyncWriteObjectStream( + client=self.client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation, + write_handle=write_handle if write_handle else self.write_handle, + ) async def state_lookup(self) -> int: """Returns the persisted_size @@ -205,14 +223,15 @@ async def state_lookup(self) -> int: if not self._is_stream_open: raise ValueError("Stream is not open. Call open() before state_lookup().") - await self.write_obj_stream.send( - _storage_v2.BidiWriteObjectRequest( - state_lookup=True, + async with self._lock: + await self.write_obj_stream.send( + _storage_v2.BidiWriteObjectRequest( + state_lookup=True, + ) ) - ) - response = await self.write_obj_stream.recv() - self.persisted_size = response.persisted_size - return self.persisted_size + response = await self.write_obj_stream.recv() + self.persisted_size = response.persisted_size + return self.persisted_size def _on_open_error(self, exc): """Extracts routing token and write handle on redirect error during open.""" @@ -288,6 +307,7 @@ def combined_on_error(exc): ).with_on_error(combined_on_error) async def _do_open(): + print("In _do_open method") current_metadata = list(metadata) if metadata else [] # Cleanup stream from previous failed attempt, if any. @@ -314,6 +334,7 @@ async def _do_open(): ) self._routing_token = None + print("Current metadata in _do_open:", current_metadata) await self.write_obj_stream.open( metadata=current_metadata if metadata else None ) @@ -327,108 +348,9 @@ async def _do_open(): self._is_stream_open = True + print("In open method, before retry_policy call") await retry_policy(_do_open)() - async def _upload_with_retry( - self, - data: bytes, - retry_policy: Optional[AsyncRetry] = None, - metadata: Optional[List[Tuple[str, str]]] = None, - ) -> None: - if not self._is_stream_open: - raise ValueError("Underlying bidi-gRPC stream is not open") - - if retry_policy is None: - retry_policy = AsyncRetry(predicate=_is_write_retryable) - - # Initialize Global State for Retry Strategy - spec = _storage_v2.AppendObjectSpec( - bucket=self.bucket_name, - object=self.object_name, - generation=self.generation, - ) - buffer = BytesIO(data) - write_state = _WriteState( - spec=spec, - chunk_size=_MAX_CHUNK_SIZE_BYTES, - user_buffer=buffer, - ) - write_state.write_handle = self.write_handle - - initial_state = { - "write_state": write_state, - "first_request": True, - } - - # Track attempts to manage stream reuse - attempt_count = 0 - - def stream_opener( - requests, - state, - metadata: Optional[List[Tuple[str, str]]] = None, - ): - async def generator(): - nonlocal attempt_count - attempt_count += 1 - - async with self._lock: - current_handle = state["write_state"].write_handle - current_token = state["write_state"].routing_token - - should_reopen = (attempt_count > 1) or (current_token is not None) - - if should_reopen: - if self.write_obj_stream and self.write_obj_stream._is_stream_open: - await self.write_obj_stream.close() - - self.write_obj_stream = _AsyncWriteObjectStream( - client=self.client, - bucket_name=self.bucket_name, - object_name=self.object_name, - generation_number=self.generation, - write_handle=current_handle, - ) - - current_metadata = list(metadata) if metadata else [] - if current_token: - current_metadata.append( - ( - "x-goog-request-params", - f"routing_token={current_token}", - ) - ) - - await self.write_obj_stream.open( - metadata=current_metadata if current_metadata else None - ) - self._is_stream_open = True - - # Let the strategy generate the request sequence - async for request in requests: - await self.write_obj_stream.send(request) - - # Signal that we are done sending requests. - await self.write_obj_stream.requests.put(None) - - # Process responses - async for response in self.write_obj_stream: - yield response - - return generator() - - strategy = _WriteResumptionStrategy() - retry_manager = _BidiStreamRetryManager( - strategy, lambda r, s: stream_opener(r, s, metadata=metadata) - ) - - await retry_manager.execute(initial_state, retry_policy) - - # Update the writer's state from the strategy's final state - final_write_state = initial_state["write_state"] - self.persisted_size = final_write_state.persisted_size - self.write_handle = final_write_state.write_handle - self.offset = self.persisted_size async def append( self, @@ -460,9 +382,93 @@ async def append( if not self._is_stream_open: raise ValueError("Stream is not open. Call open() before append().") if not data: - return # Do nothing for empty data + return + + if retry_policy is None: + retry_policy = AsyncRetry(predicate=_is_write_retryable) + + buffer = io.BytesIO(data) + target_persisted_size = self.persisted_size + len(data) + attempt_count = 0 + + print("In append method") + + def send_and_recv_generator(requests: List[BidiWriteObjectRequest], state: dict[str, _WriteState], metadata: Optional[List[Tuple[str, str]]] = None): + async def generator(): + print("In send_and_recv_generator") + nonlocal attempt_count + attempt_count += 1 + resp = None + async with self._lock: + write_state = state["write_state"] + # If this is a retry or redirect, we must re-open the stream + if attempt_count > 1 or write_state.routing_token: + print("Re-opening the stream inside send_and_recv_generator with attempt_count:", attempt_count) + if self.write_obj_stream and self.write_obj_stream.is_stream_open: + await self.write_obj_stream.close() + + self.write_obj_stream = self._stream_opener(write_handle=write_state.write_handle) + current_metadata = list(metadata) if metadata else [] + if write_state.routing_token: + current_metadata.append(("x-goog-request-params", f"routing_token={write_state.routing_token}")) + await self.write_obj_stream.open(metadata=current_metadata if current_metadata else None) + + self._is_stream_open = True + write_state.persisted_size = self.persisted_size + write_state.write_handle = self.write_handle + + print("Sending requests in send_and_recv_generator") + # req_iter = iter(requests) + + print("Starting to send requests") + for i, chunk_req in enumerate(requests): + if i == len(requests) - 1: + chunk_req.state_lookup = True + print("Sending chunk request") + await self.write_obj_stream.send(chunk_req) + print("Waiting to receive response") + print("Current persisted_size:", state["write_state"].persisted_size, "Target persisted_size:", target_persisted_size) + + resp = await self.write_obj_stream.recv() + if resp: + if resp.persisted_size is not None: + self.persisted_size = resp.persisted_size + state["write_state"].persisted_size = resp.persisted_size + if resp.write_handle: + self.write_handle = resp.write_handle + state["write_state"].write_handle = resp.write_handle + print("Received response in send_and_recv_generator", resp) + + yield resp + + # while state["write_state"].persisted_size < target_persisted_size: + # print("Waiting to receive response") + # print("Current persisted_size:", state["write_state"].persisted_size, "Target persisted_size:", target_persisted_size) + # resp = await self.write_obj_stream.recv() + # print("Received response in send_and_recv_generator", resp) + # if resp is None: + # break + # yield resp + return generator() + + # State initialization + spec = _storage_v2.AppendObjectSpec( + bucket=f"projects/_/buckets/{self.bucket_name}", object=self.object_name, generation=self.generation + ) + write_state = _WriteState(spec, _MAX_CHUNK_SIZE_BYTES, buffer) + write_state.write_handle = self.write_handle + write_state.persisted_size = self.persisted_size + write_state.bytes_sent = self.persisted_size + + print("Before creating retry manager") + retry_manager = _BidiStreamRetryManager(_WriteResumptionStrategy(), + lambda r, s: send_and_recv_generator(r, s, metadata)) + await retry_manager.execute({"write_state": write_state}, retry_policy) + + # Sync local markers + self.write_obj_stream.persisted_size = write_state.persisted_size + self.write_obj_stream.write_handle = write_state.write_handle - await self._upload_with_retry(data, retry_policy, metadata) async def simple_flush(self) -> None: """Flushes the data to the server. @@ -476,11 +482,12 @@ async def simple_flush(self) -> None: if not self._is_stream_open: raise ValueError("Stream is not open. Call open() before simple_flush().") - await self.write_obj_stream.send( - _storage_v2.BidiWriteObjectRequest( - flush=True, + async with self._lock: + await self.write_obj_stream.send( + _storage_v2.BidiWriteObjectRequest( + flush=True, + ) ) - ) async def flush(self) -> int: """Flushes the data to the server. @@ -494,16 +501,17 @@ async def flush(self) -> int: if not self._is_stream_open: raise ValueError("Stream is not open. Call open() before flush().") - await self.write_obj_stream.send( - _storage_v2.BidiWriteObjectRequest( - flush=True, - state_lookup=True, + async with self._lock: + await self.write_obj_stream.send( + _storage_v2.BidiWriteObjectRequest( + flush=True, + state_lookup=True, + ) ) - ) - response = await self.write_obj_stream.recv() - self.persisted_size = response.persisted_size - self.offset = self.persisted_size - return self.persisted_size + response = await self.write_obj_stream.recv() + self.persisted_size = response.persisted_size + self.offset = self.persisted_size + return self.persisted_size async def close(self, finalize_on_close=False) -> Union[int, _storage_v2.Object]: """Closes the underlying bidi-gRPC stream. @@ -553,10 +561,16 @@ async def finalize(self) -> _storage_v2.Object: if not self._is_stream_open: raise ValueError("Stream is not open. Call open() before finalize().") + print("In finalize method") + + # async with self._lock: + print("Sending finish_write request") await self.write_obj_stream.send( _storage_v2.BidiWriteObjectRequest(finish_write=True) ) + print("Waiting to receive response for finalize") response = await self.write_obj_stream.recv() + print("Received response for finalize:") self.object_resource = response.resource self.persisted_size = self.object_resource.size await self.write_obj_stream.close() diff --git a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py index 183a8eeb1..62c230f69 100644 --- a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py @@ -21,7 +21,7 @@ if you want to use these Rapid Storage APIs. """ -from typing import Optional +from typing import List, Optional, Tuple from google.cloud import _storage_v2 from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient from google.cloud.storage._experimental.asyncio.async_abstract_object_stream import ( @@ -91,13 +91,15 @@ def __init__( self.persisted_size = 0 self.object_resource: Optional[_storage_v2.Object] = None - async def open(self) -> None: + async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: """Opening an object for write , should do it's state lookup to know what's the persisted size is. """ if self._is_stream_open: raise ValueError("Stream is already open") + write_handle = self.write_handle if self.write_handle else None + # Create a new object or overwrite existing one if generation_number # is None. This makes it consistent with GCS JSON API behavior. # Created object type would be Appendable Object. @@ -116,15 +118,31 @@ async def open(self) -> None: bucket=self._full_bucket_name, object=self.object_name, generation=self.generation_number, + write_handle=write_handle, ), ) + request_params = [f"bucket={self._full_bucket_name}"] + other_metadata = [] + if metadata: + for key, value in metadata: + if key == "x-goog-request-params": + request_params.append(value) + else: + other_metadata.append((key, value)) + + current_metadata = other_metadata + current_metadata.append(("x-goog-request-params", ",".join(request_params))) + + print("Before sending first_bidi_write_req in open:", self.first_bidi_write_req) + self.socket_like_rpc = AsyncBidiRpc( - self.rpc, initial_request=self.first_bidi_write_req, metadata=self.metadata + self.rpc, initial_request=self.first_bidi_write_req, metadata=current_metadata ) await self.socket_like_rpc.open() # this is actually 1 send response = await self.socket_like_rpc.recv() + print("Received response on open") self._is_stream_open = True if not response.resource: @@ -181,7 +199,16 @@ async def recv(self) -> _storage_v2.BidiWriteObjectResponse: """ if not self._is_stream_open: raise ValueError("Stream is not open") - return await self.socket_like_rpc.recv() + response = await self.socket_like_rpc.recv() + # Update write_handle if present in response + if response: + if response.write_handle: + self.write_handle = response.write_handle + if response.persisted_size is not None: + self.persisted_size = response.persisted_size + if response.resource and response.resource.size: + self.persisted_size = response.resource.size + return response @property def is_stream_open(self) -> bool: diff --git a/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py b/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py index a8caae4eb..09cb16850 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py +++ b/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py @@ -49,10 +49,14 @@ async def execute(self, initial_state: Any, retry_policy): """ state = initial_state + print("Starting retry manager execute") + async def attempt(): requests = self._strategy.generate_requests(state) + print("Generated requests:", len(requests)) stream = self._send_and_recv(requests, state) try: + print("Starting to receive responses in attempt") async for response in stream: self._strategy.update_state_from_response(response, state) return diff --git a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py index 1c1d9849b..eb1d32fc4 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, IO, Iterable, Optional, Union +from typing import Any, Dict, IO, Iterable, List, Optional, Union import google_crc32c from google.api_core import exceptions @@ -62,7 +62,7 @@ class _WriteResumptionStrategy(_BaseResumptionStrategy): def generate_requests( self, state: Dict[str, Any] - ) -> Iterable[storage_type.BidiWriteObjectRequest]: + ) -> List[storage_type.BidiWriteObjectRequest]: """Generates BidiWriteObjectRequests to resume or continue the upload. For Appendable Objects, every stream opening should send an @@ -72,21 +72,9 @@ def generate_requests( """ write_state: _WriteState = state["write_state"] - initial_request = storage_type.BidiWriteObjectRequest() - - # Determine if we need to send WriteObjectSpec or AppendObjectSpec - if isinstance(write_state.spec, storage_type.WriteObjectSpec): - initial_request.write_object_spec = write_state.spec - else: - if write_state.write_handle: - write_state.spec.write_handle = write_state.write_handle - - if write_state.routing_token: - write_state.spec.routing_token = write_state.routing_token - initial_request.append_object_spec = write_state.spec - - yield initial_request + print("Generating requests from write state:", write_state) + requests = [] # The buffer should already be seeked to the correct position (persisted_size) # by the `recover_state_on_failure` method before this is called. while not write_state.is_finalized: @@ -94,7 +82,8 @@ def generate_requests( # End of File detection if not chunk: - return + print("No more data to read; ending request generation") + break checksummed_data = storage_type.ChecksummedData(content=chunk) checksum = google_crc32c.Checksum(chunk) @@ -106,14 +95,17 @@ def generate_requests( ) write_state.bytes_sent += len(chunk) - yield request + print("Yielding request", len(request.checksummed_data.content)) + requests.append(request) + return requests def update_state_from_response( self, response: storage_type.BidiWriteObjectResponse, state: Dict[str, Any] ) -> None: """Processes a server response and updates the write state.""" write_state: _WriteState = state["write_state"] - + if response is None: + return if response.persisted_size: write_state.persisted_size = response.persisted_size diff --git a/tests/conformance/test_bidi_writes.py b/tests/conformance/test_bidi_writes.py new file mode 100644 index 000000000..cc03e5a08 --- /dev/null +++ b/tests/conformance/test_bidi_writes.py @@ -0,0 +1,281 @@ +import asyncio +import uuid +import grpc +import requests + +from google.api_core import exceptions +from google.auth import credentials as auth_credentials +from google.cloud import _storage_v2 as storage_v2 + +from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, +) + +# --- Configuration --- +PROJECT_NUMBER = "12345" # A dummy project number is fine for the testbench. +GRPC_ENDPOINT = "localhost:8888" +HTTP_ENDPOINT = "http://localhost:9000" +CONTENT = b"A" * 1024 * 10 # 10 KB + + +async def run_test_scenario( + gapic_client, + http_client, + bucket_name, + object_name, + scenario, +): + """Runs a single fault-injection test scenario.""" + print(f"\n--- RUNNING SCENARIO: {scenario['name']} ---") + + retry_test_id = None + try: + # 1. Create a Retry Test resource on the testbench. + retry_test_config = { + "instructions": {scenario["method"]: [scenario["instruction"]]}, + "transport": "GRPC", + } + resp = http_client.post(f"{HTTP_ENDPOINT}/retry_test", json=retry_test_config) + resp.raise_for_status() + retry_test_id = resp.json()["id"] + + # 2. Set up writer and metadata for fault injection. + writer = AsyncAppendableObjectWriter( + gapic_client, + bucket_name, + object_name, + ) + fault_injection_metadata = (("x-retry-test-id", retry_test_id),) + + # 3. Execute the write and assert the outcome. + try: + print("Before calling open()") + await writer.open(metadata=fault_injection_metadata) + print("After calling open()") + await writer.append(CONTENT, metadata=fault_injection_metadata) + await writer.finalize() + await writer.close() + + # If an exception was expected, this line should not be reached. + if scenario["expected_error"] is not None: + raise AssertionError( + f"Expected exception {scenario['expected_error']} was not raised." + ) + + # 4. Verify the object content. + read_request = storage_v2.ReadObjectRequest( + bucket=f"projects/_/buckets/{bucket_name}", + object=object_name, + ) + read_stream = await gapic_client.read_object(request=read_request) + data = b"" + async for chunk in read_stream: + data += chunk.checksummed_data.content + assert data == CONTENT + + except Exception as e: + if scenario["expected_error"] is None or not isinstance( + e, scenario["expected_error"] + ): + raise + print(f"Caught expected exception for {scenario['name']}: {e}") + + finally: + # 5. Clean up the Retry Test resource. + if retry_test_id: + http_client.delete(f"{HTTP_ENDPOINT}/retry_test/{retry_test_id}") + + +async def main(): + """Main function to set up resources and run all test scenarios.""" + channel = grpc.aio.insecure_channel(GRPC_ENDPOINT) + creds = auth_credentials.AnonymousCredentials() + transport = storage_v2.services.storage.transports.StorageGrpcAsyncIOTransport( + channel=channel, + credentials=creds, + ) + gapic_client = storage_v2.StorageAsyncClient(transport=transport) + http_client = requests.Session() + + bucket_name = f"grpc-test-bucket-{uuid.uuid4().hex[:8]}" + object_name_prefix = "retry-test-object-" + + # Define all test scenarios + test_scenarios = [ + # { + # "name": "Retry on Service Unavailable (503)", + # "method": "storage.objects.insert", + # "instruction": "return-503", + # "expected_error": None, + # }, + # { + # "name": "Retry on 500", + # "method": "storage.objects.insert", + # "instruction": "return-500", + # "expected_error": None, + # }, + # { + # "name": "Retry on 504", + # "method": "storage.objects.insert", + # "instruction": "return-504", + # "expected_error": None, + # }, + # { + # "name": "Retry on 429", + # "method": "storage.objects.insert", + # "instruction": "return-429", + # "expected_error": None, + # }, + { + "name": "Smarter Resumption: Retry 503 after partial data", + "method": "storage.objects.insert", + "instruction": "return-broken-stream-after-2K", + "expected_error": None, + }, + { + "name": "Retry on BidiWriteObjectRedirectedError", + "method": "storage.objects.insert", + "instruction": "redirect-send-handle-and-token-tokenval", + "expected_error": None, + }, + { + "name": "Fail on 401", + "method": "storage.objects.insert", + "instruction": "return-401", + "expected_error": exceptions.Unauthorized, + }, + ] + + try: + bucket_resource = storage_v2.Bucket(project=f"projects/{PROJECT_NUMBER}") + create_bucket_request = storage_v2.CreateBucketRequest( + parent="projects/_", bucket_id=bucket_name, bucket=bucket_resource + ) + await gapic_client.create_bucket(request=create_bucket_request) + + for i, scenario in enumerate(test_scenarios): + object_name = f"{object_name_prefix}{i}" + await run_test_scenario( + gapic_client, + http_client, + bucket_name, + object_name, + scenario, + ) + + # Define and run test scenarios specifically for the open() method + # open_test_scenarios = [ + # { + # "name": "Open: Retry on 503", + # "method": "storage.objects.insert", + # "instruction": "return-503", + # "expected_error": None, + # }, + # { + # "name": "Open: Retry on BidiWriteObjectRedirectedError", + # "method": "storage.objects.insert", + # "instruction": "redirect-send-handle-and-token-tokenval", + # "expected_error": None, + # }, + # { + # "name": "Open: Fail Fast on 401", + # "method": "storage.objects.insert", + # "instruction": "return-401", + # "expected_error": exceptions.Unauthorized, + # }, + # ] + # for i, scenario in enumerate(open_test_scenarios): + # object_name = f"{object_name_prefix}-open-{i}" + # await run_open_test_scenario( + # gapic_client, + # http_client, + # bucket_name, + # object_name, + # scenario, + # ) + + except Exception: + import traceback + + traceback.print_exc() + finally: + # Clean up the test bucket. + try: + list_objects_req = storage_v2.ListObjectsRequest( + parent=f"projects/_/buckets/{bucket_name}", + ) + list_objects_res = await gapic_client.list_objects(request=list_objects_req) + async for obj in list_objects_res: + delete_object_req = storage_v2.DeleteObjectRequest( + bucket=f"projects/_/buckets/{bucket_name}", object=obj.name + ) + await gapic_client.delete_object(request=delete_object_req) + + delete_bucket_req = storage_v2.DeleteBucketRequest( + name=f"projects/_/buckets/{bucket_name}" + ) + await gapic_client.delete_bucket(request=delete_bucket_req) + except Exception as e: + print(f"Warning: Cleanup failed: {e}") + + +async def run_open_test_scenario( + gapic_client, + http_client, + bucket_name, + object_name, + scenario, +): + """Runs a fault-injection test scenario specifically for the open() method.""" + print(f"\n--- RUNNING SCENARIO: {scenario['name']} ---") + + retry_test_id = None + try: + # 1. Create a Retry Test resource on the testbench. + retry_test_config = { + "instructions": {scenario["method"]: [scenario["instruction"]]}, + "transport": "GRPC", + } + resp = http_client.post(f"{HTTP_ENDPOINT}/retry_test", json=retry_test_config) + resp.raise_for_status() + retry_test_id = resp.json()["id"] + print(f"Retry Test created with ID: {retry_test_id}") + + # 2. Set up metadata for fault injection. + fault_injection_metadata = (("x-retry-test-id", retry_test_id),) + + # 3. Execute the open and assert the outcome. + try: + writer = AsyncAppendableObjectWriter( + gapic_client, + bucket_name, + object_name, + ) + await writer.open(metadata=fault_injection_metadata) + + # If open was successful, perform a simple write to ensure the stream is usable. + await writer.append(CONTENT) + await writer.finalize() + await writer.close() + + # If an exception was expected, this line should not be reached. + if scenario["expected_error"] is not None: + raise AssertionError( + f"Expected exception {scenario['expected_error']} was not raised." + ) + + except Exception as e: + if scenario["expected_error"] is None or not isinstance( + e, scenario["expected_error"] + ): + raise + print(f"Caught expected exception for {scenario['name']}: {e}") + + finally: + # 4. Clean up the Retry Test resource. + if retry_test_id: + http_client.delete(f"{HTTP_ENDPOINT}/retry_test/{retry_test_id}") + + +if __name__ == "__main__": + asyncio.run(main()) From c13158830d0637fbb12cf3a2f8084a9c12d1e50f Mon Sep 17 00:00:00 2001 From: Pulkit Aggarwal Date: Mon, 12 Jan 2026 09:34:41 +0000 Subject: [PATCH 06/10] adding unit tests --- .../asyncio/async_appendable_object_writer.py | 128 +- .../asyncio/async_write_object_stream.py | 54 +- .../retry/bidi_stream_retry_manager.py | 4 - .../retry/writes_resumption_strategy.py | 82 +- tests/conformance/test_bidi_writes.py | 224 ++-- .../retry/test_writes_resumption_strategy.py | 445 ++++--- .../test_async_appendable_object_writer.py | 1170 +++++++---------- .../asyncio/test_async_write_object_stream.py | 792 +++++------ 8 files changed, 1423 insertions(+), 1476 deletions(-) diff --git a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py index 7d7cc2206..b68475c30 100644 --- a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py +++ b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py @@ -23,12 +23,12 @@ if you want to use these Rapid Storage APIs. """ -from io import BufferedReader, BytesIO +from io import BufferedReader import asyncio import io -from typing import List, Optional, Tuple, Union +import logging +from typing import List, Optional, Tuple -from google_crc32c import Checksum from google.api_core import exceptions from google.api_core.retry_async import AsyncRetry from google.rpc import status_pb2 @@ -58,13 +58,12 @@ _BIDI_WRITE_REDIRECTED_TYPE_URL = ( "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" ) +logger = logging.getLogger(__name__) def _is_write_retryable(exc): """Predicate to determine if a write operation should be retried.""" - print("In _is_write_retryable method, exception:", exc) - if isinstance( exc, ( @@ -74,6 +73,7 @@ def _is_write_retryable(exc): exceptions.TooManyRequests, ), ): + logger.info(f"Retryable write exception encountered: {exc}") return True grpc_error = None @@ -97,6 +97,7 @@ def _is_write_retryable(exc): if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL: return True except Exception: + logger.error("Error unpacking redirect details from gRPC error.") return False return False @@ -201,16 +202,6 @@ def __init__( self._routing_token: Optional[str] = None self.object_resource: Optional[_storage_v2.Object] = None - def _stream_opener(self, write_handle=None): - """Helper to create a new _AsyncWriteObjectStream.""" - return _AsyncWriteObjectStream( - client=self.client, - bucket_name=self.bucket_name, - object_name=self.object_name, - generation_number=self.generation, - write_handle=write_handle if write_handle else self.write_handle, - ) - async def state_lookup(self) -> int: """Returns the persisted_size @@ -244,6 +235,8 @@ def _on_open_error(self, exc): self._routing_token = grpc_error.routing_token if grpc_error.write_handle: self.write_handle = grpc_error.write_handle + if grpc_error.generation: + self.generation = grpc_error.generation return if hasattr(grpc_error, "trailing_metadata"): @@ -272,9 +265,13 @@ def _on_open_error(self, exc): self._routing_token = redirect_proto.routing_token if redirect_proto.write_handle: self.write_handle = redirect_proto.write_handle + if redirect_proto.generation: + self.generation = redirect_proto.generation break except Exception: - # Could not parse the error, ignore + logger.error( + "Error unpacking redirect details from gRPC error." + ) pass async def open( @@ -302,12 +299,16 @@ def combined_on_error(exc): if original_on_error: original_on_error(exc) - retry_policy = retry_policy.with_predicate( - _is_write_retryable - ).with_on_error(combined_on_error) + retry_policy = AsyncRetry( + predicate=_is_write_retryable, + initial=retry_policy._initial, + maximum=retry_policy._maximum, + multiplier=retry_policy._multiplier, + deadline=retry_policy._deadline, + on_error=combined_on_error, + ) async def _do_open(): - print("In _do_open method") current_metadata = list(metadata) if metadata else [] # Cleanup stream from previous failed attempt, if any. @@ -326,15 +327,14 @@ async def _do_open(): object_name=self.object_name, generation_number=self.generation, write_handle=self.write_handle, + routing_token=self._routing_token, ) if self._routing_token: current_metadata.append( ("x-goog-request-params", f"routing_token={self._routing_token}") ) - self._routing_token = None - print("Current metadata in _do_open:", current_metadata) await self.write_obj_stream.open( metadata=current_metadata if metadata else None ) @@ -347,11 +347,10 @@ async def _do_open(): self.persisted_size = self.write_obj_stream.persisted_size self._is_stream_open = True + self._routing_token = None - print("In open method, before retry_policy call") await retry_policy(_do_open)() - async def append( self, data: bytes, @@ -387,47 +386,62 @@ async def append( if retry_policy is None: retry_policy = AsyncRetry(predicate=_is_write_retryable) + strategy = _WriteResumptionStrategy() buffer = io.BytesIO(data) - target_persisted_size = self.persisted_size + len(data) attempt_count = 0 - print("In append method") - - def send_and_recv_generator(requests: List[BidiWriteObjectRequest], state: dict[str, _WriteState], metadata: Optional[List[Tuple[str, str]]] = None): + def send_and_recv_generator( + requests: List[BidiWriteObjectRequest], + state: dict[str, _WriteState], + metadata: Optional[List[Tuple[str, str]]] = None, + ): async def generator(): - print("In send_and_recv_generator") nonlocal attempt_count + nonlocal requests attempt_count += 1 resp = None async with self._lock: write_state = state["write_state"] # If this is a retry or redirect, we must re-open the stream if attempt_count > 1 or write_state.routing_token: - print("Re-opening the stream inside send_and_recv_generator with attempt_count:", attempt_count) - if self.write_obj_stream and self.write_obj_stream.is_stream_open: + logger.info( + f"Re-opening the stream with attempt_count: {attempt_count}" + ) + if ( + self.write_obj_stream + and self.write_obj_stream.is_stream_open + ): await self.write_obj_stream.close() - self.write_obj_stream = self._stream_opener(write_handle=write_state.write_handle) current_metadata = list(metadata) if metadata else [] if write_state.routing_token: - current_metadata.append(("x-goog-request-params", f"routing_token={write_state.routing_token}")) - await self.write_obj_stream.open(metadata=current_metadata if current_metadata else None) + current_metadata.append( + ( + "x-goog-request-params", + f"routing_token={write_state.routing_token}", + ) + ) + self._routing_token = write_state.routing_token + + self._is_stream_open = False + await self.open(metadata=current_metadata) - self._is_stream_open = True write_state.persisted_size = self.persisted_size write_state.write_handle = self.write_handle + write_state.routing_token = None - print("Sending requests in send_and_recv_generator") - # req_iter = iter(requests) + write_state.user_buffer.seek(write_state.persisted_size) + write_state.bytes_sent = write_state.persisted_size + write_state.bytes_since_last_flush = 0 - print("Starting to send requests") + requests = strategy.generate_requests(state) + + num_requests = len(requests) for i, chunk_req in enumerate(requests): - if i == len(requests) - 1: + if i == num_requests - 1: chunk_req.state_lookup = True - print("Sending chunk request") + chunk_req.flush = True await self.write_obj_stream.send(chunk_req) - print("Waiting to receive response") - print("Current persisted_size:", state["write_state"].persisted_size, "Target persisted_size:", target_persisted_size) resp = await self.write_obj_stream.recv() if resp: @@ -437,38 +451,28 @@ async def generator(): if resp.write_handle: self.write_handle = resp.write_handle state["write_state"].write_handle = resp.write_handle - print("Received response in send_and_recv_generator", resp) yield resp - # while state["write_state"].persisted_size < target_persisted_size: - # print("Waiting to receive response") - # print("Current persisted_size:", state["write_state"].persisted_size, "Target persisted_size:", target_persisted_size) - # resp = await self.write_obj_stream.recv() - # print("Received response in send_and_recv_generator", resp) - # if resp is None: - # break - # yield resp return generator() # State initialization - spec = _storage_v2.AppendObjectSpec( - bucket=f"projects/_/buckets/{self.bucket_name}", object=self.object_name, generation=self.generation - ) - write_state = _WriteState(spec, _MAX_CHUNK_SIZE_BYTES, buffer) + write_state = _WriteState(_MAX_CHUNK_SIZE_BYTES, buffer, self.flush_interval) write_state.write_handle = self.write_handle write_state.persisted_size = self.persisted_size write_state.bytes_sent = self.persisted_size + write_state.bytes_since_last_flush = self.bytes_appended_since_last_flush - print("Before creating retry manager") - retry_manager = _BidiStreamRetryManager(_WriteResumptionStrategy(), - lambda r, s: send_and_recv_generator(r, s, metadata)) + retry_manager = _BidiStreamRetryManager( + _WriteResumptionStrategy(), + lambda r, s: send_and_recv_generator(r, s, metadata), + ) await retry_manager.execute({"write_state": write_state}, retry_policy) # Sync local markers self.write_obj_stream.persisted_size = write_state.persisted_size self.write_obj_stream.write_handle = write_state.write_handle - + self.bytes_appended_since_last_flush = write_state.bytes_since_last_flush async def simple_flush(self) -> None: """Flushes the data to the server. @@ -561,16 +565,10 @@ async def finalize(self) -> _storage_v2.Object: if not self._is_stream_open: raise ValueError("Stream is not open. Call open() before finalize().") - print("In finalize method") - - # async with self._lock: - print("Sending finish_write request") await self.write_obj_stream.send( _storage_v2.BidiWriteObjectRequest(finish_write=True) ) - print("Waiting to receive response for finalize") response = await self.write_obj_stream.recv() - print("Received response for finalize:") self.object_resource = response.resource self.persisted_size = self.object_resource.size await self.write_obj_stream.close() diff --git a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py index 62c230f69..45a4cf072 100644 --- a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py @@ -62,6 +62,7 @@ def __init__( object_name: str, generation_number: Optional[int] = None, # None means new object write_handle: Optional[bytes] = None, + routing_token: Optional[str] = None, ) -> None: if client is None: raise ValueError("client must be provided") @@ -77,6 +78,7 @@ def __init__( ) self.client: AsyncGrpcClient.grpc_client = client self.write_handle: Optional[bytes] = write_handle + self.routing_token: Optional[str] = routing_token self._full_bucket_name = f"projects/_/buckets/{self.bucket_name}" @@ -119,6 +121,7 @@ async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: object=self.object_name, generation=self.generation_number, write_handle=write_handle, + routing_token=self.routing_token if self.routing_token else None, ), ) @@ -134,37 +137,44 @@ async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: current_metadata = other_metadata current_metadata.append(("x-goog-request-params", ",".join(request_params))) - print("Before sending first_bidi_write_req in open:", self.first_bidi_write_req) - self.socket_like_rpc = AsyncBidiRpc( - self.rpc, initial_request=self.first_bidi_write_req, metadata=current_metadata + self.rpc, + initial_request=self.first_bidi_write_req, + metadata=current_metadata, ) await self.socket_like_rpc.open() # this is actually 1 send response = await self.socket_like_rpc.recv() - print("Received response on open") self._is_stream_open = True - if not response.resource: - raise ValueError( - "Failed to obtain object resource after opening the stream" - ) - if not response.resource.generation: - raise ValueError( - "Failed to obtain object generation after opening the stream" - ) - - if not response.write_handle: - raise ValueError("Failed to obtain write_handle after opening the stream") + if response.persisted_size >= 0: + self.persisted_size = response.persisted_size - if not response.resource.size: - # Appending to a 0 byte appendable object. - self.persisted_size = 0 - else: - self.persisted_size = response.resource.size + if response.write_handle: + self.write_handle = response.write_handle + # return + + # if not response.resource: + # raise ValueError( + # "Failed to obtain object resource after opening the stream" + # ) + # if not response.resource.generation: + # raise ValueError( + # "Failed to obtain object generation after opening the stream" + # ) + + # if not response.write_handle: + # raise ValueError("Failed to obtain write_handle after opening the stream") + + if response.resource: + if not response.resource.size: + # Appending to a 0 byte appendable object. + self.persisted_size = 0 + else: + self.persisted_size = response.resource.size - self.generation_number = response.resource.generation - self.write_handle = response.write_handle + self.generation_number = response.resource.generation + self.write_handle = response.write_handle async def close(self) -> None: """Closes the bidi-gRPC connection.""" diff --git a/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py b/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py index 09cb16850..a8caae4eb 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py +++ b/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py @@ -49,14 +49,10 @@ async def execute(self, initial_state: Any, retry_policy): """ state = initial_state - print("Starting retry manager execute") - async def attempt(): requests = self._strategy.generate_requests(state) - print("Generated requests:", len(requests)) stream = self._send_and_recv(requests, state) try: - print("Starting to receive responses in attempt") async for response in stream: self._strategy.update_state_from_response(response, state) return diff --git a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py index eb1d32fc4..7a2a84d16 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, IO, Iterable, List, Optional, Union +from typing import Any, Dict, IO, List, Optional, Union import google_crc32c from google.api_core import exceptions @@ -31,27 +31,28 @@ class _WriteState: """A helper class to track the state of a single upload operation. - :type spec: :class:`google.cloud.storage_v2.types.AppendObjectSpec` - :param spec: The specification for the object to write. - :type chunk_size: int :param chunk_size: The size of chunks to write to the server. :type user_buffer: IO[bytes] :param user_buffer: The data source. + + :type flush_interval: Optional[int] + :param flush_interval: The flush interval at which the data is flushed. """ def __init__( self, - spec: Union[storage_type.AppendObjectSpec, storage_type.WriteObjectSpec], chunk_size: int, user_buffer: IO[bytes], + flush_interval: Optional[int] = None, ): - self.spec = spec self.chunk_size = chunk_size self.user_buffer = user_buffer self.persisted_size: int = 0 self.bytes_sent: int = 0 + self.bytes_since_last_flush: int = 0 + self.flush_interval: Optional[int] = flush_interval self.write_handle: Union[bytes, storage_type.BidiWriteHandle, None] = None self.routing_token: Optional[str] = None self.is_finalized: bool = False @@ -65,15 +66,10 @@ def generate_requests( ) -> List[storage_type.BidiWriteObjectRequest]: """Generates BidiWriteObjectRequests to resume or continue the upload. - For Appendable Objects, every stream opening should send an - AppendObjectSpec. If resuming, the `write_handle` is added to that spec. - This method is not applicable for `open` methods. """ write_state: _WriteState = state["write_state"] - print("Generating requests from write state:", write_state) - requests = [] # The buffer should already be seeked to the correct position (persisted_size) # by the `recover_state_on_failure` method before this is called. @@ -82,7 +78,6 @@ def generate_requests( # End of File detection if not chunk: - print("No more data to read; ending request generation") break checksummed_data = storage_type.ChecksummedData(content=chunk) @@ -93,9 +88,18 @@ def generate_requests( write_offset=write_state.bytes_sent, checksummed_data=checksummed_data, ) - write_state.bytes_sent += len(chunk) + chunk_len = len(chunk) + write_state.bytes_sent += chunk_len + write_state.bytes_since_last_flush += chunk_len + + if ( + write_state.flush_interval + and write_state.bytes_since_last_flush >= write_state.flush_interval + ): + request.flush = True + # reset counter after marking flush + write_state.bytes_since_last_flush = 0 - print("Yielding request", len(request.checksummed_data.content)) requests.append(request) return requests @@ -127,18 +131,50 @@ async def recover_state_on_failure( last confirmed 'persisted_size' from the server. """ write_state: _WriteState = state["write_state"] - cause = getattr(error, "cause", error) - - # Extract routing token and potentially a new write handle for redirection. - if isinstance(cause, BidiWriteObjectRedirectedError): - if cause.routing_token: - write_state.routing_token = cause.routing_token - redirect_handle = getattr(cause, "write_handle", None) - if redirect_handle: - write_state.write_handle = redirect_handle + grpc_error = None + if isinstance(error, exceptions.Aborted) and error.errors: + grpc_error = error.errors[0] + + if grpc_error: + # Extract routing token and potentially a new write handle for redirection. + if isinstance(grpc_error, BidiWriteObjectRedirectedError): + self._routing_token = grpc_error.routing_token + if grpc_error.write_handle: + self.write_handle = grpc_error.write_handle + return + if hasattr(grpc_error, "trailing_metadata"): + trailers = grpc_error.trailing_metadata() + if not trailers: + return + + status_details_bin = None + for key, value in trailers: + if key == "grpc-status-details-bin": + status_details_bin = value + break + + if status_details_bin: + status_proto = status_pb2.Status() + try: + status_proto.ParseFromString(status_details_bin) + for detail in status_proto.details: + if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL: + redirect_proto = ( + BidiWriteObjectRedirectedError.deserialize( + detail.value + ) + ) + if redirect_proto.routing_token: + write_state._routing_token = redirect_proto.routing_token + if redirect_proto.write_handle: + write_state.write_handle = redirect_proto.write_handle + break + except Exception: + pass # We must assume any data sent beyond 'persisted_size' was lost. # Reset the user buffer to the last known good byte confirmed by the server. write_state.user_buffer.seek(write_state.persisted_size) write_state.bytes_sent = write_state.persisted_size + write_state.bytes_since_last_flush = 0 diff --git a/tests/conformance/test_bidi_writes.py b/tests/conformance/test_bidi_writes.py index cc03e5a08..4adfd266a 100644 --- a/tests/conformance/test_bidi_writes.py +++ b/tests/conformance/test_bidi_writes.py @@ -7,6 +7,7 @@ from google.auth import credentials as auth_credentials from google.cloud import _storage_v2 as storage_v2 +from google.api_core.retry_async import AsyncRetry from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( AsyncAppendableObjectWriter, ) @@ -18,6 +19,19 @@ CONTENT = b"A" * 1024 * 10 # 10 KB +def _is_retryable(exc): + return isinstance( + exc, + ( + exceptions.InternalServerError, + exceptions.ServiceUnavailable, + exceptions.DeadlineExceeded, + exceptions.TooManyRequests, + exceptions.Aborted, # For Redirects + ), + ) + + async def run_test_scenario( gapic_client, http_client, @@ -27,6 +41,22 @@ async def run_test_scenario( ): """Runs a single fault-injection test scenario.""" print(f"\n--- RUNNING SCENARIO: {scenario['name']} ---") + retry_count = 0 + + def on_retry_error(exc): + nonlocal retry_count + retry_count += 1 + print(f"Retry attempt {retry_count} triggered by: {type(exc).__name__}") + + custom_retry = AsyncRetry( + predicate=_is_retryable, + on_error=on_retry_error, + initial=0.1, # Short backoff for fast tests + multiplier=1.0, + ) + + use_default = scenario.get("use_default_policy", False) + policy_to_pass = None if use_default else custom_retry retry_test_id = None try: @@ -49,10 +79,12 @@ async def run_test_scenario( # 3. Execute the write and assert the outcome. try: - print("Before calling open()") - await writer.open(metadata=fault_injection_metadata) - print("After calling open()") - await writer.append(CONTENT, metadata=fault_injection_metadata) + await writer.open( + metadata=fault_injection_metadata, retry_policy=policy_to_pass + ) + await writer.append( + CONTENT, metadata=fault_injection_metadata, retry_policy=policy_to_pass + ) await writer.finalize() await writer.close() @@ -72,13 +104,28 @@ async def run_test_scenario( async for chunk in read_stream: data += chunk.checksummed_data.content assert data == CONTENT + if scenario["expected_error"] is None: + # Scenarios like 503, 500, smarter resumption, and redirects + # SHOULD trigger at least one retry attempt. + if not use_default: + assert ( + retry_count > 0 + ), f"Test passed but no retry was actually triggered for {scenario['name']}!" + else: + print("Successfully recovered using library's default policy.") + print(f"Success: {scenario['name']}") except Exception as e: if scenario["expected_error"] is None or not isinstance( e, scenario["expected_error"] ): raise - print(f"Caught expected exception for {scenario['name']}: {e}") + + if not use_default: + assert ( + retry_count == 0 + ), f"Retry was incorrectly triggered for non-retriable error in {scenario['name']}!" + print(f"Success: caught expected exception for {scenario['name']}: {e}") finally: # 5. Clean up the Retry Test resource. @@ -102,34 +149,34 @@ async def main(): # Define all test scenarios test_scenarios = [ - # { - # "name": "Retry on Service Unavailable (503)", - # "method": "storage.objects.insert", - # "instruction": "return-503", - # "expected_error": None, - # }, - # { - # "name": "Retry on 500", - # "method": "storage.objects.insert", - # "instruction": "return-500", - # "expected_error": None, - # }, - # { - # "name": "Retry on 504", - # "method": "storage.objects.insert", - # "instruction": "return-504", - # "expected_error": None, - # }, - # { - # "name": "Retry on 429", - # "method": "storage.objects.insert", - # "instruction": "return-429", - # "expected_error": None, - # }, + { + "name": "Retry on Service Unavailable (503)", + "method": "storage.objects.insert", + "instruction": "return-503", + "expected_error": None, + }, + { + "name": "Retry on 500", + "method": "storage.objects.insert", + "instruction": "return-500", + "expected_error": None, + }, + { + "name": "Retry on 504", + "method": "storage.objects.insert", + "instruction": "return-504", + "expected_error": None, + }, + { + "name": "Retry on 429", + "method": "storage.objects.insert", + "instruction": "return-429", + "expected_error": None, + }, { "name": "Smarter Resumption: Retry 503 after partial data", "method": "storage.objects.insert", - "instruction": "return-broken-stream-after-2K", + "instruction": "return-503-after-2K", "expected_error": None, }, { @@ -144,6 +191,34 @@ async def main(): "instruction": "return-401", "expected_error": exceptions.Unauthorized, }, + { + "name": "Default Policy: Retry on 503", + "method": "storage.objects.insert", + "instruction": "return-503", + "expected_error": None, + "use_default_policy": True, + }, + { + "name": "Default Policy: Retry on 503", + "method": "storage.objects.insert", + "instruction": "return-500", + "expected_error": None, + "use_default_policy": True, + }, + { + "name": "Default Policy: Retry on BidiWriteObjectRedirectedError", + "method": "storage.objects.insert", + "instruction": "redirect-send-handle-and-token-tokenval", + "expected_error": None, + "use_default_policy": True, + }, + { + "name": "Default Policy: Smarter Ressumption", + "method": "storage.objects.insert", + "instruction": "return-503-after-2K", + "expected_error": None, + "use_default_policy": True, + }, ] try: @@ -163,37 +238,6 @@ async def main(): scenario, ) - # Define and run test scenarios specifically for the open() method - # open_test_scenarios = [ - # { - # "name": "Open: Retry on 503", - # "method": "storage.objects.insert", - # "instruction": "return-503", - # "expected_error": None, - # }, - # { - # "name": "Open: Retry on BidiWriteObjectRedirectedError", - # "method": "storage.objects.insert", - # "instruction": "redirect-send-handle-and-token-tokenval", - # "expected_error": None, - # }, - # { - # "name": "Open: Fail Fast on 401", - # "method": "storage.objects.insert", - # "instruction": "return-401", - # "expected_error": exceptions.Unauthorized, - # }, - # ] - # for i, scenario in enumerate(open_test_scenarios): - # object_name = f"{object_name_prefix}-open-{i}" - # await run_open_test_scenario( - # gapic_client, - # http_client, - # bucket_name, - # object_name, - # scenario, - # ) - except Exception: import traceback @@ -219,63 +263,5 @@ async def main(): print(f"Warning: Cleanup failed: {e}") -async def run_open_test_scenario( - gapic_client, - http_client, - bucket_name, - object_name, - scenario, -): - """Runs a fault-injection test scenario specifically for the open() method.""" - print(f"\n--- RUNNING SCENARIO: {scenario['name']} ---") - - retry_test_id = None - try: - # 1. Create a Retry Test resource on the testbench. - retry_test_config = { - "instructions": {scenario["method"]: [scenario["instruction"]]}, - "transport": "GRPC", - } - resp = http_client.post(f"{HTTP_ENDPOINT}/retry_test", json=retry_test_config) - resp.raise_for_status() - retry_test_id = resp.json()["id"] - print(f"Retry Test created with ID: {retry_test_id}") - - # 2. Set up metadata for fault injection. - fault_injection_metadata = (("x-retry-test-id", retry_test_id),) - - # 3. Execute the open and assert the outcome. - try: - writer = AsyncAppendableObjectWriter( - gapic_client, - bucket_name, - object_name, - ) - await writer.open(metadata=fault_injection_metadata) - - # If open was successful, perform a simple write to ensure the stream is usable. - await writer.append(CONTENT) - await writer.finalize() - await writer.close() - - # If an exception was expected, this line should not be reached. - if scenario["expected_error"] is not None: - raise AssertionError( - f"Expected exception {scenario['expected_error']} was not raised." - ) - - except Exception as e: - if scenario["expected_error"] is None or not isinstance( - e, scenario["expected_error"] - ): - raise - print(f"Caught expected exception for {scenario['name']}: {e}") - - finally: - # 4. Clean up the Retry Test resource. - if retry_test_id: - http_client.delete(f"{HTTP_ENDPOINT}/retry_test/{retry_test_id}") - - if __name__ == "__main__": asyncio.run(main()) diff --git a/tests/unit/asyncio/retry/test_writes_resumption_strategy.py b/tests/unit/asyncio/retry/test_writes_resumption_strategy.py index 7d8b7934e..556920eba 100644 --- a/tests/unit/asyncio/retry/test_writes_resumption_strategy.py +++ b/tests/unit/asyncio/retry/test_writes_resumption_strategy.py @@ -19,6 +19,8 @@ import pytest import google_crc32c +from google.rpc import status_pb2 +from google.api_core import exceptions from google.cloud._storage_v2.types import storage as storage_type from google.cloud.storage._experimental.asyncio.retry.writes_resumption_strategy import ( @@ -29,135 +31,184 @@ class TestWriteResumptionStrategy(unittest.TestCase): - def _get_target_class(self): - return _WriteResumptionStrategy + def _make_one(self): + return _WriteResumptionStrategy() - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) + # ------------------------------------------------------------------------- + # Tests for generate_requests + # ------------------------------------------------------------------------- - def test_ctor(self): + def test_generate_requests_initial_chunking(self): + """Verify initial data generation starts at offset 0 and chunks correctly.""" strategy = self._make_one() - self.assertIsInstance(strategy, self._get_target_class()) + mock_buffer = io.BytesIO(b"abcdefghij") + write_state = _WriteState(chunk_size=3, user_buffer=mock_buffer) + state = {"write_state": write_state} - def test_generate_requests_initial_new_object(self): - """ - Verify the initial request sequence for a new upload (WriteObjectSpec). - """ - strategy = self._make_one() - mock_buffer = io.BytesIO(b"0123456789") - # Use WriteObjectSpec for new objects - mock_spec = storage_type.WriteObjectSpec( - resource=storage_type.Object(name="test-object") - ) - state = { - "write_state": _WriteState( - mock_spec, chunk_size=4, user_buffer=mock_buffer - ), - } + requests = strategy.generate_requests(state) - requests = list(strategy.generate_requests(state)) + # Expected: 4 requests (3, 3, 3, 1) + self.assertEqual(len(requests), 4) - # Check first request (Spec) - self.assertEqual(requests[0].write_object_spec, mock_spec) - self.assertFalse(requests[0].state_lookup) + # Verify Request 1 + self.assertEqual(requests[0].write_offset, 0) + self.assertEqual(requests[0].checksummed_data.content, b"abc") - # Check data chunks - self.assertEqual(requests[1].write_offset, 0) - self.assertEqual(requests[1].checksummed_data.content, b"0123") - self.assertEqual(requests[2].write_offset, 4) - self.assertEqual(requests[2].checksummed_data.content, b"4567") - self.assertEqual(requests[3].write_offset, 8) - self.assertEqual(requests[3].checksummed_data.content, b"89") + # Verify Request 2 + self.assertEqual(requests[1].write_offset, 3) + self.assertEqual(requests[1].checksummed_data.content, b"def") - # Total requests: 1 Spec + 3 Chunks - self.assertEqual(len(requests), 4) + # Verify Request 3 + self.assertEqual(requests[2].write_offset, 6) + self.assertEqual(requests[2].checksummed_data.content, b"ghi") - def test_generate_requests_initial_existing_object(self): + # Verify Request 4 + self.assertEqual(requests[3].write_offset, 9) + self.assertEqual(requests[3].checksummed_data.content, b"j") + + def test_generate_requests_resumption(self): """ - Verify the initial request sequence for appending to an existing object (AppendObjectSpec). + Verify request generation when resuming. + The strategy should generate chunks starting from the current 'bytes_sent'. """ strategy = self._make_one() - mock_buffer = io.BytesIO(b"0123") - # Use AppendObjectSpec for existing objects - mock_spec = storage_type.AppendObjectSpec( - object_="test-object", bucket="test-bucket" - ) - state = { - "write_state": _WriteState( - mock_spec, chunk_size=4, user_buffer=mock_buffer - ), - } + mock_buffer = io.BytesIO(b"0123456789") + write_state = _WriteState(chunk_size=4, user_buffer=mock_buffer) - requests = list(strategy.generate_requests(state)) + # Simulate resumption state: 4 bytes already sent/persisted + write_state.persisted_size = 4 + write_state.bytes_sent = 4 + # Buffer must be seeked to 4 before calling generate + mock_buffer.seek(4) + + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) + + # Since 4 bytes are done, we expect remaining 6 bytes: [4 bytes, 2 bytes] + self.assertEqual(len(requests), 2) - # Check first request (Spec) - self.assertEqual(requests[0].append_object_spec, mock_spec) - self.assertFalse(requests[0].state_lookup) + # Check first generated request starts at offset 4 + self.assertEqual(requests[0].write_offset, 4) + self.assertEqual(requests[0].checksummed_data.content, b"4567") - # Check data chunk - self.assertEqual(requests[1].write_offset, 0) - self.assertEqual(requests[1].checksummed_data.content, b"0123") + # Check second generated request starts at offset 8 + self.assertEqual(requests[1].write_offset, 8) + self.assertEqual(requests[1].checksummed_data.content, b"89") def test_generate_requests_empty_file(self): - """ - Verify request sequence for an empty file. Should just be the Spec. - """ + """Verify request sequence for an empty file.""" strategy = self._make_one() mock_buffer = io.BytesIO(b"") - mock_spec = storage_type.AppendObjectSpec(object_="test-object") - state = { - "write_state": _WriteState( - mock_spec, chunk_size=4, user_buffer=mock_buffer - ), - } + write_state = _WriteState(chunk_size=4, user_buffer=mock_buffer) + state = {"write_state": write_state} - requests = list(strategy.generate_requests(state)) + requests = strategy.generate_requests(state) - self.assertEqual(len(requests), 1) - self.assertEqual(requests[0].append_object_spec, mock_spec) + self.assertEqual(len(requests), 0) - def test_generate_requests_resumption(self): - """ - Verify request sequence when resuming an upload. - """ + def test_generate_requests_checksum_verification(self): + """Verify CRC32C is calculated correctly for each chunk.""" strategy = self._make_one() - mock_buffer = io.BytesIO(b"0123456789") - mock_spec = storage_type.AppendObjectSpec(object_="test-object") + chunk_data = b"test_data" + mock_buffer = io.BytesIO(chunk_data) + write_state = _WriteState( + chunk_size=10, user_buffer=mock_buffer + ) + state = {"write_state": write_state} - write_state = _WriteState(mock_spec, chunk_size=4, user_buffer=mock_buffer) - write_state.persisted_size = 4 - write_state.bytes_sent = 4 - write_state.write_handle = storage_type.BidiWriteHandle(handle=b"test-handle") - mock_buffer.seek(4) + requests = strategy.generate_requests(state) + expected_crc = google_crc32c.Checksum(chunk_data).digest() + expected_int = int.from_bytes(expected_crc, "big") + self.assertEqual(requests[0].checksummed_data.crc32c, expected_int) + + def test_generate_requests_flush_logic_exact_interval(self): + """Verify the flush bit is set exactly when the interval is reached.""" + strategy = self._make_one() + mock_buffer = io.BytesIO(b"A" * 12) + # 2 byte chunks, flush every 4 bytes + write_state = _WriteState( + chunk_size=2, + user_buffer=mock_buffer, + flush_interval=4 + ) state = {"write_state": write_state} - requests = list(strategy.generate_requests(state)) + requests = strategy.generate_requests(state) + + # Request index 1 (4 bytes total) should have flush=True + self.assertFalse(requests[0].flush) + self.assertTrue(requests[1].flush) + + # Request index 3 (8 bytes total) should have flush=True + self.assertFalse(requests[2].flush) + self.assertTrue(requests[3].flush) - # Check first request has handle and lookup - self.assertEqual( - requests[0].append_object_spec.write_handle.handle, b"test-handle" + # Verify counter reset in state + self.assertEqual(write_state.bytes_since_last_flush, 0) + + def test_generate_requests_flush_logic_none_interval(self): + """Verify flush is never set if interval is None.""" + strategy = self._make_one() + mock_buffer = io.BytesIO(b"A" * 10) + write_state = _WriteState( + chunk_size=2, + user_buffer=mock_buffer, + flush_interval=None ) + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) - # Check data starts from offset 4 - self.assertEqual(requests[1].write_offset, 4) - self.assertEqual(requests[1].checksummed_data.content, b"4567") - self.assertEqual(requests[2].write_offset, 8) - self.assertEqual(requests[2].checksummed_data.content, b"89") + for req in requests: + self.assertFalse(req.flush) + + def test_generate_requests_flush_logic_data_less_than_interval(self): + """Verify flush is not set if data sent is less than interval.""" + strategy = self._make_one() + mock_buffer = io.BytesIO(b"A" * 5) + # Flush every 10 bytes + write_state = _WriteState( + chunk_size=2, + user_buffer=mock_buffer, + flush_interval=10 + ) + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) + + # Total 5 bytes < 10 bytes interval + for req in requests: + self.assertFalse(req.flush) + + self.assertEqual(write_state.bytes_since_last_flush, 5) + + def test_generate_requests_honors_finalized_state(self): + """If state is already finalized, no requests should be generated.""" + strategy = self._make_one() + mock_buffer = io.BytesIO(b"data") + write_state = _WriteState( + chunk_size=4, user_buffer=mock_buffer + ) + write_state.is_finalized = True + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) + self.assertEqual(len(requests), 0) @pytest.mark.asyncio async def test_generate_requests_after_failure_and_recovery(self): """ - Verify recovery and resumption flow. + Verify recovery and resumption flow (Integration of recover + generate). """ strategy = self._make_one() - mock_buffer = io.BytesIO(b"0123456789abcdef") + mock_buffer = io.BytesIO(b"0123456789abcdef") # 16 bytes mock_spec = storage_type.AppendObjectSpec(object_="test-object") - state = { - "write_state": _WriteState(mock_spec, chunk_size=4, user_buffer=mock_buffer) - } - write_state = state["write_state"] + write_state = _WriteState(mock_spec, chunk_size=4, user_buffer=mock_buffer) + state = {"write_state": write_state} + # Simulate initial progress: sent 8 bytes write_state.bytes_sent = 8 mock_buffer.seek(8) @@ -169,122 +220,168 @@ async def test_generate_requests_after_failure_and_recovery(self): state, ) + # Simulate Failure Triggering Recovery await strategy.recover_state_on_failure(Exception("network error"), state) + # Assertions after recovery + # 1. Buffer should rewind to persisted_size (4) self.assertEqual(mock_buffer.tell(), 4) + # 2. bytes_sent should track persisted_size (4) self.assertEqual(write_state.bytes_sent, 4) - requests = list(strategy.generate_requests(state)) + requests = strategy.generate_requests(state) - self.assertTrue(requests[0].state_lookup) - self.assertEqual( - requests[0].append_object_spec.write_handle.handle, b"handle-1" - ) + # Remaining data from offset 4 to 16 (12 bytes total) + # Chunks: [4-8], [8-12], [12-16] + self.assertEqual(len(requests), 3) + + # Verify resumption offset + self.assertEqual(requests[0].write_offset, 4) + self.assertEqual(requests[0].checksummed_data.content, b"4567") - self.assertEqual(requests[1].write_offset, 4) - self.assertEqual(requests[1].checksummed_data.content, b"4567") + # ------------------------------------------------------------------------- + # Tests for update_state_from_response + # ------------------------------------------------------------------------- - def test_update_state_from_response(self): - """Verify state updates from server responses.""" + def test_update_state_from_response_all_fields(self): + """Verify all fields from a BidiWriteObjectResponse update the state.""" strategy = self._make_one() - mock_buffer = io.BytesIO(b"0123456789") - mock_spec = storage_type.AppendObjectSpec(object_="test-object") - state = { - "write_state": _WriteState( - mock_spec, chunk_size=4, user_buffer=mock_buffer - ), - } - write_state = state["write_state"] + write_state = _WriteState( + chunk_size=4, user_buffer=io.BytesIO() + ) + state = {"write_state": write_state} - response1 = storage_type.BidiWriteObjectResponse( - write_handle=storage_type.BidiWriteHandle(handle=b"handle-1") + # 1. Update persisted_size + strategy.update_state_from_response( + storage_type.BidiWriteObjectResponse(persisted_size=123), state ) - strategy.update_state_from_response(response1, state) - self.assertEqual(write_state.write_handle.handle, b"handle-1") + self.assertEqual(write_state.persisted_size, 123) - response2 = storage_type.BidiWriteObjectResponse(persisted_size=1024) - strategy.update_state_from_response(response2, state) - self.assertEqual(write_state.persisted_size, 1024) + # 2. Update write_handle + handle = storage_type.BidiWriteHandle(handle=b"new-handle") + strategy.update_state_from_response( + storage_type.BidiWriteObjectResponse(write_handle=handle), state + ) + self.assertEqual(write_state.write_handle, handle) - final_resource = storage_type.Object( - name="test-object", bucket="b", size=2048, finalize_time=datetime.now() + # 3. Update from Resource (finalization) + resource = storage_type.Object(size=1000, finalize_time=datetime.now()) + strategy.update_state_from_response( + storage_type.BidiWriteObjectResponse(resource=resource), state ) - response3 = storage_type.BidiWriteObjectResponse(resource=final_resource) - strategy.update_state_from_response(response3, state) - self.assertEqual(write_state.persisted_size, 2048) + self.assertEqual(write_state.persisted_size, 1000) self.assertTrue(write_state.is_finalized) - @pytest.mark.asyncio - async def test_recover_state_on_failure_handles_redirect(self): - """ - Verify redirection error handling. - """ + def test_update_state_from_response_none(self): + """Verify None response doesn't crash.""" strategy = self._make_one() - mock_buffer = mock.MagicMock(spec=io.BytesIO) - mock_spec = storage_type.AppendObjectSpec(object_="test-object") - - write_state = _WriteState(mock_spec, chunk_size=4, user_buffer=mock_buffer) + write_state = _WriteState( + chunk_size=4, user_buffer=io.BytesIO() + ) state = {"write_state": write_state} + strategy.update_state_from_response(None, state) + self.assertEqual(write_state.persisted_size, 0) - redirect_error = BidiWriteObjectRedirectedError(routing_token="new-token-123") - wrapped_error = Exception("RPC error") - wrapped_error.cause = redirect_error - - await strategy.recover_state_on_failure(wrapped_error, state) - - self.assertEqual(write_state.routing_token, "new-token-123") - mock_buffer.seek.assert_called_with(0) + # ------------------------------------------------------------------------- + # Tests for recover_state_on_failure + # ------------------------------------------------------------------------- @pytest.mark.asyncio - async def test_recover_state_on_failure_handles_redirect_with_handle(self): - """Verify redirection that includes a write handle.""" + async def test_recover_state_on_failure_rewind_logic(self): + """Verify buffer seek and counter resets on generic failure (Non-redirect).""" strategy = self._make_one() - mock_buffer = mock.MagicMock(spec=io.BytesIO) - mock_spec = storage_type.AppendObjectSpec(object_="test-object") - - write_state = _WriteState(mock_spec, chunk_size=4, user_buffer=mock_buffer) - state = {"write_state": write_state} - - redirect_error = BidiWriteObjectRedirectedError( - routing_token="new-token-456", write_handle=b"redirect-handle" + mock_buffer = io.BytesIO(b"0123456789") + write_state = _WriteState( + chunk_size=2, user_buffer=mock_buffer ) - wrapped_error = Exception("RPC error") - wrapped_error.cause = redirect_error - await strategy.recover_state_on_failure(wrapped_error, state) + # Simulate progress: sent 8 bytes, but server only persisted 4 + write_state.bytes_sent = 8 + write_state.persisted_size = 4 + write_state.bytes_since_last_flush = 2 + mock_buffer.seek(8) - self.assertEqual(write_state.routing_token, "new-token-456") - self.assertEqual(write_state.write_handle, b"redirect-handle") + # Simulate generic 503 error without trailers + await strategy.recover_state_on_failure(exceptions.ServiceUnavailable("busy"), {"write_state": write_state}) - mock_buffer.seek.assert_called_with(0) + # Buffer must be seeked back to 4 + self.assertEqual(mock_buffer.tell(), 4) + self.assertEqual(write_state.bytes_sent, 4) + # Flush counter must be reset to avoid incorrect firing after resume + self.assertEqual(write_state.bytes_since_last_flush, 0) - def test_generate_requests_sends_crc32c_checksum(self): + @pytest.mark.asyncio + async def test_recover_state_on_failure_direct_redirect(self): + """Verify handling when the error is a BidiWriteObjectRedirectedError.""" strategy = self._make_one() - mock_buffer = io.BytesIO(b"0123") - mock_spec = storage_type.AppendObjectSpec(object_="test-object") - state = { - "write_state": _WriteState( - mock_spec, chunk_size=4, user_buffer=mock_buffer - ), - } + write_state = _WriteState( + chunk_size=4, user_buffer=io.BytesIO() + ) + state = {"write_state": write_state} - requests = list(strategy.generate_requests(state)) + redirect = BidiWriteObjectRedirectedError(routing_token="tok-1", write_handle=b"h-1") - expected_crc = google_crc32c.Checksum(b"0123") - expected_crc_int = int.from_bytes(expected_crc.digest(), "big") - self.assertEqual(requests[1].checksummed_data.crc32c, expected_crc_int) + await strategy.recover_state_on_failure(redirect, state) - def test_generate_requests_with_routing_token(self): + self.assertEqual(write_state.routing_token, "tok-1") + self.assertEqual(write_state.write_handle, b"h-1") + + @pytest.mark.asyncio + async def test_recover_state_on_failure_wrapped_redirect(self): + """Verify handling when RedirectedError is inside Aborted.errors.""" strategy = self._make_one() - mock_buffer = io.BytesIO(b"") - mock_spec = storage_type.AppendObjectSpec(object_="test-object") + write_state = _WriteState( + chunk_size=4, user_buffer=io.BytesIO() + ) - write_state = _WriteState(mock_spec, chunk_size=4, user_buffer=mock_buffer) - write_state.routing_token = "redirected-token" - state = {"write_state": write_state} + redirect = BidiWriteObjectRedirectedError(routing_token="tok-wrapped") + # google-api-core Aborted often wraps multiple errors + error = exceptions.Aborted("conflict", errors=[redirect]) - requests = list(strategy.generate_requests(state)) + await strategy.recover_state_on_failure(error, {"write_state": write_state}) - self.assertEqual( - requests[0].append_object_spec.routing_token, "redirected-token" + self.assertEqual(write_state.routing_token, "tok-wrapped") + + @pytest.mark.asyncio + async def test_recover_state_on_failure_trailer_metadata_redirect(self): + """Verify complex parsing from 'grpc-status-details-bin' in trailers.""" + strategy = self._make_one() + write_state = _WriteState( + chunk_size=4, user_buffer=io.BytesIO() ) + + # 1. Setup Redirect Proto + redirect_proto = BidiWriteObjectRedirectedError(routing_token="metadata-token") + + # 2. Setup Status Proto Detail + status = status_pb2.Status() + detail = status.details.add() + detail.type_url = "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + # In a real environment, detail.value is the serialized proto + detail.value = BidiWriteObjectRedirectedError.to_json(redirect_proto).encode() + + # 3. Create Mock Error with Trailers + mock_error = mock.MagicMock(spec=exceptions.Aborted) + mock_error.errors = [] # No direct errors + mock_error.trailing_metadata.return_value = [ + ("grpc-status-details-bin", status.SerializeToString()) + ] + + # 4. Patch deserialize to handle the binary value + with mock.patch("google.cloud._storage_v2.types.storage.BidiWriteObjectRedirectedError.deserialize", return_value=redirect_proto): + await strategy.recover_state_on_failure(mock_error, {"write_state": write_state}) + + self.assertEqual(write_state.routing_token, "metadata-token") + + def test_write_state_initialization(self): + """Verify WriteState starts with clean counters.""" + buffer = io.BytesIO(b"test") + ws = _WriteState( + chunk_size=10, user_buffer=buffer, flush_interval=100 + ) + + self.assertEqual(ws.persisted_size, 0) + self.assertEqual(ws.bytes_sent, 0) + self.assertEqual(ws.bytes_since_last_flush, 0) + self.assertEqual(ws.flush_interval, 100) + self.assertFalse(ws.is_finalized) diff --git a/tests/unit/asyncio/test_async_appendable_object_writer.py b/tests/unit/asyncio/test_async_appendable_object_writer.py index c496b5810..ac6716fbf 100644 --- a/tests/unit/asyncio/test_async_appendable_object_writer.py +++ b/tests/unit/asyncio/test_async_appendable_object_writer.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from io import BytesIO +import io +import unittest +import unittest.mock as mock import pytest -from unittest import mock - -from google_crc32c import Checksum - from google.api_core import exceptions +from google.rpc import status_pb2 +from google.cloud._storage_v2.types import storage as storage_type +from google.cloud._storage_v2.types.storage import BidiWriteObjectRedirectedError from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( AsyncAppendableObjectWriter, -) -from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( + _is_write_retryable, _MAX_CHUNK_SIZE_BYTES, _DEFAULT_FLUSH_INTERVAL_BYTES, ) -from google.cloud import _storage_v2 - BUCKET = "test-bucket" OBJECT = "test-object" @@ -37,695 +35,521 @@ EIGHT_MIB = 8 * 1024 * 1024 -@pytest.fixture -def mock_client(): - """Mock the async gRPC client.""" - return mock.AsyncMock() - - -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -def test_init(mock_write_object_stream, mock_client): - """Test the constructor of AsyncAppendableObjectWriter.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - - assert writer.client == mock_client - assert writer.bucket_name == BUCKET - assert writer.object_name == OBJECT - assert writer.generation is None - assert writer.write_handle is None - assert not writer._is_stream_open - assert writer.offset is None - assert writer.persisted_size is None - assert writer.bytes_appended_since_last_flush == 0 - - mock_write_object_stream.assert_called_once_with( - client=mock_client, - bucket_name=BUCKET, - object_name=OBJECT, - generation_number=None, - write_handle=None, - ) - assert writer.write_obj_stream == mock_write_object_stream.return_value - - -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -def test_init_with_optional_args(mock_write_object_stream, mock_client): - """Test the constructor with optional arguments.""" - writer = AsyncAppendableObjectWriter( - mock_client, - BUCKET, - OBJECT, - generation=GENERATION, - write_handle=WRITE_HANDLE, - ) - - assert writer.generation == GENERATION - assert writer.write_handle == WRITE_HANDLE - assert writer.bytes_appended_since_last_flush == 0 - - mock_write_object_stream.assert_called_once_with( - client=mock_client, - bucket_name=BUCKET, - object_name=OBJECT, - generation_number=GENERATION, - write_handle=WRITE_HANDLE, - ) - - -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -def test_init_with_writer_options(mock_write_object_stream, mock_client): - """Test the constructor with optional arguments.""" - writer = AsyncAppendableObjectWriter( - mock_client, - BUCKET, - OBJECT, - writer_options={"FLUSH_INTERVAL_BYTES": EIGHT_MIB}, - ) - - assert writer.flush_interval == EIGHT_MIB - assert writer.bytes_appended_since_last_flush == 0 - - mock_write_object_stream.assert_called_once_with( - client=mock_client, - bucket_name=BUCKET, - object_name=OBJECT, - generation_number=None, - write_handle=None, - ) - - -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -def test_init_with_flush_interval_less_than_chunk_size_raises_error(mock_client): - """Test that an OutOfRange error is raised if flush_interval is less than the chunk size.""" - - with pytest.raises(exceptions.OutOfRange): - AsyncAppendableObjectWriter( - mock_client, - BUCKET, - OBJECT, - writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES - 1}, +class TestIsWriteRetryable(unittest.TestCase): + def test_transient_errors(self): + for exc_type in [ + exceptions.InternalServerError, + exceptions.ServiceUnavailable, + exceptions.DeadlineExceeded, + exceptions.TooManyRequests, + ]: + self.assertTrue(_is_write_retryable(exc_type("error"))) + + def test_aborted_with_redirect_proto(self): + # Direct redirect error wrapped in Aborted + redirect = BidiWriteObjectRedirectedError(routing_token="token") + exc = exceptions.Aborted("aborted", errors=[redirect]) + self.assertTrue(_is_write_retryable(exc)) + + def test_aborted_with_trailers(self): + # Redirect hidden in trailers + status = status_pb2.Status() + detail = status.details.add() + detail.type_url = "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + + # Correctly serialize the proto message to bytes for the detail value + redirect_proto = BidiWriteObjectRedirectedError(routing_token="rt2") + detail.value = BidiWriteObjectRedirectedError.serialize(redirect_proto) + + exc = exceptions.Aborted("aborted") + exc.trailing_metadata = [("grpc-status-details-bin", status.SerializeToString())] + self.assertTrue(_is_write_retryable(exc)) + + def test_non_retryable(self): + self.assertFalse(_is_write_retryable(exceptions.BadRequest("bad"))) + self.assertFalse(_is_write_retryable(exceptions.Aborted("just aborted"))) + + +class TestAsyncAppendableObjectWriter(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.mock_client = mock.AsyncMock() + # Patch the stream class used internally + self.mock_stream_cls = mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" + ).start() + self.mock_stream = self.mock_stream_cls.return_value + + # Default mock stream state + self.mock_stream.is_stream_open = False + self.mock_stream.persisted_size = 0 + self.mock_stream.generation_number = GENERATION + self.mock_stream.write_handle = WRITE_HANDLE + + def tearDown(self): + mock.patch.stopall() + + def _make_one(self, **kwargs): + return AsyncAppendableObjectWriter( + self.mock_client, BUCKET, OBJECT, **kwargs ) - -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -def test_init_with_flush_interval_not_multiple_of_chunk_size_raises_error(mock_client): - """Test that an OutOfRange error is raised if flush_interval is not a multiple of the chunk size.""" - - with pytest.raises(exceptions.OutOfRange): - AsyncAppendableObjectWriter( - mock_client, - BUCKET, - OBJECT, - writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES + 1}, + # ------------------------------------------------------------------------- + # Initialization Tests + # ------------------------------------------------------------------------- + + def test_init_defaults(self): + writer = self._make_one() + self.assertEqual(writer.client, self.mock_client) + self.assertEqual(writer.bucket_name, BUCKET) + self.assertEqual(writer.object_name, OBJECT) + self.assertIsNone(writer.generation) + self.assertIsNone(writer.write_handle) + self.assertFalse(writer._is_stream_open) + self.assertIsNone(writer.persisted_size) + self.assertEqual(writer.bytes_appended_since_last_flush, 0) + + def test_init_with_optional_args(self): + writer = self._make_one( + generation=GENERATION, + write_handle=WRITE_HANDLE, + ) + self.assertEqual(writer.generation, GENERATION) + self.assertEqual(writer.write_handle, WRITE_HANDLE) + + def test_init_with_writer_options(self): + writer = self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": EIGHT_MIB}) + self.assertEqual(writer.flush_interval, EIGHT_MIB) + + def test_init_validation_chunk_size(self): + with self.assertRaises(exceptions.OutOfRange): + self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES - 1}) + + def test_init_validation_chunk_multiple(self): + with self.assertRaises(exceptions.OutOfRange): + self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES + 1}) + + def test_init_raises_if_crc32c_c_extension_is_missing(self): + with mock.patch("google.cloud.storage._experimental.asyncio._utils.google_crc32c") as mock_crc: + mock_crc.implementation = "python" + with self.assertRaisesRegex(exceptions.FailedPrecondition, "google-crc32c package is not installed"): + self._make_one() + + # ------------------------------------------------------------------------- + # Helper Method Tests + # ------------------------------------------------------------------------- + + async def test_state_lookup(self): + writer = self._make_one() + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + + self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( + persisted_size=PERSISTED_SIZE ) + resp = await writer.state_lookup() -@mock.patch("google.cloud.storage._experimental.asyncio._utils.google_crc32c") -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" -) -def test_init_raises_if_crc32c_c_extension_is_missing( - mock_grpc_client, mock_google_crc32c -): - mock_google_crc32c.implementation = "python" - - with pytest.raises(exceptions.FailedPrecondition) as exc_info: - AsyncAppendableObjectWriter(mock_grpc_client, "bucket", "object") - - assert "The google-crc32c package is not installed with C support" in str( - exc_info.value - ) - + self.mock_stream.send.assert_awaited_once_with( + storage_type.BidiWriteObjectRequest(state_lookup=True) + ) + self.assertEqual(resp, PERSISTED_SIZE) + self.assertEqual(writer.persisted_size, PERSISTED_SIZE) + + async def test_state_lookup_not_open_raises(self): + writer = self._make_one() + with self.assertRaisesRegex(ValueError, "Stream is not open"): + await writer.state_lookup() + + async def test_unimplemented_methods(self): + writer = self._make_one() + with self.assertRaises(NotImplementedError): + await writer.append_from_string("data") + with self.assertRaises(NotImplementedError): + await writer.append_from_stream(mock.Mock()) + + # ------------------------------------------------------------------------- + # Open & Error Handling Tests + # ------------------------------------------------------------------------- + + async def test_open_success(self): + writer = self._make_one() + self.mock_stream.generation_number = GENERATION + self.mock_stream.write_handle = WRITE_HANDLE + self.mock_stream.persisted_size = 0 -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_state_lookup(mock_write_object_stream, mock_client): - """Test state_lookup method.""" - # Arrange - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse(persisted_size=PERSISTED_SIZE) - ) - - expected_request = _storage_v2.BidiWriteObjectRequest(state_lookup=True) - - # Act - response = await writer.state_lookup() - - # Assert - mock_stream.send.assert_awaited_once_with(expected_request) - mock_stream.recv.assert_awaited_once() - assert writer.persisted_size == PERSISTED_SIZE - assert response == PERSISTED_SIZE - - -@pytest.mark.asyncio -async def test_state_lookup_without_open_raises_value_error(mock_client): - """Test that state_lookup raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, - match="Stream is not open. Call open\\(\\) before state_lookup\\(\\).", - ): - await writer.state_lookup() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_open_appendable_object_writer(mock_write_object_stream, mock_client): - """Test the open method.""" - # Arrange - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - mock_stream = mock_write_object_stream.return_value - mock_stream.open = mock.AsyncMock() - - mock_stream.generation_number = GENERATION - mock_stream.write_handle = WRITE_HANDLE - mock_stream.persisted_size = 0 - - # Act - await writer.open() - - # Assert - mock_stream.open.assert_awaited_once() - assert writer._is_stream_open - assert writer.generation == GENERATION - assert writer.write_handle == WRITE_HANDLE - assert writer.persisted_size == 0 - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_open_appendable_object_writer_existing_object( - mock_write_object_stream, mock_client -): - """Test the open method.""" - # Arrange - writer = AsyncAppendableObjectWriter( - mock_client, BUCKET, OBJECT, generation=GENERATION - ) - mock_stream = mock_write_object_stream.return_value - mock_stream.open = mock.AsyncMock() - - mock_stream.generation_number = GENERATION - mock_stream.write_handle = WRITE_HANDLE - mock_stream.persisted_size = PERSISTED_SIZE - - # Act - await writer.open() - - # Assert - mock_stream.open.assert_awaited_once() - assert writer._is_stream_open - assert writer.generation == GENERATION - assert writer.write_handle == WRITE_HANDLE - assert writer.persisted_size == PERSISTED_SIZE - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_open_when_already_open_raises_error( - mock_write_object_stream, mock_client -): - """Test that opening an already open writer raises a ValueError.""" - # Arrange - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True # Manually set to open - - # Act & Assert - with pytest.raises(ValueError, match="Underlying bidi-gRPC stream is already open"): await writer.open() + self.assertTrue(writer._is_stream_open) + self.assertEqual(writer.generation, GENERATION) + self.assertEqual(writer.write_handle, WRITE_HANDLE) + self.assertEqual(writer.persisted_size, 0) + + self.mock_stream_cls.assert_called_with( + client=self.mock_client, + bucket_name=BUCKET, + object_name=OBJECT, + generation_number=None, + write_handle=None, + routing_token=None, + ) + self.mock_stream.open.assert_awaited_once() -@pytest.mark.asyncio -async def test_unimplemented_methods_raise_error(mock_client): - """Test that all currently unimplemented methods raise NotImplementedError.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - - with pytest.raises(NotImplementedError): - await writer.append_from_string("data") + async def test_open_appendable_object_writer_existing_object(self): + # Verify opening with existing generation uses AppendObjectSpec implicitly via stream init + writer = self._make_one(generation=GENERATION, write_handle=WRITE_HANDLE) + self.mock_stream.generation_number = GENERATION + self.mock_stream.write_handle = WRITE_HANDLE + self.mock_stream.persisted_size = PERSISTED_SIZE - with pytest.raises(NotImplementedError): - await writer.append_from_stream(mock.Mock()) + await writer.open() + # Check constructor was called with generation/handle + self.mock_stream_cls.assert_called_with( + client=self.mock_client, + bucket_name=BUCKET, + object_name=OBJECT, + generation_number=GENERATION, + write_handle=WRITE_HANDLE, + routing_token=None, + ) + self.assertEqual(writer.persisted_size, PERSISTED_SIZE) + + async def test_open_with_routing_token_and_metadata(self): + writer = self._make_one() + writer._routing_token = "prev-token" + metadata = [("key", "val")] + + await writer.open(metadata=metadata) + + self.mock_stream_cls.assert_called_with( + client=self.mock_client, + bucket_name=BUCKET, + object_name=OBJECT, + generation_number=None, + write_handle=None, + routing_token="prev-token", + ) + call_kwargs = self.mock_stream.open.call_args[1] + passed_metadata = call_kwargs['metadata'] + self.assertIn(("x-goog-request-params", "routing_token=prev-token"), passed_metadata) + self.assertIsNone(writer._routing_token) + + async def test_open_when_already_open_raises(self): + writer = self._make_one() + writer._is_stream_open = True + with self.assertRaisesRegex(ValueError, "Underlying bidi-gRPC stream is already open"): + await writer.open() + + def test_on_open_error_extraction(self): + writer = self._make_one() + + # 1. Direct Redirect Error + redirect = BidiWriteObjectRedirectedError( + routing_token="rt", + write_handle=storage_type.BidiWriteHandle(handle=b"wh"), + generation=999 + ) + writer._on_open_error(exceptions.Aborted("e", errors=[redirect])) + + self.assertEqual(writer._routing_token, "rt") + self.assertEqual(writer.write_handle.handle, b"wh") + self.assertEqual(writer.generation, 999) + + # 2. Trailer Error + status = status_pb2.Status() + detail = status.details.add() + detail.type_url = "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + detail.value = BidiWriteObjectRedirectedError.serialize( + BidiWriteObjectRedirectedError(routing_token="rt2") + ) -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_flush(mock_write_object_stream, mock_client): - """Test that flush sends the correct request and updates state.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse(persisted_size=1024) - ) - - persisted_size = await writer.flush() - - expected_request = _storage_v2.BidiWriteObjectRequest(flush=True, state_lookup=True) - mock_stream.send.assert_awaited_once_with(expected_request) - mock_stream.recv.assert_awaited_once() - assert writer.persisted_size == 1024 - assert writer.offset == 1024 - assert persisted_size == 1024 - - -@pytest.mark.asyncio -async def test_flush_without_open_raises_value_error(mock_client): - """Test that flush raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Stream is not open. Call open\\(\\) before flush\\(\\)." - ): - await writer.flush() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_simple_flush(mock_write_object_stream, mock_client): - """Test that flush sends the correct request and updates state.""" - # Arrange - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - - # Act - await writer.simple_flush() - - # Assert - mock_stream.send.assert_awaited_once_with( - _storage_v2.BidiWriteObjectRequest(flush=True) - ) - - -@pytest.mark.asyncio -async def test_simple_flush_without_open_raises_value_error(mock_client): - """Test that flush raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, - match="Stream is not open. Call open\\(\\) before simple_flush\\(\\).", - ): - await writer.simple_flush() + exc = exceptions.Aborted("e") + exc.trailing_metadata = [("grpc-status-details-bin", status.SerializeToString())] + + writer._on_open_error(exc) + self.assertEqual(writer._routing_token, "rt2") + + # ------------------------------------------------------------------------- + # Append Tests + # ------------------------------------------------------------------------- + + async def test_append_not_open_raises(self): + writer = self._make_one() + with self.assertRaisesRegex(ValueError, "Stream is not open"): + await writer.append(b"data") + + async def test_append_empty_data_does_nothing(self): + writer = self._make_one() + writer._is_stream_open = True + with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: + await writer.append(b"") + MockManager.assert_not_called() + + async def test_append_propagates_non_retryable_errors(self): + """Verify non-retryable errors bubble up.""" + writer = self._make_one() + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + + with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: + # Simulate RetryManager raising a hard error + MockManager.return_value.execute.side_effect = exceptions.BadRequest("bad") + + with self.assertRaises(exceptions.BadRequest): + await writer.append(b"data") + + async def test_append_basic_flow_integration(self): + """Verify append sets up RetryManager and orchestrates chunks.""" + writer = self._make_one(write_handle=b"h1", generation=1) + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + writer.persisted_size = 0 + + data = b"a" * (_MAX_CHUNK_SIZE_BYTES + 10) # 2MB + 10 bytes + + with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: + mock_manager_instance = MockManager.return_value + + async def mock_execute(state, policy): + generator_factory = MockManager.call_args[0][1] + # Strategy generates chunks (we use dummy ones for integration check) + dummy_requests = [ + storage_type.BidiWriteObjectRequest(write_offset=0), + storage_type.BidiWriteObjectRequest(write_offset=_MAX_CHUNK_SIZE_BYTES) + ] + gen = generator_factory(dummy_requests, state) + + self.mock_stream.recv.side_effect = [ + storage_type.BidiWriteObjectResponse(persisted_size=100, write_handle=b"h2"), + None + ] + async for _ in gen: pass + + mock_manager_instance.execute.side_effect = mock_execute + + await writer.append(data) + + self.assertEqual(writer.persisted_size, 100) + self.assertEqual(writer.write_handle, b"h2") + self.assertEqual(self.mock_stream.send.await_count, 2) + # Last chunk should have state_lookup=True + self.assertTrue(self.mock_stream.send.await_args_list[-1][0][0].state_lookup) + + async def test_append_flushes_when_interval_reached(self): + """Verify generator respects flush flag from strategy.""" + # Flush interval matches 2 chunks + flush_interval = _MAX_CHUNK_SIZE_BYTES * 2 + writer = self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": flush_interval}) + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + + data = b"a" * flush_interval + + with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: + mock_manager_instance = MockManager.return_value + + async def mock_execute(state, policy): + generator_factory = MockManager.call_args[0][1] + + # Simulate strategy identifying a flush point + req_with_flush = storage_type.BidiWriteObjectRequest(flush=True) + gen = generator_factory([req_with_flush], state) + + self.mock_stream.recv.return_value = None + async for _ in gen: pass + + mock_manager_instance.execute.side_effect = mock_execute + await writer.append(data) + + # Verify sent request had flush=True + sent_request = self.mock_stream.send.call_args[0][0] + self.assertTrue(sent_request.flush) + + async def test_append_sequential_calls_update_state(self): + """Test state carry-over between two append calls.""" + writer = self._make_one(write_handle=b"h1", generation=1) + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + writer.persisted_size = 0 + + with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: + # 1. First Append + async def execute_1(state, policy): + # Simulate server acknowledging 100 bytes + state["write_state"].persisted_size = 100 + writer.write_obj_stream.persisted_size = 100 + + MockManager.return_value.execute.side_effect = execute_1 + await writer.append(b"a" * 100) + + self.assertEqual(writer.persisted_size, 100) + + # 2. Second Append + async def execute_2(state, policy): + # Verify state passed to manager starts where we left off + assert state["write_state"].persisted_size == 100 + state["write_state"].persisted_size = 200 + + MockManager.return_value.execute.side_effect = execute_2 + await writer.append(b"b" * 100) + + self.assertEqual(writer.persisted_size, 200) + + async def test_append_recovery_flow(self): + """Test internal generator logic when a retry occurs (Attempt > 1).""" + writer = self._make_one(write_handle=b"h1", generation=1) + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + writer.persisted_size = 0 + + async def mock_aaow_open(metadata=None): + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + writer.persisted_size = 4 # Server says 4 bytes persisted + writer.write_handle = b"h_new" + + with mock.patch.object(writer, "open", side_effect=mock_aaow_open) as mock_writer_open: + with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: + + async def mock_execute(state, policy): + factory = MockManager.call_args[0][1] + + # --- SIMULATE ATTEMPT 1 (Fail) --- + gen1 = factory([], state) + try: await gen1.__anext__() + except: pass + + # --- SIMULATE ATTEMPT 2 (Recovery) --- + # Logic should: close old stream, open new, rewind buffer, generate new requests + gen2 = factory([], state) + self.mock_stream.is_stream_open = True + self.mock_stream.recv.side_effect = [None] + async for _ in gen2: pass + + MockManager.return_value.execute.side_effect = mock_execute + + await writer.append(b"1234567890") + + # Recovery Assertions + self.mock_stream.close.assert_awaited() + mock_writer_open.assert_awaited() + self.assertEqual(writer.write_handle, b"h_new") + self.assertEqual(writer.persisted_size, 4) + + async def test_append_metadata_injection(self): + """Verify providing metadata forces a restart (Attempt 1 logic).""" + writer = self._make_one() + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + custom_meta = [("x-test", "true")] + + with mock.patch.object(writer, "open", new_callable=mock.AsyncMock) as mock_writer_open: + with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: + async def mock_execute(state, policy): + factory = MockManager.call_args[0][1] + gen = factory([], state) + self.mock_stream.recv.return_value = None + async for _ in gen: pass + + MockManager.return_value.execute.side_effect = mock_execute + await writer.append(b"data", metadata=custom_meta) + + self.mock_stream.close.assert_awaited() + mock_writer_open.assert_awaited_with(metadata=custom_meta) + + # ------------------------------------------------------------------------- + # Flush, Close, Finalize Tests + # ------------------------------------------------------------------------- + + async def test_flush(self): + writer = self._make_one() + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + writer.bytes_appended_since_last_flush = 50 + + self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( + persisted_size=100 + ) + res = await writer.flush() -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_close(mock_write_object_stream, mock_client): - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.offset = 1024 - writer.persisted_size = 1024 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse(persisted_size=1024) - ) - mock_stream.close = mock.AsyncMock() - writer.finalize = mock.AsyncMock() - - persisted_size = await writer.close() - - writer.finalize.assert_not_awaited() - mock_stream.close.assert_awaited_once() - assert writer.offset is None - assert persisted_size == 1024 - assert not writer._is_stream_open - - -@pytest.mark.asyncio -async def test_close_without_open_raises_value_error(mock_client): - """Test that close raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Stream is not open. Call open\\(\\) before close\\(\\)." - ): - await writer.close() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_finalize_on_close(mock_write_object_stream, mock_client): - """Test close with finalizing.""" - # Arrange - mock_resource = _storage_v2.Object(name=OBJECT, bucket=BUCKET, size=2048) - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.offset = 1024 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse(resource=mock_resource) - ) - mock_stream.close = mock.AsyncMock() - - # Act - result = await writer.close(finalize_on_close=True) - - # Assert - mock_stream.close.assert_awaited_once() - assert not writer._is_stream_open - assert writer.offset is None - assert writer.object_resource == mock_resource - assert writer.persisted_size == 2048 - assert result == mock_resource - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_finalize(mock_write_object_stream, mock_client): - """Test that finalize sends the correct request and updates state.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - mock_resource = _storage_v2.Object(name=OBJECT, bucket=BUCKET, size=123) - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse(resource=mock_resource) - ) - mock_stream.close = mock.AsyncMock() - - gcs_object = await writer.finalize() - - mock_stream.send.assert_awaited_once_with( - _storage_v2.BidiWriteObjectRequest(finish_write=True) - ) - mock_stream.recv.assert_awaited_once() - mock_stream.close.assert_awaited_once() - assert writer.object_resource == mock_resource - assert writer.persisted_size == 123 - assert gcs_object == mock_resource - assert writer._is_stream_open is False - assert writer.offset is None - - -@pytest.mark.asyncio -async def test_finalize_without_open_raises_value_error(mock_client): - """Test that finalize raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Stream is not open. Call open\\(\\) before finalize\\(\\)." - ): - await writer.finalize() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_raises_error_if_not_open(mock_write_object_stream, mock_client): - """Test that append raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Stream is not open. Call open\\(\\) before append\\(\\)." - ): - await writer.append(b"some data") - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_with_empty_data(mock_write_object_stream, mock_client): - """Test that append does nothing if data is empty.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() + self.mock_stream.send.assert_awaited_with( + storage_type.BidiWriteObjectRequest(flush=True, state_lookup=True) + ) + self.assertEqual(res, 100) + self.assertEqual(writer.bytes_appended_since_last_flush, 0) - await writer.append(b"") + async def test_flush_not_open_raises(self): + writer = self._make_one() + with self.assertRaisesRegex(ValueError, "Stream is not open"): + await writer.flush() - mock_stream.send.assert_not_awaited() + async def test_simple_flush(self): + writer = self._make_one() + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + await writer.simple_flush() -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_sends_data_in_chunks(mock_write_object_stream, mock_client): - """Test that append sends data in chunks and updates offset.""" - from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( - _MAX_CHUNK_SIZE_BYTES, - ) - - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.persisted_size = 100 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - - data = b"a" * (_MAX_CHUNK_SIZE_BYTES + 1) - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse( - persisted_size=100 + len(data) + self.mock_stream.send.assert_awaited_with( + storage_type.BidiWriteObjectRequest(flush=True) ) - ) - - await writer.append(data) - - assert mock_stream.send.await_count == 2 - first_request = mock_stream.send.await_args_list[0].args[0] - second_request = mock_stream.send.await_args_list[1].args[0] - - # First chunk - assert first_request.write_offset == 100 - assert len(first_request.checksummed_data.content) == _MAX_CHUNK_SIZE_BYTES - assert first_request.checksummed_data.crc32c == int.from_bytes( - Checksum(data[:_MAX_CHUNK_SIZE_BYTES]).digest(), byteorder="big" - ) - assert not first_request.flush - assert not first_request.state_lookup - - # Second chunk (last chunk) - assert second_request.write_offset == 100 + _MAX_CHUNK_SIZE_BYTES - assert len(second_request.checksummed_data.content) == 1 - assert second_request.checksummed_data.crc32c == int.from_bytes( - Checksum(data[_MAX_CHUNK_SIZE_BYTES:]).digest(), byteorder="big" - ) - assert second_request.flush - assert second_request.state_lookup - - assert writer.offset == 100 + len(data) - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_flushes_when_buffer_is_full( - mock_write_object_stream, mock_client -): - """Test that append flushes the stream when the buffer size is reached.""" - - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.persisted_size = 0 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock() - - data = b"a" * _DEFAULT_FLUSH_INTERVAL_BYTES - await writer.append(data) - - num_chunks = _DEFAULT_FLUSH_INTERVAL_BYTES // _MAX_CHUNK_SIZE_BYTES - assert mock_stream.send.await_count == num_chunks - - # All but the last request should not have flush or state_lookup set. - for i in range(num_chunks - 1): - request = mock_stream.send.await_args_list[i].args[0] - assert not request.flush - assert not request.state_lookup - - # The last request should have flush and state_lookup set. - last_request = mock_stream.send.await_args_list[-1].args[0] - assert last_request.flush - assert last_request.state_lookup - assert writer.bytes_appended_since_last_flush == 0 - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_handles_large_data(mock_write_object_stream, mock_client): - """Test that append handles data larger than the buffer size.""" - - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.persisted_size = 0 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock() - data = b"a" * (_DEFAULT_FLUSH_INTERVAL_BYTES * 2 + 1) - await writer.append(data) - - flushed_requests = [ - call.args[0] for call in mock_stream.send.await_args_list if call.args[0].flush - ] - assert len(flushed_requests) == 3 - - last_request = mock_stream.send.await_args_list[-1].args[0] - assert last_request.state_lookup + async def test_simple_flush_not_open_raises(self): + writer = self._make_one() + with self.assertRaisesRegex(ValueError, "Stream is not open"): + await writer.simple_flush() + + async def test_close(self): + writer = self._make_one() + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + writer.persisted_size = 50 + + res = await writer.close() + + self.mock_stream.close.assert_awaited() + self.assertFalse(writer._is_stream_open) + self.assertEqual(res, 50) + + async def test_close_not_open_raises(self): + writer = self._make_one() + with self.assertRaisesRegex(ValueError, "Stream is not open"): + await writer.close() + + async def test_finalize(self): + writer = self._make_one() + writer._is_stream_open = True + writer.write_obj_stream = self.mock_stream + resource = storage_type.Object(size=999) + self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( + resource=resource + ) + res = await writer.finalize() -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_data_two_times(mock_write_object_stream, mock_client): - """Test that append sends data correctly when called multiple times.""" - from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( - _MAX_CHUNK_SIZE_BYTES, - ) - - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.persisted_size = 0 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - - data1 = b"a" * (_MAX_CHUNK_SIZE_BYTES + 10) - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse( - persisted_size= len(data1) - ) - ) - await writer.append(data1) - - assert mock_stream.send.await_count == 2 - last_request_data1 = mock_stream.send.await_args_list[-1].args[0] - assert last_request_data1.flush - assert last_request_data1.state_lookup - - data2 = b"b" * (_MAX_CHUNK_SIZE_BYTES + 20) - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse( - persisted_size= len(data2) + len(data1) + self.mock_stream.send.assert_awaited_with( + storage_type.BidiWriteObjectRequest(finish_write=True) ) - ) - await writer.append(data2) - - assert mock_stream.send.await_count == 4 - last_request_data2 = mock_stream.send.await_args_list[-1].args[0] - assert last_request_data2.flush - assert last_request_data2.state_lookup - - total_data_length = len(data1) + len(data2) - assert writer.offset == total_data_length - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "file_size, block_size", - [ - (10, 4 * 1024), - (0, _DEFAULT_FLUSH_INTERVAL_BYTES), - (20 * 1024 * 1024, _DEFAULT_FLUSH_INTERVAL_BYTES), - (16 * 1024 * 1024, _DEFAULT_FLUSH_INTERVAL_BYTES), - ], -) -async def test_append_from_file(file_size, block_size, mock_client): - # arrange - fp = BytesIO(b"a" * file_size) - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.append = mock.AsyncMock() - - # act - await writer.append_from_file(fp, block_size=block_size) - - # assert - exepected_calls = ( - file_size // block_size - if file_size % block_size == 0 - else file_size // block_size + 1 - ) - assert writer.append.await_count == exepected_calls - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager" -) -async def test_append_with_retry_on_service_unavailable( - mock_retry_manager_class, mock_client -): - """Test that append retries on ServiceUnavailable.""" - # Arrange - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.write_handle = WRITE_HANDLE - - mock_retry_manager = mock_retry_manager_class.return_value - mock_retry_manager.execute = mock.AsyncMock( - side_effect=[exceptions.ServiceUnavailable("testing"), None] - ) + self.assertEqual(writer.object_resource, resource) + self.assertEqual(writer.persisted_size, 999) + self.assertEqual(res, resource) - data_to_append = b"some data" + async def test_finalize_not_open_raises(self): + writer = self._make_one() + with self.assertRaisesRegex(ValueError, "Stream is not open"): + await writer.finalize() - # Act - await writer.append(data_to_append) + # ------------------------------------------------------------------------- + # Append From File Tests + # ------------------------------------------------------------------------- - # Assert - assert mock_retry_manager.execute.await_count == 2 + async def test_append_from_file(self): + writer = self._make_one() + writer._is_stream_open = True + writer.append = mock.AsyncMock() + fp = io.BytesIO(b"1234567890") + await writer.append_from_file(fp, block_size=4) -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager" -) -async def test_append_with_non_retryable_error( - mock_retry_manager_class, mock_client -): - """Test that append does not retry on non-retriable errors.""" - # Arrange - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.write_handle = WRITE_HANDLE - - mock_retry_manager = mock_retry_manager_class.return_value - mock_retry_manager.execute = mock.AsyncMock( - side_effect=exceptions.BadRequest("testing") - ) - - data_to_append = b"some data" - - # Act & Assert - with pytest.raises(exceptions.BadRequest): - await writer.append(data_to_append) - - assert mock_retry_manager.execute.await_count == 1 + self.assertEqual(writer.append.await_count, 3) diff --git a/tests/unit/asyncio/test_async_write_object_stream.py b/tests/unit/asyncio/test_async_write_object_stream.py index c6ea8a8ff..4ce1526b8 100644 --- a/tests/unit/asyncio/test_async_write_object_stream.py +++ b/tests/unit/asyncio/test_async_write_object_stream.py @@ -1,396 +1,396 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from unittest import mock - -from unittest.mock import AsyncMock -from google.cloud.storage._experimental.asyncio.async_write_object_stream import ( - _AsyncWriteObjectStream, -) -from google.cloud import _storage_v2 - -BUCKET = "my-bucket" -OBJECT = "my-object" -GENERATION = 12345 -WRITE_HANDLE = b"test-handle" - - -@pytest.fixture -def mock_client(): - """Mock the async gRPC client.""" - mock_transport = mock.AsyncMock() - mock_transport.bidi_write_object = mock.sentinel.bidi_write_object - mock_transport._wrapped_methods = { - mock.sentinel.bidi_write_object: mock.sentinel.wrapped_bidi_write_object - } - - mock_gapic_client = mock.AsyncMock() - mock_gapic_client._transport = mock_transport - - client = mock.AsyncMock() - client._client = mock_gapic_client - return client - - -async def instantiate_write_obj_stream(mock_client, mock_cls_async_bidi_rpc, open=True): - """Helper to create an instance of _AsyncWriteObjectStream and open it by default.""" - socket_like_rpc = AsyncMock() - mock_cls_async_bidi_rpc.return_value = socket_like_rpc - socket_like_rpc.open = AsyncMock() - socket_like_rpc.send = AsyncMock() - socket_like_rpc.close = AsyncMock() - - mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) - mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) - mock_response.resource.generation = GENERATION - mock_response.resource.size = 0 - mock_response.write_handle = WRITE_HANDLE - socket_like_rpc.recv = AsyncMock(return_value=mock_response) - - write_obj_stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - - if open: - await write_obj_stream.open() - - return write_obj_stream - - -def test_async_write_object_stream_init(mock_client): - """Test the constructor of _AsyncWriteObjectStream.""" - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - - assert stream.client == mock_client - assert stream.bucket_name == BUCKET - assert stream.object_name == OBJECT - assert stream.generation_number is None - assert stream.write_handle is None - assert stream._full_bucket_name == f"projects/_/buckets/{BUCKET}" - assert stream.rpc == mock.sentinel.wrapped_bidi_write_object - assert stream.metadata == ( - ("x-goog-request-params", f"bucket=projects/_/buckets/{BUCKET}"), - ) - assert stream.socket_like_rpc is None - assert not stream._is_stream_open - assert stream.first_bidi_write_req is None - assert stream.persisted_size == 0 - assert stream.object_resource is None - - -def test_async_write_object_stream_init_with_generation_and_handle(mock_client): - """Test the constructor with optional arguments.""" - generation = 12345 - write_handle = b"test-handle" - stream = _AsyncWriteObjectStream( - mock_client, - BUCKET, - OBJECT, - generation_number=generation, - write_handle=write_handle, - ) - - assert stream.generation_number == generation - assert stream.write_handle == write_handle - - -def test_async_write_object_stream_init_raises_value_error(): - """Test that the constructor raises ValueError for missing arguments.""" - with pytest.raises(ValueError, match="client must be provided"): - _AsyncWriteObjectStream(None, BUCKET, OBJECT) - - with pytest.raises(ValueError, match="bucket_name must be provided"): - _AsyncWriteObjectStream(mock.Mock(), None, OBJECT) - - with pytest.raises(ValueError, match="object_name must be provided"): - _AsyncWriteObjectStream(mock.Mock(), BUCKET, None) - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_for_new_object(mock_async_bidi_rpc, mock_client): - """Test opening a stream for a new object.""" - # Arrange - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - socket_like_rpc.open = mock.AsyncMock() - - mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) - mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) - mock_response.resource.generation = GENERATION - mock_response.resource.size = 0 - mock_response.write_handle = WRITE_HANDLE - socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) - - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - - # Act - await stream.open() - - # Assert - assert stream._is_stream_open - socket_like_rpc.open.assert_called_once() - socket_like_rpc.recv.assert_called_once() - assert stream.generation_number == GENERATION - assert stream.write_handle == WRITE_HANDLE - assert stream.persisted_size == 0 - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_for_existing_object(mock_async_bidi_rpc, mock_client): - """Test opening a stream for an existing object.""" - # Arrange - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - socket_like_rpc.open = mock.AsyncMock() - - mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) - mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) - mock_response.resource.size = 1024 - mock_response.resource.generation = GENERATION - mock_response.write_handle = WRITE_HANDLE - socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) - - stream = _AsyncWriteObjectStream( - mock_client, BUCKET, OBJECT, generation_number=GENERATION - ) - - # Act - await stream.open() - - # Assert - assert stream._is_stream_open - socket_like_rpc.open.assert_called_once() - socket_like_rpc.recv.assert_called_once() - assert stream.generation_number == GENERATION - assert stream.write_handle == WRITE_HANDLE - assert stream.persisted_size == 1024 - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_when_already_open_raises_error(mock_async_bidi_rpc, mock_client): - """Test that opening an already open stream raises a ValueError.""" - # Arrange - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - socket_like_rpc.open = mock.AsyncMock() - - mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) - mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) - mock_response.resource.generation = GENERATION - mock_response.resource.size = 0 - mock_response.write_handle = WRITE_HANDLE - socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) - - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - await stream.open() - - # Act & Assert - with pytest.raises(ValueError, match="Stream is already open"): - await stream.open() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_raises_error_on_missing_object_resource( - mock_async_bidi_rpc, mock_client -): - """Test that open raises ValueError if object_resource is not in the response.""" - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - - mock_reponse = mock.AsyncMock() - type(mock_reponse).resource = mock.PropertyMock(return_value=None) - socket_like_rpc.recv.return_value = mock_reponse - - # Note: Don't use below code as unittest library automatically assigns an - # `AsyncMock` object to an attribute, if not set. - # socket_like_rpc.recv.return_value = mock.AsyncMock( - # return_value=_storage_v2.BidiWriteObjectResponse(resource=None) - # ) - - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Failed to obtain object resource after opening the stream" - ): - await stream.open() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_raises_error_on_missing_generation( - mock_async_bidi_rpc, mock_client -): - """Test that open raises ValueError if generation is not in the response.""" - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - - # Configure the mock response object - mock_response = mock.AsyncMock() - type(mock_response.resource).generation = mock.PropertyMock(return_value=None) - socket_like_rpc.recv.return_value = mock_response - - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Failed to obtain object generation after opening the stream" - ): - await stream.open() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_raises_error_on_missing_write_handle( - mock_async_bidi_rpc, mock_client -): - """Test that open raises ValueError if write_handle is not in the response.""" - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - socket_like_rpc.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse( - resource=_storage_v2.Object(generation=GENERATION), write_handle=None - ) - ) - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - with pytest.raises(ValueError, match="Failed to obtain write_handle"): - await stream.open() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_close(mock_cls_async_bidi_rpc, mock_client): - """Test that close successfully closes the stream.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=True - ) - - # Act - await write_obj_stream.close() - - # Assert - write_obj_stream.socket_like_rpc.close.assert_called_once() - assert not write_obj_stream.is_stream_open - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_close_without_open_should_raise_error( - mock_cls_async_bidi_rpc, mock_client -): - """Test that closing a stream that is not open raises a ValueError.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=False - ) - - # Act & Assert - with pytest.raises(ValueError, match="Stream is not open"): - await write_obj_stream.close() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_send(mock_cls_async_bidi_rpc, mock_client): - """Test that send calls the underlying rpc's send method.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=True - ) - - # Act - bidi_write_object_request = _storage_v2.BidiWriteObjectRequest() - await write_obj_stream.send(bidi_write_object_request) - - # Assert - write_obj_stream.socket_like_rpc.send.assert_called_once_with( - bidi_write_object_request - ) - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_send_without_open_should_raise_error( - mock_cls_async_bidi_rpc, mock_client -): - """Test that sending on a stream that is not open raises a ValueError.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=False - ) - - # Act & Assert - with pytest.raises(ValueError, match="Stream is not open"): - await write_obj_stream.send(_storage_v2.BidiWriteObjectRequest()) - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_recv(mock_cls_async_bidi_rpc, mock_client): - """Test that recv calls the underlying rpc's recv method.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=True - ) - bidi_write_object_response = _storage_v2.BidiWriteObjectResponse() - write_obj_stream.socket_like_rpc.recv = AsyncMock( - return_value=bidi_write_object_response - ) - - # Act - response = await write_obj_stream.recv() - - # Assert - write_obj_stream.socket_like_rpc.recv.assert_called_once() - assert response == bidi_write_object_response - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_recv_without_open_should_raise_error( - mock_cls_async_bidi_rpc, mock_client -): - """Test that receiving on a stream that is not open raises a ValueError.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=False - ) - - # Act & Assert - with pytest.raises(ValueError, match="Stream is not open"): - await write_obj_stream.recv() +# # Copyright 2025 Google LLC +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +# import pytest +# from unittest import mock + +# from unittest.mock import AsyncMock +# from google.cloud.storage._experimental.asyncio.async_write_object_stream import ( +# _AsyncWriteObjectStream, +# ) +# from google.cloud import _storage_v2 + +# BUCKET = "my-bucket" +# OBJECT = "my-object" +# GENERATION = 12345 +# WRITE_HANDLE = b"test-handle" + + +# @pytest.fixture +# def mock_client(): +# """Mock the async gRPC client.""" +# mock_transport = mock.AsyncMock() +# mock_transport.bidi_write_object = mock.sentinel.bidi_write_object +# mock_transport._wrapped_methods = { +# mock.sentinel.bidi_write_object: mock.sentinel.wrapped_bidi_write_object +# } + +# mock_gapic_client = mock.AsyncMock() +# mock_gapic_client._transport = mock_transport + +# client = mock.AsyncMock() +# client._client = mock_gapic_client +# return client + + +# async def instantiate_write_obj_stream(mock_client, mock_cls_async_bidi_rpc, open=True): +# """Helper to create an instance of _AsyncWriteObjectStream and open it by default.""" +# socket_like_rpc = AsyncMock() +# mock_cls_async_bidi_rpc.return_value = socket_like_rpc +# socket_like_rpc.open = AsyncMock() +# socket_like_rpc.send = AsyncMock() +# socket_like_rpc.close = AsyncMock() + +# mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) +# mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) +# mock_response.resource.generation = GENERATION +# mock_response.resource.size = 0 +# mock_response.write_handle = WRITE_HANDLE +# socket_like_rpc.recv = AsyncMock(return_value=mock_response) + +# write_obj_stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + +# if open: +# await write_obj_stream.open() + +# return write_obj_stream + + +# def test_async_write_object_stream_init(mock_client): +# """Test the constructor of _AsyncWriteObjectStream.""" +# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + +# assert stream.client == mock_client +# assert stream.bucket_name == BUCKET +# assert stream.object_name == OBJECT +# assert stream.generation_number is None +# assert stream.write_handle is None +# assert stream._full_bucket_name == f"projects/_/buckets/{BUCKET}" +# assert stream.rpc == mock.sentinel.wrapped_bidi_write_object +# assert stream.metadata == ( +# ("x-goog-request-params", f"bucket=projects/_/buckets/{BUCKET}"), +# ) +# assert stream.socket_like_rpc is None +# assert not stream._is_stream_open +# assert stream.first_bidi_write_req is None +# assert stream.persisted_size == 0 +# assert stream.object_resource is None + + +# def test_async_write_object_stream_init_with_generation_and_handle(mock_client): +# """Test the constructor with optional arguments.""" +# generation = 12345 +# write_handle = b"test-handle" +# stream = _AsyncWriteObjectStream( +# mock_client, +# BUCKET, +# OBJECT, +# generation_number=generation, +# write_handle=write_handle, +# ) + +# assert stream.generation_number == generation +# assert stream.write_handle == write_handle + + +# def test_async_write_object_stream_init_raises_value_error(): +# """Test that the constructor raises ValueError for missing arguments.""" +# with pytest.raises(ValueError, match="client must be provided"): +# _AsyncWriteObjectStream(None, BUCKET, OBJECT) + +# with pytest.raises(ValueError, match="bucket_name must be provided"): +# _AsyncWriteObjectStream(mock.Mock(), None, OBJECT) + +# with pytest.raises(ValueError, match="object_name must be provided"): +# _AsyncWriteObjectStream(mock.Mock(), BUCKET, None) + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_open_for_new_object(mock_async_bidi_rpc, mock_client): +# """Test opening a stream for a new object.""" +# # Arrange +# socket_like_rpc = mock.AsyncMock() +# mock_async_bidi_rpc.return_value = socket_like_rpc +# socket_like_rpc.open = mock.AsyncMock() + +# mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) +# mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) +# mock_response.resource.generation = GENERATION +# mock_response.resource.size = 0 +# mock_response.write_handle = WRITE_HANDLE +# socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) + +# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + +# # Act +# await stream.open() + +# # Assert +# assert stream._is_stream_open +# socket_like_rpc.open.assert_called_once() +# socket_like_rpc.recv.assert_called_once() +# assert stream.generation_number == GENERATION +# assert stream.write_handle == WRITE_HANDLE +# assert stream.persisted_size == 0 + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_open_for_existing_object(mock_async_bidi_rpc, mock_client): +# """Test opening a stream for an existing object.""" +# # Arrange +# socket_like_rpc = mock.AsyncMock() +# mock_async_bidi_rpc.return_value = socket_like_rpc +# socket_like_rpc.open = mock.AsyncMock() + +# mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) +# mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) +# mock_response.resource.size = 1024 +# mock_response.resource.generation = GENERATION +# mock_response.write_handle = WRITE_HANDLE +# socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) + +# stream = _AsyncWriteObjectStream( +# mock_client, BUCKET, OBJECT, generation_number=GENERATION +# ) + +# # Act +# await stream.open() + +# # Assert +# assert stream._is_stream_open +# socket_like_rpc.open.assert_called_once() +# socket_like_rpc.recv.assert_called_once() +# assert stream.generation_number == GENERATION +# assert stream.write_handle == WRITE_HANDLE +# assert stream.persisted_size == 1024 + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_open_when_already_open_raises_error(mock_async_bidi_rpc, mock_client): +# """Test that opening an already open stream raises a ValueError.""" +# # Arrange +# socket_like_rpc = mock.AsyncMock() +# mock_async_bidi_rpc.return_value = socket_like_rpc +# socket_like_rpc.open = mock.AsyncMock() + +# mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) +# mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) +# mock_response.resource.generation = GENERATION +# mock_response.resource.size = 0 +# mock_response.write_handle = WRITE_HANDLE +# socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) + +# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) +# await stream.open() + +# # Act & Assert +# with pytest.raises(ValueError, match="Stream is already open"): +# await stream.open() + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_open_raises_error_on_missing_object_resource( +# mock_async_bidi_rpc, mock_client +# ): +# """Test that open raises ValueError if object_resource is not in the response.""" +# socket_like_rpc = mock.AsyncMock() +# mock_async_bidi_rpc.return_value = socket_like_rpc + +# mock_reponse = mock.AsyncMock() +# type(mock_reponse).resource = mock.PropertyMock(return_value=None) +# socket_like_rpc.recv.return_value = mock_reponse + +# # Note: Don't use below code as unittest library automatically assigns an +# # `AsyncMock` object to an attribute, if not set. +# # socket_like_rpc.recv.return_value = mock.AsyncMock( +# # return_value=_storage_v2.BidiWriteObjectResponse(resource=None) +# # ) + +# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) +# with pytest.raises( +# ValueError, match="Failed to obtain object resource after opening the stream" +# ): +# await stream.open() + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_open_raises_error_on_missing_generation( +# mock_async_bidi_rpc, mock_client +# ): +# """Test that open raises ValueError if generation is not in the response.""" +# socket_like_rpc = mock.AsyncMock() +# mock_async_bidi_rpc.return_value = socket_like_rpc + +# # Configure the mock response object +# mock_response = mock.AsyncMock() +# type(mock_response.resource).generation = mock.PropertyMock(return_value=None) +# socket_like_rpc.recv.return_value = mock_response + +# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) +# with pytest.raises( +# ValueError, match="Failed to obtain object generation after opening the stream" +# ): +# await stream.open() + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_open_raises_error_on_missing_write_handle( +# mock_async_bidi_rpc, mock_client +# ): +# """Test that open raises ValueError if write_handle is not in the response.""" +# socket_like_rpc = mock.AsyncMock() +# mock_async_bidi_rpc.return_value = socket_like_rpc +# socket_like_rpc.recv = mock.AsyncMock( +# return_value=_storage_v2.BidiWriteObjectResponse( +# resource=_storage_v2.Object(generation=GENERATION), write_handle=None +# ) +# ) +# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) +# with pytest.raises(ValueError, match="Failed to obtain write_handle"): +# await stream.open() + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_close(mock_cls_async_bidi_rpc, mock_client): +# """Test that close successfully closes the stream.""" +# # Arrange +# write_obj_stream = await instantiate_write_obj_stream( +# mock_client, mock_cls_async_bidi_rpc, open=True +# ) + +# # Act +# await write_obj_stream.close() + +# # Assert +# write_obj_stream.socket_like_rpc.close.assert_called_once() +# assert not write_obj_stream.is_stream_open + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_close_without_open_should_raise_error( +# mock_cls_async_bidi_rpc, mock_client +# ): +# """Test that closing a stream that is not open raises a ValueError.""" +# # Arrange +# write_obj_stream = await instantiate_write_obj_stream( +# mock_client, mock_cls_async_bidi_rpc, open=False +# ) + +# # Act & Assert +# with pytest.raises(ValueError, match="Stream is not open"): +# await write_obj_stream.close() + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_send(mock_cls_async_bidi_rpc, mock_client): +# """Test that send calls the underlying rpc's send method.""" +# # Arrange +# write_obj_stream = await instantiate_write_obj_stream( +# mock_client, mock_cls_async_bidi_rpc, open=True +# ) + +# # Act +# bidi_write_object_request = _storage_v2.BidiWriteObjectRequest() +# await write_obj_stream.send(bidi_write_object_request) + +# # Assert +# write_obj_stream.socket_like_rpc.send.assert_called_once_with( +# bidi_write_object_request +# ) + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_send_without_open_should_raise_error( +# mock_cls_async_bidi_rpc, mock_client +# ): +# """Test that sending on a stream that is not open raises a ValueError.""" +# # Arrange +# write_obj_stream = await instantiate_write_obj_stream( +# mock_client, mock_cls_async_bidi_rpc, open=False +# ) + +# # Act & Assert +# with pytest.raises(ValueError, match="Stream is not open"): +# await write_obj_stream.send(_storage_v2.BidiWriteObjectRequest()) + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_recv(mock_cls_async_bidi_rpc, mock_client): +# """Test that recv calls the underlying rpc's recv method.""" +# # Arrange +# write_obj_stream = await instantiate_write_obj_stream( +# mock_client, mock_cls_async_bidi_rpc, open=True +# ) +# bidi_write_object_response = _storage_v2.BidiWriteObjectResponse() +# write_obj_stream.socket_like_rpc.recv = AsyncMock( +# return_value=bidi_write_object_response +# ) + +# # Act +# response = await write_obj_stream.recv() + +# # Assert +# write_obj_stream.socket_like_rpc.recv.assert_called_once() +# assert response == bidi_write_object_response + + +# @pytest.mark.asyncio +# @mock.patch( +# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +# ) +# async def test_recv_without_open_should_raise_error( +# mock_cls_async_bidi_rpc, mock_client +# ): +# """Test that receiving on a stream that is not open raises a ValueError.""" +# # Arrange +# write_obj_stream = await instantiate_write_obj_stream( +# mock_client, mock_cls_async_bidi_rpc, open=False +# ) + +# # Act & Assert +# with pytest.raises(ValueError, match="Stream is not open"): +# await write_obj_stream.recv() From 5d39296c7be1cd74fd0e8b24e84811cc5f22ee3d Mon Sep 17 00:00:00 2001 From: Pulkit Aggarwal Date: Tue, 13 Jan 2026 06:53:15 +0000 Subject: [PATCH 07/10] adding more unit tests --- .../asyncio/async_appendable_object_writer.py | 72 +-- .../asyncio/async_write_object_stream.py | 20 +- .../_experimental/asyncio/retry/_helpers.py | 46 +- .../retry/writes_resumption_strategy.py | 44 +- tests/conformance/test_bidi_writes.py | 4 +- .../test_async_appendable_object_writer.py | 483 +++++--------- .../asyncio/test_async_write_object_stream.py | 591 ++++++------------ 7 files changed, 428 insertions(+), 832 deletions(-) diff --git a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py index b68475c30..99cf391b5 100644 --- a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py +++ b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py @@ -27,7 +27,7 @@ import asyncio import io import logging -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from google.api_core import exceptions from google.api_core.retry_async import AsyncRetry @@ -51,6 +51,9 @@ _WriteResumptionStrategy, _WriteState, ) +from google.cloud.storage._experimental.asyncio.retry._helpers import ( + _extract_bidi_writes_redirect_proto, +) _MAX_CHUNK_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB @@ -71,14 +74,18 @@ def _is_write_retryable(exc): exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, exceptions.TooManyRequests, + BidiWriteObjectRedirectedError ), ): logger.info(f"Retryable write exception encountered: {exc}") return True grpc_error = None - if isinstance(exc, exceptions.Aborted): + if isinstance(exc, exceptions.Aborted) and exc.errors: grpc_error = exc.errors[0] + if isinstance(grpc_error, BidiWriteObjectRedirectedError): + return True + trailers = grpc_error.trailing_metadata() if not trailers: return False @@ -226,53 +233,14 @@ async def state_lookup(self) -> int: def _on_open_error(self, exc): """Extracts routing token and write handle on redirect error during open.""" - grpc_error = None - if isinstance(exc, exceptions.Aborted) and exc.errors: - grpc_error = exc.errors[0] - - if grpc_error: - if isinstance(grpc_error, BidiWriteObjectRedirectedError): - self._routing_token = grpc_error.routing_token - if grpc_error.write_handle: - self.write_handle = grpc_error.write_handle - if grpc_error.generation: - self.generation = grpc_error.generation - return - - if hasattr(grpc_error, "trailing_metadata"): - trailers = grpc_error.trailing_metadata() - if not trailers: - return - - status_details_bin = None - for key, value in trailers: - if key == "grpc-status-details-bin": - status_details_bin = value - break - - if status_details_bin: - status_proto = status_pb2.Status() - try: - status_proto.ParseFromString(status_details_bin) - for detail in status_proto.details: - if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL: - redirect_proto = ( - BidiWriteObjectRedirectedError.deserialize( - detail.value - ) - ) - if redirect_proto.routing_token: - self._routing_token = redirect_proto.routing_token - if redirect_proto.write_handle: - self.write_handle = redirect_proto.write_handle - if redirect_proto.generation: - self.generation = redirect_proto.generation - break - except Exception: - logger.error( - "Error unpacking redirect details from gRPC error." - ) - pass + redirect_proto = _extract_bidi_writes_redirect_proto(exc) + if redirect_proto: + if redirect_proto.routing_token: + self._routing_token = redirect_proto.routing_token + if redirect_proto.write_handle: + self.write_handle = redirect_proto.write_handle + if redirect_proto.generation: + self.generation = redirect_proto.generation async def open( self, @@ -448,9 +416,11 @@ async def generator(): if resp.persisted_size is not None: self.persisted_size = resp.persisted_size state["write_state"].persisted_size = resp.persisted_size + self.offset = self.persisted_size if resp.write_handle: self.write_handle = resp.write_handle state["write_state"].write_handle = resp.write_handle + self.bytes_appended_since_last_flush = 0 yield resp @@ -473,6 +443,8 @@ async def generator(): self.write_obj_stream.persisted_size = write_state.persisted_size self.write_obj_stream.write_handle = write_state.write_handle self.bytes_appended_since_last_flush = write_state.bytes_since_last_flush + self.persisted_size = write_state.persisted_size + self.offset = write_state.persisted_size async def simple_flush(self) -> None: """Flushes the data to the server. @@ -492,6 +464,7 @@ async def simple_flush(self) -> None: flush=True, ) ) + self.bytes_appended_since_last_flush = 0 async def flush(self) -> int: """Flushes the data to the server. @@ -515,6 +488,7 @@ async def flush(self) -> int: response = await self.write_obj_stream.recv() self.persisted_size = response.persisted_size self.offset = self.persisted_size + self.bytes_appended_since_last_flush = 0 return self.persisted_size async def close(self, finalize_on_close=False) -> Union[int, _storage_v2.Object]: diff --git a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py index 45a4cf072..256911073 100644 --- a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py @@ -147,25 +147,9 @@ async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: response = await self.socket_like_rpc.recv() self._is_stream_open = True - if response.persisted_size >= 0: + if response.persisted_size: self.persisted_size = response.persisted_size - if response.write_handle: - self.write_handle = response.write_handle - # return - - # if not response.resource: - # raise ValueError( - # "Failed to obtain object resource after opening the stream" - # ) - # if not response.resource.generation: - # raise ValueError( - # "Failed to obtain object generation after opening the stream" - # ) - - # if not response.write_handle: - # raise ValueError("Failed to obtain write_handle after opening the stream") - if response.resource: if not response.resource.size: # Appending to a 0 byte appendable object. @@ -174,6 +158,8 @@ async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: self.persisted_size = response.resource.size self.generation_number = response.resource.generation + + if response.write_handle: self.write_handle = response.write_handle async def close(self) -> None: diff --git a/google/cloud/storage/_experimental/asyncio/retry/_helpers.py b/google/cloud/storage/_experimental/asyncio/retry/_helpers.py index 627bf5944..dcc830cc5 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/_helpers.py +++ b/google/cloud/storage/_experimental/asyncio/retry/_helpers.py @@ -18,12 +18,16 @@ from typing import Tuple, Optional from google.api_core import exceptions -from google.cloud._storage_v2.types import BidiReadObjectRedirectedError +from google.cloud._storage_v2.types import BidiReadObjectRedirectedError, BidiWriteObjectRedirectedError from google.rpc import status_pb2 _BIDI_READ_REDIRECTED_TYPE_URL = ( "type.googleapis.com/google.storage.v2.BidiReadObjectRedirectedError" ) +_BIDI_WRITE_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" +) +logger = logging.getLogger(__name__) def _handle_redirect( @@ -78,6 +82,44 @@ def _handle_redirect( read_handle = redirect_proto.read_handle break except Exception as e: - logging.ERROR(f"Error unpacking redirect: {e}") + logger.error(f"Error unpacking redirect: {e}") return routing_token, read_handle + +def _extract_bidi_writes_redirect_proto(exc: Exception): + grpc_error = None + if isinstance(exc, exceptions.Aborted) and exc.errors: + grpc_error = exc.errors[0] + + if grpc_error: + if isinstance(grpc_error, BidiWriteObjectRedirectedError): + return grpc_error + + if hasattr(grpc_error, "trailing_metadata"): + trailers = grpc_error.trailing_metadata() + if not trailers: + return + + status_details_bin = None + for key, value in trailers: + if key == "grpc-status-details-bin": + status_details_bin = value + break + + if status_details_bin: + status_proto = status_pb2.Status() + try: + status_proto.ParseFromString(status_details_bin) + for detail in status_proto.details: + if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL: + redirect_proto = ( + BidiWriteObjectRedirectedError.deserialize( + detail.value + ) + ) + return redirect_proto + except Exception: + logger.error( + "Error unpacking redirect details from gRPC error." + ) + pass diff --git a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py index 7a2a84d16..4ad20662b 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py @@ -22,6 +22,10 @@ from google.cloud.storage._experimental.asyncio.retry.base_strategy import ( _BaseResumptionStrategy, ) +from google.cloud.storage._experimental.asyncio.retry._helpers import ( + _extract_bidi_writes_redirect_proto, +) + _BIDI_WRITE_REDIRECTED_TYPE_URL = ( "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" @@ -139,39 +143,17 @@ async def recover_state_on_failure( if grpc_error: # Extract routing token and potentially a new write handle for redirection. if isinstance(grpc_error, BidiWriteObjectRedirectedError): - self._routing_token = grpc_error.routing_token + write_state.routing_token = grpc_error.routing_token if grpc_error.write_handle: - self.write_handle = grpc_error.write_handle + write_state.write_handle = grpc_error.write_handle return - if hasattr(grpc_error, "trailing_metadata"): - trailers = grpc_error.trailing_metadata() - if not trailers: - return - - status_details_bin = None - for key, value in trailers: - if key == "grpc-status-details-bin": - status_details_bin = value - break - - if status_details_bin: - status_proto = status_pb2.Status() - try: - status_proto.ParseFromString(status_details_bin) - for detail in status_proto.details: - if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL: - redirect_proto = ( - BidiWriteObjectRedirectedError.deserialize( - detail.value - ) - ) - if redirect_proto.routing_token: - write_state._routing_token = redirect_proto.routing_token - if redirect_proto.write_handle: - write_state.write_handle = redirect_proto.write_handle - break - except Exception: - pass + + redirect_proto = _extract_bidi_writes_redirect_proto(error) + if redirect_proto: + if redirect_proto.routing_token: + write_state.routing_token = redirect_proto.routing_token + if redirect_proto.write_handle: + write_state.write_handle = redirect_proto.write_handle # We must assume any data sent beyond 'persisted_size' was lost. # Reset the user buffer to the last known good byte confirmed by the server. diff --git a/tests/conformance/test_bidi_writes.py b/tests/conformance/test_bidi_writes.py index 4adfd266a..90dfaf5f8 100644 --- a/tests/conformance/test_bidi_writes.py +++ b/tests/conformance/test_bidi_writes.py @@ -85,8 +85,8 @@ def on_retry_error(exc): await writer.append( CONTENT, metadata=fault_injection_metadata, retry_policy=policy_to_pass ) - await writer.finalize() - await writer.close() + # await writer.finalize() + await writer.close(finalize_on_close=True) # If an exception was expected, this line should not be reached. if scenario["expected_error"] is not None: diff --git a/tests/unit/asyncio/test_async_appendable_object_writer.py b/tests/unit/asyncio/test_async_appendable_object_writer.py index ac6716fbf..2b8680e6e 100644 --- a/tests/unit/asyncio/test_async_appendable_object_writer.py +++ b/tests/unit/asyncio/test_async_appendable_object_writer.py @@ -15,7 +15,9 @@ import io import unittest import unittest.mock as mock +from unittest.mock import AsyncMock, MagicMock import pytest + from google.api_core import exceptions from google.rpc import status_pb2 from google.cloud._storage_v2.types import storage as storage_type @@ -27,6 +29,7 @@ _DEFAULT_FLUSH_INTERVAL_BYTES, ) +# Constants BUCKET = "test-bucket" OBJECT = "test-object" GENERATION = 123 @@ -36,14 +39,16 @@ class TestIsWriteRetryable(unittest.TestCase): - def test_transient_errors(self): - for exc_type in [ - exceptions.InternalServerError, - exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, - exceptions.TooManyRequests, + """Exhaustive tests for retry predicate logic.""" + + def test_standard_transient_errors(self): + for exc in [ + exceptions.InternalServerError("500"), + exceptions.ServiceUnavailable("503"), + exceptions.DeadlineExceeded("timeout"), + exceptions.TooManyRequests("429"), ]: - self.assertTrue(_is_write_retryable(exc_type("error"))) + self.assertTrue(_is_write_retryable(exc)) def test_aborted_with_redirect_proto(self): # Direct redirect error wrapped in Aborted @@ -52,41 +57,56 @@ def test_aborted_with_redirect_proto(self): self.assertTrue(_is_write_retryable(exc)) def test_aborted_with_trailers(self): - # Redirect hidden in trailers + # Setup Status with Redirect Detail status = status_pb2.Status() detail = status.details.add() detail.type_url = "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" - # Correctly serialize the proto message to bytes for the detail value - redirect_proto = BidiWriteObjectRedirectedError(routing_token="rt2") - detail.value = BidiWriteObjectRedirectedError.serialize(redirect_proto) + # Mock error with trailing_metadata method + mock_grpc_error = MagicMock() + mock_grpc_error.trailing_metadata.return_value = [ + ("grpc-status-details-bin", status.SerializeToString()) + ] - exc = exceptions.Aborted("aborted") - exc.trailing_metadata = [("grpc-status-details-bin", status.SerializeToString())] + # Aborted wraps the grpc error + exc = exceptions.Aborted("aborted", errors=[mock_grpc_error]) self.assertTrue(_is_write_retryable(exc)) - def test_non_retryable(self): - self.assertFalse(_is_write_retryable(exceptions.BadRequest("bad"))) - self.assertFalse(_is_write_retryable(exceptions.Aborted("just aborted"))) + def test_aborted_without_metadata(self): + mock_grpc_error = MagicMock() + mock_grpc_error.trailing_metadata.return_value = [] + exc = exceptions.Aborted("bare aborted", errors=[mock_grpc_error]) + self.assertFalse(_is_write_retryable(exc)) + + def test_non_retryable_errors(self): + self.assertFalse(_is_write_retryable(exceptions.BadRequest("400"))) + self.assertFalse(_is_write_retryable(exceptions.NotFound("404"))) class TestAsyncAppendableObjectWriter(unittest.IsolatedAsyncioTestCase): def setUp(self): self.mock_client = mock.AsyncMock() - # Patch the stream class used internally - self.mock_stream_cls = mock.patch( + # Internal stream class patch + self.mock_stream_patcher = mock.patch( "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" - ).start() + ) + self.mock_stream_cls = self.mock_stream_patcher.start() self.mock_stream = self.mock_stream_cls.return_value - # Default mock stream state + # Configure all async methods explicitly + self.mock_stream.open = AsyncMock() + self.mock_stream.close = AsyncMock() + self.mock_stream.send = AsyncMock() + self.mock_stream.recv = AsyncMock() + + # Default mock properties self.mock_stream.is_stream_open = False self.mock_stream.persisted_size = 0 self.mock_stream.generation_number = GENERATION self.mock_stream.write_handle = WRITE_HANDLE def tearDown(self): - mock.patch.stopall() + self.mock_stream_patcher.stop() def _make_one(self, **kwargs): return AsyncAppendableObjectWriter( @@ -94,462 +114,255 @@ def _make_one(self, **kwargs): ) # ------------------------------------------------------------------------- - # Initialization Tests + # Initialization & Configuration Tests # ------------------------------------------------------------------------- def test_init_defaults(self): writer = self._make_one() - self.assertEqual(writer.client, self.mock_client) self.assertEqual(writer.bucket_name, BUCKET) self.assertEqual(writer.object_name, OBJECT) - self.assertIsNone(writer.generation) - self.assertIsNone(writer.write_handle) - self.assertFalse(writer._is_stream_open) self.assertIsNone(writer.persisted_size) self.assertEqual(writer.bytes_appended_since_last_flush, 0) - - def test_init_with_optional_args(self): - writer = self._make_one( - generation=GENERATION, - write_handle=WRITE_HANDLE, - ) - self.assertEqual(writer.generation, GENERATION) - self.assertEqual(writer.write_handle, WRITE_HANDLE) + self.assertEqual(writer.flush_interval, _DEFAULT_FLUSH_INTERVAL_BYTES) def test_init_with_writer_options(self): writer = self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": EIGHT_MIB}) self.assertEqual(writer.flush_interval, EIGHT_MIB) - def test_init_validation_chunk_size(self): + def test_init_validation_chunk_size_raises(self): with self.assertRaises(exceptions.OutOfRange): self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES - 1}) - def test_init_validation_chunk_multiple(self): + def test_init_validation_multiple_raises(self): with self.assertRaises(exceptions.OutOfRange): self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES + 1}) - def test_init_raises_if_crc32c_c_extension_is_missing(self): + def test_init_raises_if_crc32c_missing(self): with mock.patch("google.cloud.storage._experimental.asyncio._utils.google_crc32c") as mock_crc: mock_crc.implementation = "python" - with self.assertRaisesRegex(exceptions.FailedPrecondition, "google-crc32c package is not installed"): + with self.assertRaises(exceptions.FailedPrecondition): self._make_one() # ------------------------------------------------------------------------- - # Helper Method Tests + # Stream Lifecycle Tests # ------------------------------------------------------------------------- - async def test_state_lookup(self): + async def test_state_lookup_success(self): writer = self._make_one() writer._is_stream_open = True writer.write_obj_stream = self.mock_stream - self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( - persisted_size=PERSISTED_SIZE - ) + self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse(persisted_size=100) - resp = await writer.state_lookup() + size = await writer.state_lookup() - self.mock_stream.send.assert_awaited_once_with( - storage_type.BidiWriteObjectRequest(state_lookup=True) - ) - self.assertEqual(resp, PERSISTED_SIZE) - self.assertEqual(writer.persisted_size, PERSISTED_SIZE) + self.mock_stream.send.assert_awaited_once() + self.assertEqual(size, 100) + self.assertEqual(writer.persisted_size, 100) - async def test_state_lookup_not_open_raises(self): + async def test_state_lookup_raises_if_not_open(self): writer = self._make_one() with self.assertRaisesRegex(ValueError, "Stream is not open"): await writer.state_lookup() - async def test_unimplemented_methods(self): - writer = self._make_one() - with self.assertRaises(NotImplementedError): - await writer.append_from_string("data") - with self.assertRaises(NotImplementedError): - await writer.append_from_stream(mock.Mock()) - - # ------------------------------------------------------------------------- - # Open & Error Handling Tests - # ------------------------------------------------------------------------- - async def test_open_success(self): writer = self._make_one() - self.mock_stream.generation_number = GENERATION - self.mock_stream.write_handle = WRITE_HANDLE + self.mock_stream.generation_number = 456 + self.mock_stream.write_handle = b"new-h" self.mock_stream.persisted_size = 0 await writer.open() self.assertTrue(writer._is_stream_open) - self.assertEqual(writer.generation, GENERATION) - self.assertEqual(writer.write_handle, WRITE_HANDLE) - self.assertEqual(writer.persisted_size, 0) - - self.mock_stream_cls.assert_called_with( - client=self.mock_client, - bucket_name=BUCKET, - object_name=OBJECT, - generation_number=None, - write_handle=None, - routing_token=None, - ) + self.assertEqual(writer.generation, 456) + self.assertEqual(writer.write_handle, b"new-h") self.mock_stream.open.assert_awaited_once() - async def test_open_appendable_object_writer_existing_object(self): - # Verify opening with existing generation uses AppendObjectSpec implicitly via stream init - writer = self._make_one(generation=GENERATION, write_handle=WRITE_HANDLE) - self.mock_stream.generation_number = GENERATION - self.mock_stream.write_handle = WRITE_HANDLE - self.mock_stream.persisted_size = PERSISTED_SIZE - - await writer.open() - - # Check constructor was called with generation/handle - self.mock_stream_cls.assert_called_with( - client=self.mock_client, - bucket_name=BUCKET, - object_name=OBJECT, - generation_number=GENERATION, - write_handle=WRITE_HANDLE, - routing_token=None, - ) - self.assertEqual(writer.persisted_size, PERSISTED_SIZE) - - async def test_open_with_routing_token_and_metadata(self): - writer = self._make_one() - writer._routing_token = "prev-token" - metadata = [("key", "val")] - - await writer.open(metadata=metadata) - - self.mock_stream_cls.assert_called_with( - client=self.mock_client, - bucket_name=BUCKET, - object_name=OBJECT, - generation_number=None, - write_handle=None, - routing_token="prev-token", - ) - call_kwargs = self.mock_stream.open.call_args[1] - passed_metadata = call_kwargs['metadata'] - self.assertIn(("x-goog-request-params", "routing_token=prev-token"), passed_metadata) - self.assertIsNone(writer._routing_token) - - async def test_open_when_already_open_raises(self): + async def test_open_already_open_raises(self): writer = self._make_one() writer._is_stream_open = True - with self.assertRaisesRegex(ValueError, "Underlying bidi-gRPC stream is already open"): + with self.assertRaisesRegex(ValueError, "already open"): await writer.open() - def test_on_open_error_extraction(self): + def test_on_open_error_redirection(self): + """Verify redirect info is extracted from helper.""" writer = self._make_one() - - # 1. Direct Redirect Error redirect = BidiWriteObjectRedirectedError( - routing_token="rt", - write_handle=storage_type.BidiWriteHandle(handle=b"wh"), - generation=999 - ) - writer._on_open_error(exceptions.Aborted("e", errors=[redirect])) - - self.assertEqual(writer._routing_token, "rt") - self.assertEqual(writer.write_handle.handle, b"wh") - self.assertEqual(writer.generation, 999) - - # 2. Trailer Error - status = status_pb2.Status() - detail = status.details.add() - detail.type_url = "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" - detail.value = BidiWriteObjectRedirectedError.serialize( - BidiWriteObjectRedirectedError(routing_token="rt2") + routing_token="rt1", + write_handle=storage_type.BidiWriteHandle(handle=b"h1"), + generation=777 ) - exc = exceptions.Aborted("e") - exc.trailing_metadata = [("grpc-status-details-bin", status.SerializeToString())] + with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._extract_bidi_writes_redirect_proto", return_value=redirect): + writer._on_open_error(exceptions.Aborted("redirect")) - writer._on_open_error(exc) - self.assertEqual(writer._routing_token, "rt2") + self.assertEqual(writer._routing_token, "rt1") + self.assertEqual(writer.write_handle.handle, b"h1") + self.assertEqual(writer.generation, 777) # ------------------------------------------------------------------------- - # Append Tests + # Append & Integration Tests # ------------------------------------------------------------------------- - async def test_append_not_open_raises(self): - writer = self._make_one() - with self.assertRaisesRegex(ValueError, "Stream is not open"): - await writer.append(b"data") - - async def test_append_empty_data_does_nothing(self): - writer = self._make_one() - writer._is_stream_open = True - with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: - await writer.append(b"") - MockManager.assert_not_called() - - async def test_append_propagates_non_retryable_errors(self): - """Verify non-retryable errors bubble up.""" + async def test_append_integration_basic(self): + """Verify append orchestrates manager and drives the internal generator.""" writer = self._make_one() writer._is_stream_open = True writer.write_obj_stream = self.mock_stream - - with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: - # Simulate RetryManager raising a hard error - MockManager.return_value.execute.side_effect = exceptions.BadRequest("bad") - - with self.assertRaises(exceptions.BadRequest): - await writer.append(b"data") - - async def test_append_basic_flow_integration(self): - """Verify append sets up RetryManager and orchestrates chunks.""" - writer = self._make_one(write_handle=b"h1", generation=1) - writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream writer.persisted_size = 0 - data = b"a" * (_MAX_CHUNK_SIZE_BYTES + 10) # 2MB + 10 bytes + data = b"test-data" with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: - mock_manager_instance = MockManager.return_value - async def mock_execute(state, policy): - generator_factory = MockManager.call_args[0][1] - # Strategy generates chunks (we use dummy ones for integration check) - dummy_requests = [ - storage_type.BidiWriteObjectRequest(write_offset=0), - storage_type.BidiWriteObjectRequest(write_offset=_MAX_CHUNK_SIZE_BYTES) - ] - gen = generator_factory(dummy_requests, state) + factory = MockManager.call_args[0][1] + dummy_reqs = [storage_type.BidiWriteObjectRequest()] + gen = factory(dummy_reqs, state) self.mock_stream.recv.side_effect = [ - storage_type.BidiWriteObjectResponse(persisted_size=100, write_handle=b"h2"), + storage_type.BidiWriteObjectResponse( + persisted_size=len(data), + write_handle=storage_type.BidiWriteHandle(handle=b"h2") + ), None ] async for _ in gen: pass - mock_manager_instance.execute.side_effect = mock_execute - + MockManager.return_value.execute.side_effect = mock_execute await writer.append(data) - self.assertEqual(writer.persisted_size, 100) - self.assertEqual(writer.write_handle, b"h2") - self.assertEqual(self.mock_stream.send.await_count, 2) - # Last chunk should have state_lookup=True - self.assertTrue(self.mock_stream.send.await_args_list[-1][0][0].state_lookup) - - async def test_append_flushes_when_interval_reached(self): - """Verify generator respects flush flag from strategy.""" - # Flush interval matches 2 chunks - flush_interval = _MAX_CHUNK_SIZE_BYTES * 2 - writer = self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": flush_interval}) - writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream - - data = b"a" * flush_interval - - with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: - mock_manager_instance = MockManager.return_value - - async def mock_execute(state, policy): - generator_factory = MockManager.call_args[0][1] - - # Simulate strategy identifying a flush point - req_with_flush = storage_type.BidiWriteObjectRequest(flush=True) - gen = generator_factory([req_with_flush], state) - - self.mock_stream.recv.return_value = None - async for _ in gen: pass + self.assertEqual(writer.persisted_size, len(data)) + sent_req = self.mock_stream.send.call_args[0][0] + self.assertTrue(sent_req.state_lookup) + self.assertTrue(sent_req.flush) - mock_manager_instance.execute.side_effect = mock_execute - await writer.append(data) - - # Verify sent request had flush=True - sent_request = self.mock_stream.send.call_args[0][0] - self.assertTrue(sent_request.flush) - - async def test_append_sequential_calls_update_state(self): - """Test state carry-over between two append calls.""" - writer = self._make_one(write_handle=b"h1", generation=1) - writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream - writer.persisted_size = 0 - - with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: - # 1. First Append - async def execute_1(state, policy): - # Simulate server acknowledging 100 bytes - state["write_state"].persisted_size = 100 - writer.write_obj_stream.persisted_size = 100 - - MockManager.return_value.execute.side_effect = execute_1 - await writer.append(b"a" * 100) - - self.assertEqual(writer.persisted_size, 100) - - # 2. Second Append - async def execute_2(state, policy): - # Verify state passed to manager starts where we left off - assert state["write_state"].persisted_size == 100 - state["write_state"].persisted_size = 200 - - MockManager.return_value.execute.side_effect = execute_2 - await writer.append(b"b" * 100) - - self.assertEqual(writer.persisted_size, 200) - - async def test_append_recovery_flow(self): - """Test internal generator logic when a retry occurs (Attempt > 1).""" - writer = self._make_one(write_handle=b"h1", generation=1) + async def test_append_recovery_reopens_stream(self): + """Verifies re-opening logic on retry.""" + writer = self._make_one(write_handle=b"h1") writer._is_stream_open = True writer.write_obj_stream = self.mock_stream - writer.persisted_size = 0 + # Setup mock to allow close() call + self.mock_stream.is_stream_open = True - async def mock_aaow_open(metadata=None): - writer._is_stream_open = True + async def mock_open(metadata=None): writer.write_obj_stream = self.mock_stream - writer.persisted_size = 4 # Server says 4 bytes persisted - writer.write_handle = b"h_new" + writer._is_stream_open = True + writer.persisted_size = 5 + writer.write_handle = b"h_recovered" - with mock.patch.object(writer, "open", side_effect=mock_aaow_open) as mock_writer_open: + with mock.patch.object(writer, "open", side_effect=mock_open) as mock_writer_open: with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: - async def mock_execute(state, policy): factory = MockManager.call_args[0][1] - - # --- SIMULATE ATTEMPT 1 (Fail) --- + # Simulate Attempt 1 fail gen1 = factory([], state) try: await gen1.__anext__() except: pass - - # --- SIMULATE ATTEMPT 2 (Recovery) --- - # Logic should: close old stream, open new, rewind buffer, generate new requests + # Simulate Attempt 2 gen2 = factory([], state) - self.mock_stream.is_stream_open = True - self.mock_stream.recv.side_effect = [None] + self.mock_stream.recv.return_value = None async for _ in gen2: pass MockManager.return_value.execute.side_effect = mock_execute + await writer.append(b"0123456789") - await writer.append(b"1234567890") - - # Recovery Assertions self.mock_stream.close.assert_awaited() mock_writer_open.assert_awaited() - self.assertEqual(writer.write_handle, b"h_new") - self.assertEqual(writer.persisted_size, 4) + self.assertEqual(writer.persisted_size, 5) - async def test_append_metadata_injection(self): - """Verify providing metadata forces a restart (Attempt 1 logic).""" + async def test_append_unimplemented_string_raises(self): writer = self._make_one() - writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream - custom_meta = [("x-test", "true")] - - with mock.patch.object(writer, "open", new_callable=mock.AsyncMock) as mock_writer_open: - with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: - async def mock_execute(state, policy): - factory = MockManager.call_args[0][1] - gen = factory([], state) - self.mock_stream.recv.return_value = None - async for _ in gen: pass - - MockManager.return_value.execute.side_effect = mock_execute - await writer.append(b"data", metadata=custom_meta) - - self.mock_stream.close.assert_awaited() - mock_writer_open.assert_awaited_with(metadata=custom_meta) + with self.assertRaises(NotImplementedError): + await writer.append_from_string("test") # ------------------------------------------------------------------------- - # Flush, Close, Finalize Tests + # Flush, Close, Finalize # ------------------------------------------------------------------------- - async def test_flush(self): + async def test_flush_resets_counters(self): writer = self._make_one() writer._is_stream_open = True writer.write_obj_stream = self.mock_stream - writer.bytes_appended_since_last_flush = 50 + writer.bytes_appended_since_last_flush = 100 - self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( - persisted_size=100 - ) + self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse(persisted_size=200) - res = await writer.flush() + await writer.flush() - self.mock_stream.send.assert_awaited_with( - storage_type.BidiWriteObjectRequest(flush=True, state_lookup=True) - ) - self.assertEqual(res, 100) self.assertEqual(writer.bytes_appended_since_last_flush, 0) - - async def test_flush_not_open_raises(self): - writer = self._make_one() - with self.assertRaisesRegex(ValueError, "Stream is not open"): - await writer.flush() + self.assertEqual(writer.persisted_size, 200) async def test_simple_flush(self): writer = self._make_one() writer._is_stream_open = True writer.write_obj_stream = self.mock_stream + writer.bytes_appended_since_last_flush = 50 await writer.simple_flush() - self.mock_stream.send.assert_awaited_with( - storage_type.BidiWriteObjectRequest(flush=True) - ) - - async def test_simple_flush_not_open_raises(self): - writer = self._make_one() - with self.assertRaisesRegex(ValueError, "Stream is not open"): - await writer.simple_flush() + self.mock_stream.send.assert_awaited_with(storage_type.BidiWriteObjectRequest(flush=True)) + self.assertEqual(writer.bytes_appended_since_last_flush, 0) - async def test_close(self): + async def test_close_without_finalize(self): writer = self._make_one() writer._is_stream_open = True writer.write_obj_stream = self.mock_stream writer.persisted_size = 50 - res = await writer.close() + size = await writer.close() self.mock_stream.close.assert_awaited() self.assertFalse(writer._is_stream_open) - self.assertEqual(res, 50) + self.assertEqual(size, 50) - async def test_close_not_open_raises(self): - writer = self._make_one() - with self.assertRaisesRegex(ValueError, "Stream is not open"): - await writer.close() - - async def test_finalize(self): + async def test_finalize_lifecycle(self): writer = self._make_one() writer._is_stream_open = True writer.write_obj_stream = self.mock_stream + resource = storage_type.Object(size=999) - self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( - resource=resource - ) + self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse(resource=resource) res = await writer.finalize() - self.mock_stream.send.assert_awaited_with( - storage_type.BidiWriteObjectRequest(finish_write=True) - ) - self.assertEqual(writer.object_resource, resource) - self.assertEqual(writer.persisted_size, 999) self.assertEqual(res, resource) + self.assertEqual(writer.persisted_size, 999) + self.mock_stream.send.assert_awaited_with(storage_type.BidiWriteObjectRequest(finish_write=True)) + self.mock_stream.close.assert_awaited() + self.assertFalse(writer._is_stream_open) - async def test_finalize_not_open_raises(self): + async def test_close_with_finalize_on_close(self): writer = self._make_one() - with self.assertRaisesRegex(ValueError, "Stream is not open"): - await writer.finalize() + writer._is_stream_open = True + writer.finalize = AsyncMock() + + await writer.close(finalize_on_close=True) + writer.finalize.assert_awaited_once() # ------------------------------------------------------------------------- - # Append From File Tests + # Helper Integration Tests # ------------------------------------------------------------------------- - async def test_append_from_file(self): + async def test_append_from_file_integration(self): writer = self._make_one() writer._is_stream_open = True - writer.append = mock.AsyncMock() + writer.append = AsyncMock() - fp = io.BytesIO(b"1234567890") + fp = io.BytesIO(b"a" * 12) await writer.append_from_file(fp, block_size=4) self.assertEqual(writer.append.await_count, 3) + + async def test_methods_require_open_stream_raises(self): + writer = self._make_one() + methods = [ + writer.append(b"data"), + writer.flush(), + writer.simple_flush(), + writer.close(), + writer.finalize(), + writer.state_lookup() + ] + for coro in methods: + with self.assertRaisesRegex(ValueError, "Stream is not open"): + await coro diff --git a/tests/unit/asyncio/test_async_write_object_stream.py b/tests/unit/asyncio/test_async_write_object_stream.py index 4ce1526b8..283bf9b42 100644 --- a/tests/unit/asyncio/test_async_write_object_stream.py +++ b/tests/unit/asyncio/test_async_write_object_stream.py @@ -1,396 +1,195 @@ -# # Copyright 2025 Google LLC -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. - -# import pytest -# from unittest import mock - -# from unittest.mock import AsyncMock -# from google.cloud.storage._experimental.asyncio.async_write_object_stream import ( -# _AsyncWriteObjectStream, -# ) -# from google.cloud import _storage_v2 - -# BUCKET = "my-bucket" -# OBJECT = "my-object" -# GENERATION = 12345 -# WRITE_HANDLE = b"test-handle" - - -# @pytest.fixture -# def mock_client(): -# """Mock the async gRPC client.""" -# mock_transport = mock.AsyncMock() -# mock_transport.bidi_write_object = mock.sentinel.bidi_write_object -# mock_transport._wrapped_methods = { -# mock.sentinel.bidi_write_object: mock.sentinel.wrapped_bidi_write_object -# } - -# mock_gapic_client = mock.AsyncMock() -# mock_gapic_client._transport = mock_transport - -# client = mock.AsyncMock() -# client._client = mock_gapic_client -# return client - - -# async def instantiate_write_obj_stream(mock_client, mock_cls_async_bidi_rpc, open=True): -# """Helper to create an instance of _AsyncWriteObjectStream and open it by default.""" -# socket_like_rpc = AsyncMock() -# mock_cls_async_bidi_rpc.return_value = socket_like_rpc -# socket_like_rpc.open = AsyncMock() -# socket_like_rpc.send = AsyncMock() -# socket_like_rpc.close = AsyncMock() - -# mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) -# mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) -# mock_response.resource.generation = GENERATION -# mock_response.resource.size = 0 -# mock_response.write_handle = WRITE_HANDLE -# socket_like_rpc.recv = AsyncMock(return_value=mock_response) - -# write_obj_stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - -# if open: -# await write_obj_stream.open() - -# return write_obj_stream - - -# def test_async_write_object_stream_init(mock_client): -# """Test the constructor of _AsyncWriteObjectStream.""" -# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - -# assert stream.client == mock_client -# assert stream.bucket_name == BUCKET -# assert stream.object_name == OBJECT -# assert stream.generation_number is None -# assert stream.write_handle is None -# assert stream._full_bucket_name == f"projects/_/buckets/{BUCKET}" -# assert stream.rpc == mock.sentinel.wrapped_bidi_write_object -# assert stream.metadata == ( -# ("x-goog-request-params", f"bucket=projects/_/buckets/{BUCKET}"), -# ) -# assert stream.socket_like_rpc is None -# assert not stream._is_stream_open -# assert stream.first_bidi_write_req is None -# assert stream.persisted_size == 0 -# assert stream.object_resource is None - - -# def test_async_write_object_stream_init_with_generation_and_handle(mock_client): -# """Test the constructor with optional arguments.""" -# generation = 12345 -# write_handle = b"test-handle" -# stream = _AsyncWriteObjectStream( -# mock_client, -# BUCKET, -# OBJECT, -# generation_number=generation, -# write_handle=write_handle, -# ) - -# assert stream.generation_number == generation -# assert stream.write_handle == write_handle - - -# def test_async_write_object_stream_init_raises_value_error(): -# """Test that the constructor raises ValueError for missing arguments.""" -# with pytest.raises(ValueError, match="client must be provided"): -# _AsyncWriteObjectStream(None, BUCKET, OBJECT) - -# with pytest.raises(ValueError, match="bucket_name must be provided"): -# _AsyncWriteObjectStream(mock.Mock(), None, OBJECT) - -# with pytest.raises(ValueError, match="object_name must be provided"): -# _AsyncWriteObjectStream(mock.Mock(), BUCKET, None) - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_open_for_new_object(mock_async_bidi_rpc, mock_client): -# """Test opening a stream for a new object.""" -# # Arrange -# socket_like_rpc = mock.AsyncMock() -# mock_async_bidi_rpc.return_value = socket_like_rpc -# socket_like_rpc.open = mock.AsyncMock() - -# mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) -# mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) -# mock_response.resource.generation = GENERATION -# mock_response.resource.size = 0 -# mock_response.write_handle = WRITE_HANDLE -# socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) - -# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - -# # Act -# await stream.open() - -# # Assert -# assert stream._is_stream_open -# socket_like_rpc.open.assert_called_once() -# socket_like_rpc.recv.assert_called_once() -# assert stream.generation_number == GENERATION -# assert stream.write_handle == WRITE_HANDLE -# assert stream.persisted_size == 0 - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_open_for_existing_object(mock_async_bidi_rpc, mock_client): -# """Test opening a stream for an existing object.""" -# # Arrange -# socket_like_rpc = mock.AsyncMock() -# mock_async_bidi_rpc.return_value = socket_like_rpc -# socket_like_rpc.open = mock.AsyncMock() - -# mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) -# mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) -# mock_response.resource.size = 1024 -# mock_response.resource.generation = GENERATION -# mock_response.write_handle = WRITE_HANDLE -# socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) - -# stream = _AsyncWriteObjectStream( -# mock_client, BUCKET, OBJECT, generation_number=GENERATION -# ) - -# # Act -# await stream.open() - -# # Assert -# assert stream._is_stream_open -# socket_like_rpc.open.assert_called_once() -# socket_like_rpc.recv.assert_called_once() -# assert stream.generation_number == GENERATION -# assert stream.write_handle == WRITE_HANDLE -# assert stream.persisted_size == 1024 - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_open_when_already_open_raises_error(mock_async_bidi_rpc, mock_client): -# """Test that opening an already open stream raises a ValueError.""" -# # Arrange -# socket_like_rpc = mock.AsyncMock() -# mock_async_bidi_rpc.return_value = socket_like_rpc -# socket_like_rpc.open = mock.AsyncMock() - -# mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) -# mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) -# mock_response.resource.generation = GENERATION -# mock_response.resource.size = 0 -# mock_response.write_handle = WRITE_HANDLE -# socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) - -# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) -# await stream.open() - -# # Act & Assert -# with pytest.raises(ValueError, match="Stream is already open"): -# await stream.open() - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_open_raises_error_on_missing_object_resource( -# mock_async_bidi_rpc, mock_client -# ): -# """Test that open raises ValueError if object_resource is not in the response.""" -# socket_like_rpc = mock.AsyncMock() -# mock_async_bidi_rpc.return_value = socket_like_rpc - -# mock_reponse = mock.AsyncMock() -# type(mock_reponse).resource = mock.PropertyMock(return_value=None) -# socket_like_rpc.recv.return_value = mock_reponse - -# # Note: Don't use below code as unittest library automatically assigns an -# # `AsyncMock` object to an attribute, if not set. -# # socket_like_rpc.recv.return_value = mock.AsyncMock( -# # return_value=_storage_v2.BidiWriteObjectResponse(resource=None) -# # ) - -# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) -# with pytest.raises( -# ValueError, match="Failed to obtain object resource after opening the stream" -# ): -# await stream.open() - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_open_raises_error_on_missing_generation( -# mock_async_bidi_rpc, mock_client -# ): -# """Test that open raises ValueError if generation is not in the response.""" -# socket_like_rpc = mock.AsyncMock() -# mock_async_bidi_rpc.return_value = socket_like_rpc - -# # Configure the mock response object -# mock_response = mock.AsyncMock() -# type(mock_response.resource).generation = mock.PropertyMock(return_value=None) -# socket_like_rpc.recv.return_value = mock_response - -# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) -# with pytest.raises( -# ValueError, match="Failed to obtain object generation after opening the stream" -# ): -# await stream.open() - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_open_raises_error_on_missing_write_handle( -# mock_async_bidi_rpc, mock_client -# ): -# """Test that open raises ValueError if write_handle is not in the response.""" -# socket_like_rpc = mock.AsyncMock() -# mock_async_bidi_rpc.return_value = socket_like_rpc -# socket_like_rpc.recv = mock.AsyncMock( -# return_value=_storage_v2.BidiWriteObjectResponse( -# resource=_storage_v2.Object(generation=GENERATION), write_handle=None -# ) -# ) -# stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) -# with pytest.raises(ValueError, match="Failed to obtain write_handle"): -# await stream.open() - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_close(mock_cls_async_bidi_rpc, mock_client): -# """Test that close successfully closes the stream.""" -# # Arrange -# write_obj_stream = await instantiate_write_obj_stream( -# mock_client, mock_cls_async_bidi_rpc, open=True -# ) - -# # Act -# await write_obj_stream.close() - -# # Assert -# write_obj_stream.socket_like_rpc.close.assert_called_once() -# assert not write_obj_stream.is_stream_open - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_close_without_open_should_raise_error( -# mock_cls_async_bidi_rpc, mock_client -# ): -# """Test that closing a stream that is not open raises a ValueError.""" -# # Arrange -# write_obj_stream = await instantiate_write_obj_stream( -# mock_client, mock_cls_async_bidi_rpc, open=False -# ) - -# # Act & Assert -# with pytest.raises(ValueError, match="Stream is not open"): -# await write_obj_stream.close() - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_send(mock_cls_async_bidi_rpc, mock_client): -# """Test that send calls the underlying rpc's send method.""" -# # Arrange -# write_obj_stream = await instantiate_write_obj_stream( -# mock_client, mock_cls_async_bidi_rpc, open=True -# ) - -# # Act -# bidi_write_object_request = _storage_v2.BidiWriteObjectRequest() -# await write_obj_stream.send(bidi_write_object_request) - -# # Assert -# write_obj_stream.socket_like_rpc.send.assert_called_once_with( -# bidi_write_object_request -# ) - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_send_without_open_should_raise_error( -# mock_cls_async_bidi_rpc, mock_client -# ): -# """Test that sending on a stream that is not open raises a ValueError.""" -# # Arrange -# write_obj_stream = await instantiate_write_obj_stream( -# mock_client, mock_cls_async_bidi_rpc, open=False -# ) - -# # Act & Assert -# with pytest.raises(ValueError, match="Stream is not open"): -# await write_obj_stream.send(_storage_v2.BidiWriteObjectRequest()) - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_recv(mock_cls_async_bidi_rpc, mock_client): -# """Test that recv calls the underlying rpc's recv method.""" -# # Arrange -# write_obj_stream = await instantiate_write_obj_stream( -# mock_client, mock_cls_async_bidi_rpc, open=True -# ) -# bidi_write_object_response = _storage_v2.BidiWriteObjectResponse() -# write_obj_stream.socket_like_rpc.recv = AsyncMock( -# return_value=bidi_write_object_response -# ) - -# # Act -# response = await write_obj_stream.recv() - -# # Assert -# write_obj_stream.socket_like_rpc.recv.assert_called_once() -# assert response == bidi_write_object_response - - -# @pytest.mark.asyncio -# @mock.patch( -# "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -# ) -# async def test_recv_without_open_should_raise_error( -# mock_cls_async_bidi_rpc, mock_client -# ): -# """Test that receiving on a stream that is not open raises a ValueError.""" -# # Arrange -# write_obj_stream = await instantiate_write_obj_stream( -# mock_client, mock_cls_async_bidi_rpc, open=False -# ) - -# # Act & Assert -# with pytest.raises(ValueError, match="Stream is not open"): -# await write_obj_stream.recv() +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import unittest.mock as mock +from unittest.mock import AsyncMock, MagicMock +import pytest + +from google.cloud.storage._experimental.asyncio.async_write_object_stream import ( + _AsyncWriteObjectStream, +) +from google.cloud import _storage_v2 + +BUCKET = "my-bucket" +OBJECT = "my-object" +GENERATION = 12345 +WRITE_HANDLE = b"test-handle" +FULL_BUCKET_PATH = f"projects/_/buckets/{BUCKET}" + + +class TestAsyncWriteObjectStream(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.mock_client = MagicMock() + # Mocking transport internal structures + mock_transport = MagicMock() + mock_transport.bidi_write_object = mock.sentinel.bidi_write_object + mock_transport._wrapped_methods = { + mock.sentinel.bidi_write_object: mock.sentinel.wrapped_bidi_write_object + } + self.mock_client._client._transport = mock_transport + + # ------------------------------------------------------------------------- + # Initialization Tests + # ------------------------------------------------------------------------- + + def test_init_basic(self): + stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + self.assertEqual(stream.bucket_name, BUCKET) + self.assertEqual(stream.object_name, OBJECT) + self.assertEqual(stream._full_bucket_name, FULL_BUCKET_PATH) + self.assertEqual(stream.metadata, (("x-goog-request-params", f"bucket={FULL_BUCKET_PATH}"),)) + self.assertFalse(stream.is_stream_open) + + def test_init_raises_value_error(self): + with self.assertRaisesRegex(ValueError, "client must be provided"): + _AsyncWriteObjectStream(None, BUCKET, OBJECT) + with self.assertRaisesRegex(ValueError, "bucket_name must be provided"): + _AsyncWriteObjectStream(self.mock_client, None, OBJECT) + with self.assertRaisesRegex(ValueError, "object_name must be provided"): + _AsyncWriteObjectStream(self.mock_client, BUCKET, None) + + # ------------------------------------------------------------------------- + # Open Stream Tests + # ------------------------------------------------------------------------- + + @mock.patch("google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc") + async def test_open_new_object(self, mock_rpc_cls): + mock_rpc = mock_rpc_cls.return_value + mock_rpc.open = AsyncMock() + + # We don't use spec here to avoid descriptor issues with nested protos + mock_response = MagicMock() + mock_response.persisted_size = 0 + mock_response.resource.generation = GENERATION + mock_response.resource.size = 0 + mock_response.write_handle = WRITE_HANDLE + mock_rpc.recv = AsyncMock(return_value=mock_response) + + stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + await stream.open() + + # Check if BidiRpc was initialized with WriteObjectSpec + call_args = mock_rpc_cls.call_args + initial_request = call_args.kwargs["initial_request"] + self.assertIsNotNone(initial_request.write_object_spec) + self.assertEqual(initial_request.write_object_spec.resource.name, OBJECT) + self.assertTrue(initial_request.write_object_spec.appendable) + + self.assertTrue(stream.is_stream_open) + self.assertEqual(stream.write_handle, WRITE_HANDLE) + self.assertEqual(stream.generation_number, GENERATION) + + @mock.patch("google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc") + async def test_open_existing_object_with_token(self, mock_rpc_cls): + mock_rpc = mock_rpc_cls.return_value + mock_rpc.open = AsyncMock() + + # Ensure resource is None so persisted_size logic doesn't get overwritten by child mocks + mock_response = MagicMock() + mock_response.persisted_size = 1024 + mock_response.resource = None + mock_response.write_handle = WRITE_HANDLE + mock_rpc.recv = AsyncMock(return_value=mock_response) + + stream = _AsyncWriteObjectStream( + self.mock_client, BUCKET, OBJECT, + generation_number=GENERATION, + routing_token="token-123" + ) + await stream.open() + + # Verify AppendObjectSpec attributes + initial_request = mock_rpc_cls.call_args.kwargs["initial_request"] + self.assertIsNotNone(initial_request.append_object_spec) + self.assertEqual(initial_request.append_object_spec.generation, GENERATION) + self.assertEqual(initial_request.append_object_spec.routing_token, "token-123") + self.assertEqual(stream.persisted_size, 1024) + + @mock.patch("google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc") + async def test_open_metadata_merging(self, mock_rpc_cls): + mock_rpc = mock_rpc_cls.return_value + mock_rpc.open = AsyncMock() + mock_rpc.recv = AsyncMock(return_value=MagicMock(resource=None)) + + stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + extra_metadata = [("x-custom", "val"), ("x-goog-request-params", "extra=param")] + + await stream.open(metadata=extra_metadata) + + # Verify that metadata combined bucket and extra params + passed_metadata = mock_rpc_cls.call_args.kwargs["metadata"] + meta_dict = dict(passed_metadata) + self.assertEqual(meta_dict["x-custom"], "val") + # Params should be comma separated + params = meta_dict["x-goog-request-params"] + self.assertIn(f"bucket={FULL_BUCKET_PATH}", params) + self.assertIn("extra=param", params) + + async def test_open_already_open_raises(self): + stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + stream._is_stream_open = True + with self.assertRaisesRegex(ValueError, "already open"): + await stream.open() + + # ------------------------------------------------------------------------- + # Send & Recv & Close Tests + # ------------------------------------------------------------------------- + + @mock.patch("google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc") + async def test_send_and_recv_logic(self, mock_rpc_cls): + # Setup open stream + mock_rpc = mock_rpc_cls.return_value + mock_rpc.open = AsyncMock() + mock_rpc.send = AsyncMock() # Crucial: Must be AsyncMock + mock_rpc.recv = AsyncMock(return_value=MagicMock(resource=None)) + + stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + await stream.open() + + # Test Send + req = _storage_v2.BidiWriteObjectRequest(write_offset=0) + await stream.send(req) + mock_rpc.send.assert_awaited_with(req) + + # Test Recv with state update + mock_response = MagicMock() + mock_response.persisted_size = 5000 + mock_response.write_handle = b"new-handle" + mock_response.resource = None + mock_rpc.recv.return_value = mock_response + + res = await stream.recv() + self.assertEqual(res.persisted_size, 5000) + self.assertEqual(stream.persisted_size, 5000) + self.assertEqual(stream.write_handle, b"new-handle") + + async def test_close_success(self): + stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + stream._is_stream_open = True + stream.socket_like_rpc = AsyncMock() + stream.socket_like_rpc.close = AsyncMock() + + await stream.close() + stream.socket_like_rpc.close.assert_awaited_once() + self.assertFalse(stream.is_stream_open) + + async def test_methods_require_open_raises(self): + stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + with self.assertRaisesRegex(ValueError, "Stream is not open"): + await stream.send(MagicMock()) + with self.assertRaisesRegex(ValueError, "Stream is not open"): + await stream.recv() + with self.assertRaisesRegex(ValueError, "Stream is not open"): + await stream.close() From e8696d708a86ccc31c4e12a7f40eead7aea4d501 Mon Sep 17 00:00:00 2001 From: Pulkit Aggarwal Date: Tue, 13 Jan 2026 09:17:09 +0000 Subject: [PATCH 08/10] fix lint errors --- .../asyncio/async_appendable_object_writer.py | 2 +- .../_experimental/asyncio/retry/_helpers.py | 16 ++-- .../retry/writes_resumption_strategy.py | 1 - .../retry/test_writes_resumption_strategy.py | 73 ++++++-------- .../test_async_appendable_object_writer.py | 96 ++++++++++++++----- .../asyncio/test_async_write_object_stream.py | 37 +++++-- 6 files changed, 138 insertions(+), 87 deletions(-) diff --git a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py index 99cf391b5..6ec1e7315 100644 --- a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py +++ b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py @@ -74,7 +74,7 @@ def _is_write_retryable(exc): exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, exceptions.TooManyRequests, - BidiWriteObjectRedirectedError + BidiWriteObjectRedirectedError, ), ): logger.info(f"Retryable write exception encountered: {exc}") diff --git a/google/cloud/storage/_experimental/asyncio/retry/_helpers.py b/google/cloud/storage/_experimental/asyncio/retry/_helpers.py index dcc830cc5..d9ad2462e 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/_helpers.py +++ b/google/cloud/storage/_experimental/asyncio/retry/_helpers.py @@ -18,7 +18,10 @@ from typing import Tuple, Optional from google.api_core import exceptions -from google.cloud._storage_v2.types import BidiReadObjectRedirectedError, BidiWriteObjectRedirectedError +from google.cloud._storage_v2.types import ( + BidiReadObjectRedirectedError, + BidiWriteObjectRedirectedError, +) from google.rpc import status_pb2 _BIDI_READ_REDIRECTED_TYPE_URL = ( @@ -86,6 +89,7 @@ def _handle_redirect( return routing_token, read_handle + def _extract_bidi_writes_redirect_proto(exc: Exception): grpc_error = None if isinstance(exc, exceptions.Aborted) and exc.errors: @@ -112,14 +116,10 @@ def _extract_bidi_writes_redirect_proto(exc: Exception): status_proto.ParseFromString(status_details_bin) for detail in status_proto.details: if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL: - redirect_proto = ( - BidiWriteObjectRedirectedError.deserialize( - detail.value - ) + redirect_proto = BidiWriteObjectRedirectedError.deserialize( + detail.value ) return redirect_proto except Exception: - logger.error( - "Error unpacking redirect details from gRPC error." - ) + logger.error("Error unpacking redirect details from gRPC error.") pass diff --git a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py index 4ad20662b..e9d48a6be 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py @@ -16,7 +16,6 @@ import google_crc32c from google.api_core import exceptions -from google.rpc import status_pb2 from google.cloud._storage_v2.types import storage as storage_type from google.cloud._storage_v2.types.storage import BidiWriteObjectRedirectedError from google.cloud.storage._experimental.asyncio.retry.base_strategy import ( diff --git a/tests/unit/asyncio/retry/test_writes_resumption_strategy.py b/tests/unit/asyncio/retry/test_writes_resumption_strategy.py index 556920eba..64efedeb0 100644 --- a/tests/unit/asyncio/retry/test_writes_resumption_strategy.py +++ b/tests/unit/asyncio/retry/test_writes_resumption_strategy.py @@ -112,9 +112,7 @@ def test_generate_requests_checksum_verification(self): strategy = self._make_one() chunk_data = b"test_data" mock_buffer = io.BytesIO(chunk_data) - write_state = _WriteState( - chunk_size=10, user_buffer=mock_buffer - ) + write_state = _WriteState(chunk_size=10, user_buffer=mock_buffer) state = {"write_state": write_state} requests = strategy.generate_requests(state) @@ -129,9 +127,7 @@ def test_generate_requests_flush_logic_exact_interval(self): mock_buffer = io.BytesIO(b"A" * 12) # 2 byte chunks, flush every 4 bytes write_state = _WriteState( - chunk_size=2, - user_buffer=mock_buffer, - flush_interval=4 + chunk_size=2, user_buffer=mock_buffer, flush_interval=4 ) state = {"write_state": write_state} @@ -153,9 +149,7 @@ def test_generate_requests_flush_logic_none_interval(self): strategy = self._make_one() mock_buffer = io.BytesIO(b"A" * 10) write_state = _WriteState( - chunk_size=2, - user_buffer=mock_buffer, - flush_interval=None + chunk_size=2, user_buffer=mock_buffer, flush_interval=None ) state = {"write_state": write_state} @@ -170,9 +164,7 @@ def test_generate_requests_flush_logic_data_less_than_interval(self): mock_buffer = io.BytesIO(b"A" * 5) # Flush every 10 bytes write_state = _WriteState( - chunk_size=2, - user_buffer=mock_buffer, - flush_interval=10 + chunk_size=2, user_buffer=mock_buffer, flush_interval=10 ) state = {"write_state": write_state} @@ -188,9 +180,7 @@ def test_generate_requests_honors_finalized_state(self): """If state is already finalized, no requests should be generated.""" strategy = self._make_one() mock_buffer = io.BytesIO(b"data") - write_state = _WriteState( - chunk_size=4, user_buffer=mock_buffer - ) + write_state = _WriteState(chunk_size=4, user_buffer=mock_buffer) write_state.is_finalized = True state = {"write_state": write_state} @@ -203,7 +193,7 @@ async def test_generate_requests_after_failure_and_recovery(self): Verify recovery and resumption flow (Integration of recover + generate). """ strategy = self._make_one() - mock_buffer = io.BytesIO(b"0123456789abcdef") # 16 bytes + mock_buffer = io.BytesIO(b"0123456789abcdef") # 16 bytes mock_spec = storage_type.AppendObjectSpec(object_="test-object") write_state = _WriteState(mock_spec, chunk_size=4, user_buffer=mock_buffer) state = {"write_state": write_state} @@ -246,9 +236,7 @@ async def test_generate_requests_after_failure_and_recovery(self): def test_update_state_from_response_all_fields(self): """Verify all fields from a BidiWriteObjectResponse update the state.""" strategy = self._make_one() - write_state = _WriteState( - chunk_size=4, user_buffer=io.BytesIO() - ) + write_state = _WriteState(chunk_size=4, user_buffer=io.BytesIO()) state = {"write_state": write_state} # 1. Update persisted_size @@ -275,9 +263,7 @@ def test_update_state_from_response_all_fields(self): def test_update_state_from_response_none(self): """Verify None response doesn't crash.""" strategy = self._make_one() - write_state = _WriteState( - chunk_size=4, user_buffer=io.BytesIO() - ) + write_state = _WriteState(chunk_size=4, user_buffer=io.BytesIO()) state = {"write_state": write_state} strategy.update_state_from_response(None, state) self.assertEqual(write_state.persisted_size, 0) @@ -291,9 +277,7 @@ async def test_recover_state_on_failure_rewind_logic(self): """Verify buffer seek and counter resets on generic failure (Non-redirect).""" strategy = self._make_one() mock_buffer = io.BytesIO(b"0123456789") - write_state = _WriteState( - chunk_size=2, user_buffer=mock_buffer - ) + write_state = _WriteState(chunk_size=2, user_buffer=mock_buffer) # Simulate progress: sent 8 bytes, but server only persisted 4 write_state.bytes_sent = 8 @@ -302,7 +286,9 @@ async def test_recover_state_on_failure_rewind_logic(self): mock_buffer.seek(8) # Simulate generic 503 error without trailers - await strategy.recover_state_on_failure(exceptions.ServiceUnavailable("busy"), {"write_state": write_state}) + await strategy.recover_state_on_failure( + exceptions.ServiceUnavailable("busy"), {"write_state": write_state} + ) # Buffer must be seeked back to 4 self.assertEqual(mock_buffer.tell(), 4) @@ -314,12 +300,12 @@ async def test_recover_state_on_failure_rewind_logic(self): async def test_recover_state_on_failure_direct_redirect(self): """Verify handling when the error is a BidiWriteObjectRedirectedError.""" strategy = self._make_one() - write_state = _WriteState( - chunk_size=4, user_buffer=io.BytesIO() - ) + write_state = _WriteState(chunk_size=4, user_buffer=io.BytesIO()) state = {"write_state": write_state} - redirect = BidiWriteObjectRedirectedError(routing_token="tok-1", write_handle=b"h-1") + redirect = BidiWriteObjectRedirectedError( + routing_token="tok-1", write_handle=b"h-1" + ) await strategy.recover_state_on_failure(redirect, state) @@ -330,9 +316,7 @@ async def test_recover_state_on_failure_direct_redirect(self): async def test_recover_state_on_failure_wrapped_redirect(self): """Verify handling when RedirectedError is inside Aborted.errors.""" strategy = self._make_one() - write_state = _WriteState( - chunk_size=4, user_buffer=io.BytesIO() - ) + write_state = _WriteState(chunk_size=4, user_buffer=io.BytesIO()) redirect = BidiWriteObjectRedirectedError(routing_token="tok-wrapped") # google-api-core Aborted often wraps multiple errors @@ -346,9 +330,7 @@ async def test_recover_state_on_failure_wrapped_redirect(self): async def test_recover_state_on_failure_trailer_metadata_redirect(self): """Verify complex parsing from 'grpc-status-details-bin' in trailers.""" strategy = self._make_one() - write_state = _WriteState( - chunk_size=4, user_buffer=io.BytesIO() - ) + write_state = _WriteState(chunk_size=4, user_buffer=io.BytesIO()) # 1. Setup Redirect Proto redirect_proto = BidiWriteObjectRedirectedError(routing_token="metadata-token") @@ -356,29 +338,34 @@ async def test_recover_state_on_failure_trailer_metadata_redirect(self): # 2. Setup Status Proto Detail status = status_pb2.Status() detail = status.details.add() - detail.type_url = "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + detail.type_url = ( + "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + ) # In a real environment, detail.value is the serialized proto detail.value = BidiWriteObjectRedirectedError.to_json(redirect_proto).encode() # 3. Create Mock Error with Trailers mock_error = mock.MagicMock(spec=exceptions.Aborted) - mock_error.errors = [] # No direct errors + mock_error.errors = [] # No direct errors mock_error.trailing_metadata.return_value = [ ("grpc-status-details-bin", status.SerializeToString()) ] # 4. Patch deserialize to handle the binary value - with mock.patch("google.cloud._storage_v2.types.storage.BidiWriteObjectRedirectedError.deserialize", return_value=redirect_proto): - await strategy.recover_state_on_failure(mock_error, {"write_state": write_state}) + with mock.patch( + "google.cloud._storage_v2.types.storage.BidiWriteObjectRedirectedError.deserialize", + return_value=redirect_proto, + ): + await strategy.recover_state_on_failure( + mock_error, {"write_state": write_state} + ) self.assertEqual(write_state.routing_token, "metadata-token") def test_write_state_initialization(self): """Verify WriteState starts with clean counters.""" buffer = io.BytesIO(b"test") - ws = _WriteState( - chunk_size=10, user_buffer=buffer, flush_interval=100 - ) + ws = _WriteState(chunk_size=10, user_buffer=buffer, flush_interval=100) self.assertEqual(ws.persisted_size, 0) self.assertEqual(ws.bytes_sent, 0) diff --git a/tests/unit/asyncio/test_async_appendable_object_writer.py b/tests/unit/asyncio/test_async_appendable_object_writer.py index 2b8680e6e..f75f15c37 100644 --- a/tests/unit/asyncio/test_async_appendable_object_writer.py +++ b/tests/unit/asyncio/test_async_appendable_object_writer.py @@ -60,7 +60,9 @@ def test_aborted_with_trailers(self): # Setup Status with Redirect Detail status = status_pb2.Status() detail = status.details.add() - detail.type_url = "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + detail.type_url = ( + "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + ) # Mock error with trailing_metadata method mock_grpc_error = MagicMock() @@ -83,7 +85,7 @@ def test_non_retryable_errors(self): self.assertFalse(_is_write_retryable(exceptions.NotFound("404"))) -class TestAsyncAppendableObjectWriter(unittest.IsolatedAsyncioTestCase): +class TestAsyncAppendableObjectWriter(unittest.TestCase): def setUp(self): self.mock_client = mock.AsyncMock() # Internal stream class patch @@ -109,9 +111,7 @@ def tearDown(self): self.mock_stream_patcher.stop() def _make_one(self, **kwargs): - return AsyncAppendableObjectWriter( - self.mock_client, BUCKET, OBJECT, **kwargs - ) + return AsyncAppendableObjectWriter(self.mock_client, BUCKET, OBJECT, **kwargs) # ------------------------------------------------------------------------- # Initialization & Configuration Tests @@ -131,14 +131,20 @@ def test_init_with_writer_options(self): def test_init_validation_chunk_size_raises(self): with self.assertRaises(exceptions.OutOfRange): - self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES - 1}) + self._make_one( + writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES - 1} + ) def test_init_validation_multiple_raises(self): with self.assertRaises(exceptions.OutOfRange): - self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES + 1}) + self._make_one( + writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES + 1} + ) def test_init_raises_if_crc32c_missing(self): - with mock.patch("google.cloud.storage._experimental.asyncio._utils.google_crc32c") as mock_crc: + with mock.patch( + "google.cloud.storage._experimental.asyncio._utils.google_crc32c" + ) as mock_crc: mock_crc.implementation = "python" with self.assertRaises(exceptions.FailedPrecondition): self._make_one() @@ -147,12 +153,15 @@ def test_init_raises_if_crc32c_missing(self): # Stream Lifecycle Tests # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_state_lookup_success(self): writer = self._make_one() writer._is_stream_open = True writer.write_obj_stream = self.mock_stream - self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse(persisted_size=100) + self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( + persisted_size=100 + ) size = await writer.state_lookup() @@ -160,11 +169,13 @@ async def test_state_lookup_success(self): self.assertEqual(size, 100) self.assertEqual(writer.persisted_size, 100) + @pytest.mark.asyncio async def test_state_lookup_raises_if_not_open(self): writer = self._make_one() with self.assertRaisesRegex(ValueError, "Stream is not open"): await writer.state_lookup() + @pytest.mark.asyncio async def test_open_success(self): writer = self._make_one() self.mock_stream.generation_number = 456 @@ -178,22 +189,27 @@ async def test_open_success(self): self.assertEqual(writer.write_handle, b"new-h") self.mock_stream.open.assert_awaited_once() + @pytest.mark.asyncio async def test_open_already_open_raises(self): writer = self._make_one() writer._is_stream_open = True with self.assertRaisesRegex(ValueError, "already open"): await writer.open() + @pytest.mark.asyncio def test_on_open_error_redirection(self): """Verify redirect info is extracted from helper.""" writer = self._make_one() redirect = BidiWriteObjectRedirectedError( routing_token="rt1", write_handle=storage_type.BidiWriteHandle(handle=b"h1"), - generation=777 + generation=777, ) - with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._extract_bidi_writes_redirect_proto", return_value=redirect): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._extract_bidi_writes_redirect_proto", + return_value=redirect, + ): writer._on_open_error(exceptions.Aborted("redirect")) self.assertEqual(writer._routing_token, "rt1") @@ -204,6 +220,7 @@ def test_on_open_error_redirection(self): # Append & Integration Tests # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_append_integration_basic(self): """Verify append orchestrates manager and drives the internal generator.""" writer = self._make_one() @@ -213,7 +230,10 @@ async def test_append_integration_basic(self): data = b"test-data" - with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager" + ) as MockManager: + async def mock_execute(state, policy): factory = MockManager.call_args[0][1] dummy_reqs = [storage_type.BidiWriteObjectRequest()] @@ -222,11 +242,12 @@ async def mock_execute(state, policy): self.mock_stream.recv.side_effect = [ storage_type.BidiWriteObjectResponse( persisted_size=len(data), - write_handle=storage_type.BidiWriteHandle(handle=b"h2") + write_handle=storage_type.BidiWriteHandle(handle=b"h2"), ), - None + None, ] - async for _ in gen: pass + async for _ in gen: + pass MockManager.return_value.execute.side_effect = mock_execute await writer.append(data) @@ -236,6 +257,7 @@ async def mock_execute(state, policy): self.assertTrue(sent_req.state_lookup) self.assertTrue(sent_req.flush) + @pytest.mark.asyncio async def test_append_recovery_reopens_stream(self): """Verifies re-opening logic on retry.""" writer = self._make_one(write_handle=b"h1") @@ -250,18 +272,26 @@ async def mock_open(metadata=None): writer.persisted_size = 5 writer.write_handle = b"h_recovered" - with mock.patch.object(writer, "open", side_effect=mock_open) as mock_writer_open: - with mock.patch("google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager") as MockManager: + with mock.patch.object( + writer, "open", side_effect=mock_open + ) as mock_writer_open: + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._BidiStreamRetryManager" + ) as MockManager: + async def mock_execute(state, policy): factory = MockManager.call_args[0][1] # Simulate Attempt 1 fail gen1 = factory([], state) - try: await gen1.__anext__() - except: pass + try: + await gen1.__anext__() + except Exception: + pass # Simulate Attempt 2 gen2 = factory([], state) self.mock_stream.recv.return_value = None - async for _ in gen2: pass + async for _ in gen2: + pass MockManager.return_value.execute.side_effect = mock_execute await writer.append(b"0123456789") @@ -270,6 +300,7 @@ async def mock_execute(state, policy): mock_writer_open.assert_awaited() self.assertEqual(writer.persisted_size, 5) + @pytest.mark.asyncio async def test_append_unimplemented_string_raises(self): writer = self._make_one() with self.assertRaises(NotImplementedError): @@ -279,19 +310,23 @@ async def test_append_unimplemented_string_raises(self): # Flush, Close, Finalize # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_flush_resets_counters(self): writer = self._make_one() writer._is_stream_open = True writer.write_obj_stream = self.mock_stream writer.bytes_appended_since_last_flush = 100 - self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse(persisted_size=200) + self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( + persisted_size=200 + ) await writer.flush() self.assertEqual(writer.bytes_appended_since_last_flush, 0) self.assertEqual(writer.persisted_size, 200) + @pytest.mark.asyncio async def test_simple_flush(self): writer = self._make_one() writer._is_stream_open = True @@ -300,9 +335,12 @@ async def test_simple_flush(self): await writer.simple_flush() - self.mock_stream.send.assert_awaited_with(storage_type.BidiWriteObjectRequest(flush=True)) + self.mock_stream.send.assert_awaited_with( + storage_type.BidiWriteObjectRequest(flush=True) + ) self.assertEqual(writer.bytes_appended_since_last_flush, 0) + @pytest.mark.asyncio async def test_close_without_finalize(self): writer = self._make_one() writer._is_stream_open = True @@ -315,22 +353,28 @@ async def test_close_without_finalize(self): self.assertFalse(writer._is_stream_open) self.assertEqual(size, 50) + @pytest.mark.asyncio async def test_finalize_lifecycle(self): writer = self._make_one() writer._is_stream_open = True writer.write_obj_stream = self.mock_stream resource = storage_type.Object(size=999) - self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse(resource=resource) + self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( + resource=resource + ) res = await writer.finalize() self.assertEqual(res, resource) self.assertEqual(writer.persisted_size, 999) - self.mock_stream.send.assert_awaited_with(storage_type.BidiWriteObjectRequest(finish_write=True)) + self.mock_stream.send.assert_awaited_with( + storage_type.BidiWriteObjectRequest(finish_write=True) + ) self.mock_stream.close.assert_awaited() self.assertFalse(writer._is_stream_open) + @pytest.mark.asyncio async def test_close_with_finalize_on_close(self): writer = self._make_one() writer._is_stream_open = True @@ -343,6 +387,7 @@ async def test_close_with_finalize_on_close(self): # Helper Integration Tests # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_append_from_file_integration(self): writer = self._make_one() writer._is_stream_open = True @@ -353,6 +398,7 @@ async def test_append_from_file_integration(self): self.assertEqual(writer.append.await_count, 3) + @pytest.mark.asyncio async def test_methods_require_open_stream_raises(self): writer = self._make_one() methods = [ @@ -361,7 +407,7 @@ async def test_methods_require_open_stream_raises(self): writer.simple_flush(), writer.close(), writer.finalize(), - writer.state_lookup() + writer.state_lookup(), ] for coro in methods: with self.assertRaisesRegex(ValueError, "Stream is not open"): diff --git a/tests/unit/asyncio/test_async_write_object_stream.py b/tests/unit/asyncio/test_async_write_object_stream.py index 283bf9b42..42611b0b8 100644 --- a/tests/unit/asyncio/test_async_write_object_stream.py +++ b/tests/unit/asyncio/test_async_write_object_stream.py @@ -29,7 +29,7 @@ FULL_BUCKET_PATH = f"projects/_/buckets/{BUCKET}" -class TestAsyncWriteObjectStream(unittest.IsolatedAsyncioTestCase): +class TestAsyncWriteObjectStream(unittest.TestCase): def setUp(self): self.mock_client = MagicMock() # Mocking transport internal structures @@ -49,7 +49,9 @@ def test_init_basic(self): self.assertEqual(stream.bucket_name, BUCKET) self.assertEqual(stream.object_name, OBJECT) self.assertEqual(stream._full_bucket_name, FULL_BUCKET_PATH) - self.assertEqual(stream.metadata, (("x-goog-request-params", f"bucket={FULL_BUCKET_PATH}"),)) + self.assertEqual( + stream.metadata, (("x-goog-request-params", f"bucket={FULL_BUCKET_PATH}"),) + ) self.assertFalse(stream.is_stream_open) def test_init_raises_value_error(self): @@ -64,7 +66,10 @@ def test_init_raises_value_error(self): # Open Stream Tests # ------------------------------------------------------------------------- - @mock.patch("google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc") + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" + ) + @pytest.mark.asyncio async def test_open_new_object(self, mock_rpc_cls): mock_rpc = mock_rpc_cls.return_value mock_rpc.open = AsyncMock() @@ -91,7 +96,10 @@ async def test_open_new_object(self, mock_rpc_cls): self.assertEqual(stream.write_handle, WRITE_HANDLE) self.assertEqual(stream.generation_number, GENERATION) - @mock.patch("google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc") + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" + ) + @pytest.mark.asyncio async def test_open_existing_object_with_token(self, mock_rpc_cls): mock_rpc = mock_rpc_cls.return_value mock_rpc.open = AsyncMock() @@ -104,9 +112,11 @@ async def test_open_existing_object_with_token(self, mock_rpc_cls): mock_rpc.recv = AsyncMock(return_value=mock_response) stream = _AsyncWriteObjectStream( - self.mock_client, BUCKET, OBJECT, + self.mock_client, + BUCKET, + OBJECT, generation_number=GENERATION, - routing_token="token-123" + routing_token="token-123", ) await stream.open() @@ -117,7 +127,10 @@ async def test_open_existing_object_with_token(self, mock_rpc_cls): self.assertEqual(initial_request.append_object_spec.routing_token, "token-123") self.assertEqual(stream.persisted_size, 1024) - @mock.patch("google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc") + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" + ) + @pytest.mark.asyncio async def test_open_metadata_merging(self, mock_rpc_cls): mock_rpc = mock_rpc_cls.return_value mock_rpc.open = AsyncMock() @@ -137,6 +150,7 @@ async def test_open_metadata_merging(self, mock_rpc_cls): self.assertIn(f"bucket={FULL_BUCKET_PATH}", params) self.assertIn("extra=param", params) + @pytest.mark.asyncio async def test_open_already_open_raises(self): stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) stream._is_stream_open = True @@ -147,12 +161,15 @@ async def test_open_already_open_raises(self): # Send & Recv & Close Tests # ------------------------------------------------------------------------- - @mock.patch("google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc") + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" + ) + @pytest.mark.asyncio async def test_send_and_recv_logic(self, mock_rpc_cls): # Setup open stream mock_rpc = mock_rpc_cls.return_value mock_rpc.open = AsyncMock() - mock_rpc.send = AsyncMock() # Crucial: Must be AsyncMock + mock_rpc.send = AsyncMock() # Crucial: Must be AsyncMock mock_rpc.recv = AsyncMock(return_value=MagicMock(resource=None)) stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) @@ -175,6 +192,7 @@ async def test_send_and_recv_logic(self, mock_rpc_cls): self.assertEqual(stream.persisted_size, 5000) self.assertEqual(stream.write_handle, b"new-handle") + @pytest.mark.asyncio async def test_close_success(self): stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) stream._is_stream_open = True @@ -185,6 +203,7 @@ async def test_close_success(self): stream.socket_like_rpc.close.assert_awaited_once() self.assertFalse(stream.is_stream_open) + @pytest.mark.asyncio async def test_methods_require_open_raises(self): stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) with self.assertRaisesRegex(ValueError, "Stream is not open"): From 8c81047445f46e281c9c76db94ec6d2aa7f8aa47 Mon Sep 17 00:00:00 2001 From: Pulkit Aggarwal Date: Tue, 13 Jan 2026 10:53:17 +0000 Subject: [PATCH 09/10] refactor unit tests --- .../retry/writes_resumption_strategy.py | 27 +- .../retry/test_writes_resumption_strategy.py | 188 +++++------- .../test_async_appendable_object_writer.py | 290 +++++++++--------- tests/unit/asyncio/test_async_grpc_client.py | 8 +- .../asyncio/test_async_write_object_stream.py | 125 ++++---- 5 files changed, 313 insertions(+), 325 deletions(-) diff --git a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py index e9d48a6be..bef72ce64 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py @@ -135,24 +135,19 @@ async def recover_state_on_failure( """ write_state: _WriteState = state["write_state"] - grpc_error = None - if isinstance(error, exceptions.Aborted) and error.errors: - grpc_error = error.errors[0] - - if grpc_error: - # Extract routing token and potentially a new write handle for redirection. - if isinstance(grpc_error, BidiWriteObjectRedirectedError): - write_state.routing_token = grpc_error.routing_token - if grpc_error.write_handle: - write_state.write_handle = grpc_error.write_handle - return + redirect_proto = None + if isinstance(error, BidiWriteObjectRedirectedError): + redirect_proto = error + else: redirect_proto = _extract_bidi_writes_redirect_proto(error) - if redirect_proto: - if redirect_proto.routing_token: - write_state.routing_token = redirect_proto.routing_token - if redirect_proto.write_handle: - write_state.write_handle = redirect_proto.write_handle + + # Extract routing token and potentially a new write handle for redirection. + if redirect_proto: + if redirect_proto.routing_token: + write_state.routing_token = redirect_proto.routing_token + if redirect_proto.write_handle: + write_state.write_handle = redirect_proto.write_handle # We must assume any data sent beyond 'persisted_size' was lost. # Reset the user buffer to the last known good byte confirmed by the server. diff --git a/tests/unit/asyncio/retry/test_writes_resumption_strategy.py b/tests/unit/asyncio/retry/test_writes_resumption_strategy.py index 64efedeb0..138035e3b 100644 --- a/tests/unit/asyncio/retry/test_writes_resumption_strategy.py +++ b/tests/unit/asyncio/retry/test_writes_resumption_strategy.py @@ -13,7 +13,6 @@ # limitations under the License. import io -import unittest import unittest.mock as mock from datetime import datetime @@ -30,17 +29,21 @@ from google.cloud._storage_v2.types.storage import BidiWriteObjectRedirectedError -class TestWriteResumptionStrategy(unittest.TestCase): - def _make_one(self): - return _WriteResumptionStrategy() +@pytest.fixture +def strategy(): + """Fixture to provide a WriteResumptionStrategy instance.""" + return _WriteResumptionStrategy() + + +class TestWriteResumptionStrategy: + """Test suite for WriteResumptionStrategy.""" # ------------------------------------------------------------------------- # Tests for generate_requests # ------------------------------------------------------------------------- - def test_generate_requests_initial_chunking(self): + def test_generate_requests_initial_chunking(self, strategy): """Verify initial data generation starts at offset 0 and chunks correctly.""" - strategy = self._make_one() mock_buffer = io.BytesIO(b"abcdefghij") write_state = _WriteState(chunk_size=3, user_buffer=mock_buffer) state = {"write_state": write_state} @@ -48,30 +51,29 @@ def test_generate_requests_initial_chunking(self): requests = strategy.generate_requests(state) # Expected: 4 requests (3, 3, 3, 1) - self.assertEqual(len(requests), 4) + assert len(requests) == 4 # Verify Request 1 - self.assertEqual(requests[0].write_offset, 0) - self.assertEqual(requests[0].checksummed_data.content, b"abc") + assert requests[0].write_offset == 0 + assert requests[0].checksummed_data.content == b"abc" # Verify Request 2 - self.assertEqual(requests[1].write_offset, 3) - self.assertEqual(requests[1].checksummed_data.content, b"def") + assert requests[1].write_offset == 3 + assert requests[1].checksummed_data.content == b"def" # Verify Request 3 - self.assertEqual(requests[2].write_offset, 6) - self.assertEqual(requests[2].checksummed_data.content, b"ghi") + assert requests[2].write_offset == 6 + assert requests[2].checksummed_data.content == b"ghi" # Verify Request 4 - self.assertEqual(requests[3].write_offset, 9) - self.assertEqual(requests[3].checksummed_data.content, b"j") + assert requests[3].write_offset == 9 + assert requests[3].checksummed_data.content == b"j" - def test_generate_requests_resumption(self): + def test_generate_requests_resumption(self, strategy): """ Verify request generation when resuming. The strategy should generate chunks starting from the current 'bytes_sent'. """ - strategy = self._make_one() mock_buffer = io.BytesIO(b"0123456789") write_state = _WriteState(chunk_size=4, user_buffer=mock_buffer) @@ -86,30 +88,28 @@ def test_generate_requests_resumption(self): requests = strategy.generate_requests(state) # Since 4 bytes are done, we expect remaining 6 bytes: [4 bytes, 2 bytes] - self.assertEqual(len(requests), 2) + assert len(requests) == 2 # Check first generated request starts at offset 4 - self.assertEqual(requests[0].write_offset, 4) - self.assertEqual(requests[0].checksummed_data.content, b"4567") + assert requests[0].write_offset == 4 + assert requests[0].checksummed_data.content == b"4567" # Check second generated request starts at offset 8 - self.assertEqual(requests[1].write_offset, 8) - self.assertEqual(requests[1].checksummed_data.content, b"89") + assert requests[1].write_offset == 8 + assert requests[1].checksummed_data.content == b"89" - def test_generate_requests_empty_file(self): + def test_generate_requests_empty_file(self, strategy): """Verify request sequence for an empty file.""" - strategy = self._make_one() mock_buffer = io.BytesIO(b"") write_state = _WriteState(chunk_size=4, user_buffer=mock_buffer) state = {"write_state": write_state} requests = strategy.generate_requests(state) - self.assertEqual(len(requests), 0) + assert len(requests) == 0 - def test_generate_requests_checksum_verification(self): + def test_generate_requests_checksum_verification(self, strategy): """Verify CRC32C is calculated correctly for each chunk.""" - strategy = self._make_one() chunk_data = b"test_data" mock_buffer = io.BytesIO(chunk_data) write_state = _WriteState(chunk_size=10, user_buffer=mock_buffer) @@ -119,11 +119,10 @@ def test_generate_requests_checksum_verification(self): expected_crc = google_crc32c.Checksum(chunk_data).digest() expected_int = int.from_bytes(expected_crc, "big") - self.assertEqual(requests[0].checksummed_data.crc32c, expected_int) + assert requests[0].checksummed_data.crc32c == expected_int - def test_generate_requests_flush_logic_exact_interval(self): + def test_generate_requests_flush_logic_exact_interval(self, strategy): """Verify the flush bit is set exactly when the interval is reached.""" - strategy = self._make_one() mock_buffer = io.BytesIO(b"A" * 12) # 2 byte chunks, flush every 4 bytes write_state = _WriteState( @@ -134,19 +133,22 @@ def test_generate_requests_flush_logic_exact_interval(self): requests = strategy.generate_requests(state) # Request index 1 (4 bytes total) should have flush=True - self.assertFalse(requests[0].flush) - self.assertTrue(requests[1].flush) + assert requests[0].flush == False + assert requests[1].flush == True + + # Request index 2 (8 bytes total) should have flush=True + assert requests[2].flush == False + assert requests[3].flush == True - # Request index 3 (8 bytes total) should have flush=True - self.assertFalse(requests[2].flush) - self.assertTrue(requests[3].flush) + # Request index 3 (12 bytes total) should have flush=True + assert requests[4].flush == False + assert requests[5].flush == True # Verify counter reset in state - self.assertEqual(write_state.bytes_since_last_flush, 0) + assert write_state.bytes_since_last_flush == 0 - def test_generate_requests_flush_logic_none_interval(self): + def test_generate_requests_flush_logic_none_interval(self, strategy): """Verify flush is never set if interval is None.""" - strategy = self._make_one() mock_buffer = io.BytesIO(b"A" * 10) write_state = _WriteState( chunk_size=2, user_buffer=mock_buffer, flush_interval=None @@ -156,11 +158,10 @@ def test_generate_requests_flush_logic_none_interval(self): requests = strategy.generate_requests(state) for req in requests: - self.assertFalse(req.flush) + assert req.flush == False - def test_generate_requests_flush_logic_data_less_than_interval(self): + def test_generate_requests_flush_logic_data_less_than_interval(self, strategy): """Verify flush is not set if data sent is less than interval.""" - strategy = self._make_one() mock_buffer = io.BytesIO(b"A" * 5) # Flush every 10 bytes write_state = _WriteState( @@ -172,30 +173,27 @@ def test_generate_requests_flush_logic_data_less_than_interval(self): # Total 5 bytes < 10 bytes interval for req in requests: - self.assertFalse(req.flush) + assert req.flush == False - self.assertEqual(write_state.bytes_since_last_flush, 5) + assert write_state.bytes_since_last_flush == 5 - def test_generate_requests_honors_finalized_state(self): + def test_generate_requests_honors_finalized_state(self, strategy): """If state is already finalized, no requests should be generated.""" - strategy = self._make_one() mock_buffer = io.BytesIO(b"data") write_state = _WriteState(chunk_size=4, user_buffer=mock_buffer) write_state.is_finalized = True state = {"write_state": write_state} requests = strategy.generate_requests(state) - self.assertEqual(len(requests), 0) + assert len(requests) == 0 @pytest.mark.asyncio - async def test_generate_requests_after_failure_and_recovery(self): + async def test_generate_requests_after_failure_and_recovery(self, strategy): """ Verify recovery and resumption flow (Integration of recover + generate). """ - strategy = self._make_one() mock_buffer = io.BytesIO(b"0123456789abcdef") # 16 bytes - mock_spec = storage_type.AppendObjectSpec(object_="test-object") - write_state = _WriteState(mock_spec, chunk_size=4, user_buffer=mock_buffer) + write_state = _WriteState(chunk_size=4, user_buffer=mock_buffer) state = {"write_state": write_state} # Simulate initial progress: sent 8 bytes @@ -215,27 +213,26 @@ async def test_generate_requests_after_failure_and_recovery(self): # Assertions after recovery # 1. Buffer should rewind to persisted_size (4) - self.assertEqual(mock_buffer.tell(), 4) + assert mock_buffer.tell() == 4 # 2. bytes_sent should track persisted_size (4) - self.assertEqual(write_state.bytes_sent, 4) + assert write_state.bytes_sent == 4 requests = strategy.generate_requests(state) # Remaining data from offset 4 to 16 (12 bytes total) # Chunks: [4-8], [8-12], [12-16] - self.assertEqual(len(requests), 3) + assert len(requests) == 3 # Verify resumption offset - self.assertEqual(requests[0].write_offset, 4) - self.assertEqual(requests[0].checksummed_data.content, b"4567") + assert requests[0].write_offset == 4 + assert requests[0].checksummed_data.content == b"4567" # ------------------------------------------------------------------------- # Tests for update_state_from_response # ------------------------------------------------------------------------- - def test_update_state_from_response_all_fields(self): + def test_update_state_from_response_all_fields(self, strategy): """Verify all fields from a BidiWriteObjectResponse update the state.""" - strategy = self._make_one() write_state = _WriteState(chunk_size=4, user_buffer=io.BytesIO()) state = {"write_state": write_state} @@ -243,39 +240,37 @@ def test_update_state_from_response_all_fields(self): strategy.update_state_from_response( storage_type.BidiWriteObjectResponse(persisted_size=123), state ) - self.assertEqual(write_state.persisted_size, 123) + assert write_state.persisted_size == 123 # 2. Update write_handle handle = storage_type.BidiWriteHandle(handle=b"new-handle") strategy.update_state_from_response( storage_type.BidiWriteObjectResponse(write_handle=handle), state ) - self.assertEqual(write_state.write_handle, handle) + assert write_state.write_handle == handle # 3. Update from Resource (finalization) resource = storage_type.Object(size=1000, finalize_time=datetime.now()) strategy.update_state_from_response( storage_type.BidiWriteObjectResponse(resource=resource), state ) - self.assertEqual(write_state.persisted_size, 1000) - self.assertTrue(write_state.is_finalized) + assert write_state.persisted_size == 1000 + assert write_state.is_finalized - def test_update_state_from_response_none(self): + def test_update_state_from_response_none(self, strategy): """Verify None response doesn't crash.""" - strategy = self._make_one() write_state = _WriteState(chunk_size=4, user_buffer=io.BytesIO()) state = {"write_state": write_state} strategy.update_state_from_response(None, state) - self.assertEqual(write_state.persisted_size, 0) + assert write_state.persisted_size == 0 # ------------------------------------------------------------------------- # Tests for recover_state_on_failure # ------------------------------------------------------------------------- @pytest.mark.asyncio - async def test_recover_state_on_failure_rewind_logic(self): + async def test_recover_state_on_failure_rewind_logic(self, strategy): """Verify buffer seek and counter resets on generic failure (Non-redirect).""" - strategy = self._make_one() mock_buffer = io.BytesIO(b"0123456789") write_state = _WriteState(chunk_size=2, user_buffer=mock_buffer) @@ -291,31 +286,29 @@ async def test_recover_state_on_failure_rewind_logic(self): ) # Buffer must be seeked back to 4 - self.assertEqual(mock_buffer.tell(), 4) - self.assertEqual(write_state.bytes_sent, 4) + assert mock_buffer.tell() == 4 + assert write_state.bytes_sent == 4 # Flush counter must be reset to avoid incorrect firing after resume - self.assertEqual(write_state.bytes_since_last_flush, 0) + assert write_state.bytes_since_last_flush == 0 @pytest.mark.asyncio - async def test_recover_state_on_failure_direct_redirect(self): + async def test_recover_state_on_failure_direct_redirect(self, strategy): """Verify handling when the error is a BidiWriteObjectRedirectedError.""" - strategy = self._make_one() write_state = _WriteState(chunk_size=4, user_buffer=io.BytesIO()) state = {"write_state": write_state} redirect = BidiWriteObjectRedirectedError( - routing_token="tok-1", write_handle=b"h-1" + routing_token="tok-1", write_handle=storage_type.BidiWriteHandle(handle=b"h-1"), ) await strategy.recover_state_on_failure(redirect, state) - self.assertEqual(write_state.routing_token, "tok-1") - self.assertEqual(write_state.write_handle, b"h-1") + assert write_state.routing_token == "tok-1" + assert write_state.write_handle.handle == b"h-1" @pytest.mark.asyncio - async def test_recover_state_on_failure_wrapped_redirect(self): + async def test_recover_state_on_failure_wrapped_redirect(self, strategy): """Verify handling when RedirectedError is inside Aborted.errors.""" - strategy = self._make_one() write_state = _WriteState(chunk_size=4, user_buffer=io.BytesIO()) redirect = BidiWriteObjectRedirectedError(routing_token="tok-wrapped") @@ -324,51 +317,38 @@ async def test_recover_state_on_failure_wrapped_redirect(self): await strategy.recover_state_on_failure(error, {"write_state": write_state}) - self.assertEqual(write_state.routing_token, "tok-wrapped") + assert write_state.routing_token == "tok-wrapped" @pytest.mark.asyncio - async def test_recover_state_on_failure_trailer_metadata_redirect(self): + async def test_recover_state_on_failure_trailer_metadata_redirect(self, strategy): """Verify complex parsing from 'grpc-status-details-bin' in trailers.""" - strategy = self._make_one() write_state = _WriteState(chunk_size=4, user_buffer=io.BytesIO()) - # 1. Setup Redirect Proto redirect_proto = BidiWriteObjectRedirectedError(routing_token="metadata-token") - - # 2. Setup Status Proto Detail status = status_pb2.Status() detail = status.details.add() - detail.type_url = ( - "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" - ) - # In a real environment, detail.value is the serialized proto - detail.value = BidiWriteObjectRedirectedError.to_json(redirect_proto).encode() + detail.type_url = "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + detail.value = BidiWriteObjectRedirectedError.serialize(redirect_proto) - # 3. Create Mock Error with Trailers - mock_error = mock.MagicMock(spec=exceptions.Aborted) - mock_error.errors = [] # No direct errors + # FIX: No spec= here, because Aborted doesn't have trailing_metadata in its base definition + mock_error = mock.MagicMock() + mock_error.errors = [] mock_error.trailing_metadata.return_value = [ ("grpc-status-details-bin", status.SerializeToString()) ] - # 4. Patch deserialize to handle the binary value - with mock.patch( - "google.cloud._storage_v2.types.storage.BidiWriteObjectRedirectedError.deserialize", - return_value=redirect_proto, - ): - await strategy.recover_state_on_failure( - mock_error, {"write_state": write_state} - ) + with mock.patch("google.cloud.storage._experimental.asyncio.retry.writes_resumption_strategy._extract_bidi_writes_redirect_proto", return_value=redirect_proto): + await strategy.recover_state_on_failure(mock_error, {"write_state": write_state}) - self.assertEqual(write_state.routing_token, "metadata-token") + assert write_state.routing_token == "metadata-token" def test_write_state_initialization(self): """Verify WriteState starts with clean counters.""" buffer = io.BytesIO(b"test") ws = _WriteState(chunk_size=10, user_buffer=buffer, flush_interval=100) - self.assertEqual(ws.persisted_size, 0) - self.assertEqual(ws.bytes_sent, 0) - self.assertEqual(ws.bytes_since_last_flush, 0) - self.assertEqual(ws.flush_interval, 100) - self.assertFalse(ws.is_finalized) + assert ws.persisted_size == 0 + assert ws.bytes_sent == 0 + assert ws.bytes_since_last_flush == 0 + assert ws.flush_interval == 100 + assert not ws.is_finalized diff --git a/tests/unit/asyncio/test_async_appendable_object_writer.py b/tests/unit/asyncio/test_async_appendable_object_writer.py index f75f15c37..9f8fc3c1e 100644 --- a/tests/unit/asyncio/test_async_appendable_object_writer.py +++ b/tests/unit/asyncio/test_async_appendable_object_writer.py @@ -38,25 +38,25 @@ EIGHT_MIB = 8 * 1024 * 1024 -class TestIsWriteRetryable(unittest.TestCase): +class TestIsWriteRetryable: """Exhaustive tests for retry predicate logic.""" - def test_standard_transient_errors(self): + def test_standard_transient_errors(self, mock_appendable_writer): for exc in [ exceptions.InternalServerError("500"), exceptions.ServiceUnavailable("503"), exceptions.DeadlineExceeded("timeout"), exceptions.TooManyRequests("429"), ]: - self.assertTrue(_is_write_retryable(exc)) + assert _is_write_retryable(exc) - def test_aborted_with_redirect_proto(self): + def test_aborted_with_redirect_proto(self, mock_appendable_writer): # Direct redirect error wrapped in Aborted redirect = BidiWriteObjectRedirectedError(routing_token="token") exc = exceptions.Aborted("aborted", errors=[redirect]) - self.assertTrue(_is_write_retryable(exc)) + assert _is_write_retryable(exc) - def test_aborted_with_trailers(self): + def test_aborted_with_trailers(self, mock_appendable_writer): # Setup Status with Redirect Detail status = status_pb2.Status() detail = status.details.add() @@ -72,134 +72,144 @@ def test_aborted_with_trailers(self): # Aborted wraps the grpc error exc = exceptions.Aborted("aborted", errors=[mock_grpc_error]) - self.assertTrue(_is_write_retryable(exc)) + assert _is_write_retryable(exc) - def test_aborted_without_metadata(self): + def test_aborted_without_metadata(self, mock_appendable_writer): mock_grpc_error = MagicMock() mock_grpc_error.trailing_metadata.return_value = [] exc = exceptions.Aborted("bare aborted", errors=[mock_grpc_error]) - self.assertFalse(_is_write_retryable(exc)) - - def test_non_retryable_errors(self): - self.assertFalse(_is_write_retryable(exceptions.BadRequest("400"))) - self.assertFalse(_is_write_retryable(exceptions.NotFound("404"))) - - -class TestAsyncAppendableObjectWriter(unittest.TestCase): - def setUp(self): - self.mock_client = mock.AsyncMock() - # Internal stream class patch - self.mock_stream_patcher = mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" - ) - self.mock_stream_cls = self.mock_stream_patcher.start() - self.mock_stream = self.mock_stream_cls.return_value - - # Configure all async methods explicitly - self.mock_stream.open = AsyncMock() - self.mock_stream.close = AsyncMock() - self.mock_stream.send = AsyncMock() - self.mock_stream.recv = AsyncMock() - - # Default mock properties - self.mock_stream.is_stream_open = False - self.mock_stream.persisted_size = 0 - self.mock_stream.generation_number = GENERATION - self.mock_stream.write_handle = WRITE_HANDLE - - def tearDown(self): - self.mock_stream_patcher.stop() - - def _make_one(self, **kwargs): - return AsyncAppendableObjectWriter(self.mock_client, BUCKET, OBJECT, **kwargs) + assert not _is_write_retryable(exc) + + def test_non_retryable_errors(self, mock_appendable_writer): + assert not _is_write_retryable(exceptions.BadRequest("400")) + assert not _is_write_retryable(exceptions.NotFound("404")) + + +@pytest.fixture +def mock_appendable_writer(): + """Fixture to provide a mock AsyncAppendableObjectWriter setup.""" + mock_client = mock.AsyncMock() + # Internal stream class patch + stream_patcher = mock.patch( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" + ) + mock_stream_cls = stream_patcher.start() + mock_stream = mock_stream_cls.return_value + + # Configure all async methods explicitly + mock_stream.open = AsyncMock() + mock_stream.close = AsyncMock() + mock_stream.send = AsyncMock() + mock_stream.recv = AsyncMock() + + # Default mock properties + mock_stream.is_stream_open = False + mock_stream.persisted_size = 0 + mock_stream.generation_number = GENERATION + mock_stream.write_handle = WRITE_HANDLE + + yield { + "mock_client": mock_client, + "mock_stream_cls": mock_stream_cls, + "mock_stream": mock_stream, + "stream_patcher": stream_patcher, + } + + stream_patcher.stop() + + +class TestAsyncAppendableObjectWriter: + def _make_one(self, mock_client, **kwargs): + return AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT, **kwargs) # ------------------------------------------------------------------------- # Initialization & Configuration Tests # ------------------------------------------------------------------------- - def test_init_defaults(self): - writer = self._make_one() - self.assertEqual(writer.bucket_name, BUCKET) - self.assertEqual(writer.object_name, OBJECT) - self.assertIsNone(writer.persisted_size) - self.assertEqual(writer.bytes_appended_since_last_flush, 0) - self.assertEqual(writer.flush_interval, _DEFAULT_FLUSH_INTERVAL_BYTES) + def test_init_defaults(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + assert writer.bucket_name == BUCKET + assert writer.object_name == OBJECT + assert writer.persisted_size is None + assert writer.bytes_appended_since_last_flush == 0 + assert writer.flush_interval == _DEFAULT_FLUSH_INTERVAL_BYTES - def test_init_with_writer_options(self): - writer = self._make_one(writer_options={"FLUSH_INTERVAL_BYTES": EIGHT_MIB}) - self.assertEqual(writer.flush_interval, EIGHT_MIB) + def test_init_with_writer_options(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"], writer_options={"FLUSH_INTERVAL_BYTES": EIGHT_MIB}) + assert writer.flush_interval == EIGHT_MIB - def test_init_validation_chunk_size_raises(self): - with self.assertRaises(exceptions.OutOfRange): + def test_init_validation_chunk_size_raises(self, mock_appendable_writer): + with pytest.raises(exceptions.OutOfRange): self._make_one( + mock_appendable_writer["mock_client"], writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES - 1} ) - def test_init_validation_multiple_raises(self): - with self.assertRaises(exceptions.OutOfRange): + def test_init_validation_multiple_raises(self, mock_appendable_writer): + with pytest.raises(exceptions.OutOfRange): self._make_one( + mock_appendable_writer["mock_client"], writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES + 1} ) - def test_init_raises_if_crc32c_missing(self): + def test_init_raises_if_crc32c_missing(self, mock_appendable_writer): with mock.patch( "google.cloud.storage._experimental.asyncio._utils.google_crc32c" ) as mock_crc: mock_crc.implementation = "python" - with self.assertRaises(exceptions.FailedPrecondition): - self._make_one() + with pytest.raises(exceptions.FailedPrecondition): + self._make_one(mock_appendable_writer["mock_client"]) # ------------------------------------------------------------------------- # Stream Lifecycle Tests # ------------------------------------------------------------------------- @pytest.mark.asyncio - async def test_state_lookup_success(self): - writer = self._make_one() + async def test_state_lookup_success(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream + writer.write_obj_stream = mock_appendable_writer['mock_stream'] - self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( + mock_appendable_writer['mock_stream'].recv.return_value = storage_type.BidiWriteObjectResponse( persisted_size=100 ) size = await writer.state_lookup() - self.mock_stream.send.assert_awaited_once() - self.assertEqual(size, 100) - self.assertEqual(writer.persisted_size, 100) + mock_appendable_writer['mock_stream'].send.assert_awaited_once() + assert size == 100 + assert writer.persisted_size == 100 @pytest.mark.asyncio - async def test_state_lookup_raises_if_not_open(self): - writer = self._make_one() - with self.assertRaisesRegex(ValueError, "Stream is not open"): + async def test_state_lookup_raises_if_not_open(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) + with pytest.raises(ValueError, match="Stream is not open"): await writer.state_lookup() @pytest.mark.asyncio - async def test_open_success(self): - writer = self._make_one() - self.mock_stream.generation_number = 456 - self.mock_stream.write_handle = b"new-h" - self.mock_stream.persisted_size = 0 + async def test_open_success(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) + mock_appendable_writer['mock_stream'].generation_number = 456 + mock_appendable_writer['mock_stream'].write_handle = b"new-h" + mock_appendable_writer['mock_stream'].persisted_size = 0 await writer.open() - self.assertTrue(writer._is_stream_open) - self.assertEqual(writer.generation, 456) - self.assertEqual(writer.write_handle, b"new-h") - self.mock_stream.open.assert_awaited_once() + assert writer._is_stream_open + assert writer.generation == 456 + assert writer.write_handle == b"new-h" + mock_appendable_writer['mock_stream'].open.assert_awaited_once() @pytest.mark.asyncio - async def test_open_already_open_raises(self): - writer = self._make_one() + async def test_open_already_open_raises(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) writer._is_stream_open = True - with self.assertRaisesRegex(ValueError, "already open"): + with pytest.raises(ValueError, match="already open"): await writer.open() - @pytest.mark.asyncio - def test_on_open_error_redirection(self): + def test_on_open_error_redirection(self, mock_appendable_writer): """Verify redirect info is extracted from helper.""" - writer = self._make_one() + writer = self._make_one(mock_appendable_writer['mock_client']) redirect = BidiWriteObjectRedirectedError( routing_token="rt1", write_handle=storage_type.BidiWriteHandle(handle=b"h1"), @@ -212,20 +222,20 @@ def test_on_open_error_redirection(self): ): writer._on_open_error(exceptions.Aborted("redirect")) - self.assertEqual(writer._routing_token, "rt1") - self.assertEqual(writer.write_handle.handle, b"h1") - self.assertEqual(writer.generation, 777) + assert writer._routing_token == "rt1" + assert writer.write_handle.handle == b"h1" + assert writer.generation == 777 # ------------------------------------------------------------------------- # Append & Integration Tests # ------------------------------------------------------------------------- @pytest.mark.asyncio - async def test_append_integration_basic(self): + async def test_append_integration_basic(self, mock_appendable_writer): """Verify append orchestrates manager and drives the internal generator.""" - writer = self._make_one() + writer = self._make_one(mock_appendable_writer['mock_client']) writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream + writer.write_obj_stream = mock_appendable_writer['mock_stream'] writer.persisted_size = 0 data = b"test-data" @@ -239,7 +249,7 @@ async def mock_execute(state, policy): dummy_reqs = [storage_type.BidiWriteObjectRequest()] gen = factory(dummy_reqs, state) - self.mock_stream.recv.side_effect = [ + mock_appendable_writer['mock_stream'].recv.side_effect = [ storage_type.BidiWriteObjectResponse( persisted_size=len(data), write_handle=storage_type.BidiWriteHandle(handle=b"h2"), @@ -252,22 +262,22 @@ async def mock_execute(state, policy): MockManager.return_value.execute.side_effect = mock_execute await writer.append(data) - self.assertEqual(writer.persisted_size, len(data)) - sent_req = self.mock_stream.send.call_args[0][0] - self.assertTrue(sent_req.state_lookup) - self.assertTrue(sent_req.flush) + assert writer.persisted_size == len(data) + sent_req = mock_appendable_writer['mock_stream'].send.call_args[0][0] + assert sent_req.state_lookup + assert sent_req.flush @pytest.mark.asyncio - async def test_append_recovery_reopens_stream(self): + async def test_append_recovery_reopens_stream(self, mock_appendable_writer): """Verifies re-opening logic on retry.""" - writer = self._make_one(write_handle=b"h1") + writer = self._make_one(mock_appendable_writer['mock_client'], write_handle=b"h1") writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream + writer.write_obj_stream = mock_appendable_writer['mock_stream'] # Setup mock to allow close() call - self.mock_stream.is_stream_open = True + mock_appendable_writer['mock_stream'].is_stream_open = True async def mock_open(metadata=None): - writer.write_obj_stream = self.mock_stream + writer.write_obj_stream = mock_appendable_writer['mock_stream'] writer._is_stream_open = True writer.persisted_size = 5 writer.write_handle = b"h_recovered" @@ -289,21 +299,21 @@ async def mock_execute(state, policy): pass # Simulate Attempt 2 gen2 = factory([], state) - self.mock_stream.recv.return_value = None + mock_appendable_writer['mock_stream'].recv.return_value = None async for _ in gen2: pass MockManager.return_value.execute.side_effect = mock_execute await writer.append(b"0123456789") - self.mock_stream.close.assert_awaited() + mock_appendable_writer['mock_stream'].close.assert_awaited() mock_writer_open.assert_awaited() - self.assertEqual(writer.persisted_size, 5) + assert writer.persisted_size == 5 @pytest.mark.asyncio - async def test_append_unimplemented_string_raises(self): - writer = self._make_one() - with self.assertRaises(NotImplementedError): + async def test_append_unimplemented_string_raises(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) + with pytest.raises(NotImplementedError): await writer.append_from_string("test") # ------------------------------------------------------------------------- @@ -311,72 +321,72 @@ async def test_append_unimplemented_string_raises(self): # ------------------------------------------------------------------------- @pytest.mark.asyncio - async def test_flush_resets_counters(self): - writer = self._make_one() + async def test_flush_resets_counters(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream + writer.write_obj_stream = mock_appendable_writer['mock_stream'] writer.bytes_appended_since_last_flush = 100 - self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( + mock_appendable_writer['mock_stream'].recv.return_value = storage_type.BidiWriteObjectResponse( persisted_size=200 ) await writer.flush() - self.assertEqual(writer.bytes_appended_since_last_flush, 0) - self.assertEqual(writer.persisted_size, 200) + assert writer.bytes_appended_since_last_flush == 0 + assert writer.persisted_size == 200 @pytest.mark.asyncio - async def test_simple_flush(self): - writer = self._make_one() + async def test_simple_flush(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream + writer.write_obj_stream = mock_appendable_writer['mock_stream'] writer.bytes_appended_since_last_flush = 50 await writer.simple_flush() - self.mock_stream.send.assert_awaited_with( + mock_appendable_writer['mock_stream'].send.assert_awaited_with( storage_type.BidiWriteObjectRequest(flush=True) ) - self.assertEqual(writer.bytes_appended_since_last_flush, 0) + assert writer.bytes_appended_since_last_flush == 0 @pytest.mark.asyncio - async def test_close_without_finalize(self): - writer = self._make_one() + async def test_close_without_finalize(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream + writer.write_obj_stream = mock_appendable_writer['mock_stream'] writer.persisted_size = 50 size = await writer.close() - self.mock_stream.close.assert_awaited() - self.assertFalse(writer._is_stream_open) - self.assertEqual(size, 50) + mock_appendable_writer['mock_stream'].close.assert_awaited() + assert not writer._is_stream_open + assert size == 50 @pytest.mark.asyncio - async def test_finalize_lifecycle(self): - writer = self._make_one() + async def test_finalize_lifecycle(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) writer._is_stream_open = True - writer.write_obj_stream = self.mock_stream + writer.write_obj_stream = mock_appendable_writer['mock_stream'] resource = storage_type.Object(size=999) - self.mock_stream.recv.return_value = storage_type.BidiWriteObjectResponse( + mock_appendable_writer['mock_stream'].recv.return_value = storage_type.BidiWriteObjectResponse( resource=resource ) res = await writer.finalize() - self.assertEqual(res, resource) - self.assertEqual(writer.persisted_size, 999) - self.mock_stream.send.assert_awaited_with( + assert res == resource + assert writer.persisted_size == 999 + mock_appendable_writer['mock_stream'].send.assert_awaited_with( storage_type.BidiWriteObjectRequest(finish_write=True) ) - self.mock_stream.close.assert_awaited() - self.assertFalse(writer._is_stream_open) + mock_appendable_writer['mock_stream'].close.assert_awaited() + assert not writer._is_stream_open @pytest.mark.asyncio - async def test_close_with_finalize_on_close(self): - writer = self._make_one() + async def test_close_with_finalize_on_close(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) writer._is_stream_open = True writer.finalize = AsyncMock() @@ -388,19 +398,19 @@ async def test_close_with_finalize_on_close(self): # ------------------------------------------------------------------------- @pytest.mark.asyncio - async def test_append_from_file_integration(self): - writer = self._make_one() + async def test_append_from_file_integration(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) writer._is_stream_open = True writer.append = AsyncMock() fp = io.BytesIO(b"a" * 12) await writer.append_from_file(fp, block_size=4) - self.assertEqual(writer.append.await_count, 3) + assert writer.append.await_count == 3 @pytest.mark.asyncio - async def test_methods_require_open_stream_raises(self): - writer = self._make_one() + async def test_methods_require_open_stream_raises(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer['mock_client']) methods = [ writer.append(b"data"), writer.flush(), @@ -410,5 +420,5 @@ async def test_methods_require_open_stream_raises(self): writer.state_lookup(), ] for coro in methods: - with self.assertRaisesRegex(ValueError, "Stream is not open"): + with pytest.raises(ValueError, match="Stream is not open"): await coro diff --git a/tests/unit/asyncio/test_async_grpc_client.py b/tests/unit/asyncio/test_async_grpc_client.py index eb06ab938..7321f99ad 100644 --- a/tests/unit/asyncio/test_async_grpc_client.py +++ b/tests/unit/asyncio/test_async_grpc_client.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +import pytest from unittest import mock from google.auth import credentials as auth_credentials from google.auth.credentials import AnonymousCredentials @@ -24,7 +24,7 @@ def _make_credentials(spec=None): return mock.Mock(spec=spec) -class TestAsyncGrpcClient(unittest.TestCase): +class TestAsyncGrpcClient: @mock.patch("google.cloud._storage_v2.StorageAsyncClient") def test_constructor_default_options(self, mock_async_storage_client): from google.cloud.storage._experimental.asyncio import async_grpc_client @@ -111,7 +111,7 @@ def test_grpc_client_property(self, mock_grpc_gapic_client): client_info=mock_client_info, client_options=mock_client_options, ) - self.assertIs(retrieved_client, mock_grpc_gapic_client.return_value) + assert retrieved_client is mock_grpc_gapic_client.return_value @mock.patch("google.cloud._storage_v2.StorageAsyncClient") def test_grpc_client_with_anon_creds(self, mock_grpc_gapic_client): @@ -131,7 +131,7 @@ def test_grpc_client_with_anon_creds(self, mock_grpc_gapic_client): retrieved_client = client.grpc_client # Assert - self.assertIs(retrieved_client, mock_grpc_gapic_client.return_value) + assert retrieved_client is mock_grpc_gapic_client.return_value mock_transport_cls.create_channel.assert_called_once_with( attempt_direct_path=True, credentials=anonymous_creds diff --git a/tests/unit/asyncio/test_async_write_object_stream.py b/tests/unit/asyncio/test_async_write_object_stream.py index 42611b0b8..aec0b3794 100644 --- a/tests/unit/asyncio/test_async_write_object_stream.py +++ b/tests/unit/asyncio/test_async_write_object_stream.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import unittest.mock as mock from unittest.mock import AsyncMock, MagicMock import pytest @@ -29,38 +28,42 @@ FULL_BUCKET_PATH = f"projects/_/buckets/{BUCKET}" -class TestAsyncWriteObjectStream(unittest.TestCase): - def setUp(self): - self.mock_client = MagicMock() - # Mocking transport internal structures - mock_transport = MagicMock() - mock_transport.bidi_write_object = mock.sentinel.bidi_write_object - mock_transport._wrapped_methods = { - mock.sentinel.bidi_write_object: mock.sentinel.wrapped_bidi_write_object - } - self.mock_client._client._transport = mock_transport +@pytest.fixture +def mock_client(): + """Fixture to provide a mock gRPC client.""" + client = MagicMock() + # Mocking transport internal structures + mock_transport = MagicMock() + mock_transport.bidi_write_object = mock.sentinel.bidi_write_object + mock_transport._wrapped_methods = { + mock.sentinel.bidi_write_object: mock.sentinel.wrapped_bidi_write_object + } + client._client._transport = mock_transport + return client + + +class TestAsyncWriteObjectStream: + """Test suite for AsyncWriteObjectStream.""" # ------------------------------------------------------------------------- # Initialization Tests # ------------------------------------------------------------------------- - def test_init_basic(self): - stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) - self.assertEqual(stream.bucket_name, BUCKET) - self.assertEqual(stream.object_name, OBJECT) - self.assertEqual(stream._full_bucket_name, FULL_BUCKET_PATH) - self.assertEqual( - stream.metadata, (("x-goog-request-params", f"bucket={FULL_BUCKET_PATH}"),) - ) - self.assertFalse(stream.is_stream_open) + def test_init_basic(self, mock_client): + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + assert stream.bucket_name == BUCKET + assert stream.object_name == OBJECT + assert stream._full_bucket_name == FULL_BUCKET_PATH + assert stream.metadata == (("x-goog-request-params", f"bucket={FULL_BUCKET_PATH}"),) + assert not stream.is_stream_open - def test_init_raises_value_error(self): - with self.assertRaisesRegex(ValueError, "client must be provided"): + def test_init_raises_value_error(self, mock_client): + with pytest.raises(ValueError, match="client must be provided"): _AsyncWriteObjectStream(None, BUCKET, OBJECT) - with self.assertRaisesRegex(ValueError, "bucket_name must be provided"): - _AsyncWriteObjectStream(self.mock_client, None, OBJECT) - with self.assertRaisesRegex(ValueError, "object_name must be provided"): - _AsyncWriteObjectStream(self.mock_client, BUCKET, None) + with pytest.raises(ValueError, match="bucket_name must be provided"): + _AsyncWriteObjectStream(mock_client, None, OBJECT) + with pytest.raises(ValueError, match="object_name must be provided"): + _AsyncWriteObjectStream(mock_client, BUCKET, None) # ------------------------------------------------------------------------- # Open Stream Tests @@ -70,7 +73,7 @@ def test_init_raises_value_error(self): "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" ) @pytest.mark.asyncio - async def test_open_new_object(self, mock_rpc_cls): + async def test_open_new_object(self, mock_rpc_cls, mock_client): mock_rpc = mock_rpc_cls.return_value mock_rpc.open = AsyncMock() @@ -82,25 +85,25 @@ async def test_open_new_object(self, mock_rpc_cls): mock_response.write_handle = WRITE_HANDLE mock_rpc.recv = AsyncMock(return_value=mock_response) - stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) await stream.open() # Check if BidiRpc was initialized with WriteObjectSpec call_args = mock_rpc_cls.call_args initial_request = call_args.kwargs["initial_request"] - self.assertIsNotNone(initial_request.write_object_spec) - self.assertEqual(initial_request.write_object_spec.resource.name, OBJECT) - self.assertTrue(initial_request.write_object_spec.appendable) + assert initial_request.write_object_spec is not None + assert initial_request.write_object_spec.resource.name == OBJECT + assert initial_request.write_object_spec.appendable - self.assertTrue(stream.is_stream_open) - self.assertEqual(stream.write_handle, WRITE_HANDLE) - self.assertEqual(stream.generation_number, GENERATION) + assert stream.is_stream_open + assert stream.write_handle == WRITE_HANDLE + assert stream.generation_number == GENERATION @mock.patch( "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" ) @pytest.mark.asyncio - async def test_open_existing_object_with_token(self, mock_rpc_cls): + async def test_open_existing_object_with_token(self, mock_rpc_cls, mock_client): mock_rpc = mock_rpc_cls.return_value mock_rpc.open = AsyncMock() @@ -112,7 +115,7 @@ async def test_open_existing_object_with_token(self, mock_rpc_cls): mock_rpc.recv = AsyncMock(return_value=mock_response) stream = _AsyncWriteObjectStream( - self.mock_client, + mock_client, BUCKET, OBJECT, generation_number=GENERATION, @@ -122,21 +125,21 @@ async def test_open_existing_object_with_token(self, mock_rpc_cls): # Verify AppendObjectSpec attributes initial_request = mock_rpc_cls.call_args.kwargs["initial_request"] - self.assertIsNotNone(initial_request.append_object_spec) - self.assertEqual(initial_request.append_object_spec.generation, GENERATION) - self.assertEqual(initial_request.append_object_spec.routing_token, "token-123") - self.assertEqual(stream.persisted_size, 1024) + assert initial_request.append_object_spec is not None + assert initial_request.append_object_spec.generation == GENERATION + assert initial_request.append_object_spec.routing_token == "token-123" + assert stream.persisted_size == 1024 @mock.patch( "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" ) @pytest.mark.asyncio - async def test_open_metadata_merging(self, mock_rpc_cls): + async def test_open_metadata_merging(self, mock_rpc_cls, mock_client): mock_rpc = mock_rpc_cls.return_value mock_rpc.open = AsyncMock() mock_rpc.recv = AsyncMock(return_value=MagicMock(resource=None)) - stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) extra_metadata = [("x-custom", "val"), ("x-goog-request-params", "extra=param")] await stream.open(metadata=extra_metadata) @@ -144,17 +147,17 @@ async def test_open_metadata_merging(self, mock_rpc_cls): # Verify that metadata combined bucket and extra params passed_metadata = mock_rpc_cls.call_args.kwargs["metadata"] meta_dict = dict(passed_metadata) - self.assertEqual(meta_dict["x-custom"], "val") + assert meta_dict["x-custom"] == "val" # Params should be comma separated params = meta_dict["x-goog-request-params"] - self.assertIn(f"bucket={FULL_BUCKET_PATH}", params) - self.assertIn("extra=param", params) + assert f"bucket={FULL_BUCKET_PATH}" in params + assert "extra=param" in params @pytest.mark.asyncio - async def test_open_already_open_raises(self): - stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + async def test_open_already_open_raises(self, mock_client): + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) stream._is_stream_open = True - with self.assertRaisesRegex(ValueError, "already open"): + with pytest.raises(ValueError, match="already open"): await stream.open() # ------------------------------------------------------------------------- @@ -165,14 +168,14 @@ async def test_open_already_open_raises(self): "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" ) @pytest.mark.asyncio - async def test_send_and_recv_logic(self, mock_rpc_cls): + async def test_send_and_recv_logic(self, mock_rpc_cls, mock_client): # Setup open stream mock_rpc = mock_rpc_cls.return_value mock_rpc.open = AsyncMock() mock_rpc.send = AsyncMock() # Crucial: Must be AsyncMock mock_rpc.recv = AsyncMock(return_value=MagicMock(resource=None)) - stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) await stream.open() # Test Send @@ -188,27 +191,27 @@ async def test_send_and_recv_logic(self, mock_rpc_cls): mock_rpc.recv.return_value = mock_response res = await stream.recv() - self.assertEqual(res.persisted_size, 5000) - self.assertEqual(stream.persisted_size, 5000) - self.assertEqual(stream.write_handle, b"new-handle") + assert res.persisted_size == 5000 + assert stream.persisted_size == 5000 + assert stream.write_handle == b"new-handle" @pytest.mark.asyncio - async def test_close_success(self): - stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) + async def test_close_success(self, mock_client): + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) stream._is_stream_open = True stream.socket_like_rpc = AsyncMock() stream.socket_like_rpc.close = AsyncMock() await stream.close() stream.socket_like_rpc.close.assert_awaited_once() - self.assertFalse(stream.is_stream_open) + assert not stream.is_stream_open @pytest.mark.asyncio - async def test_methods_require_open_raises(self): - stream = _AsyncWriteObjectStream(self.mock_client, BUCKET, OBJECT) - with self.assertRaisesRegex(ValueError, "Stream is not open"): + async def test_methods_require_open_raises(self, mock_client): + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + with pytest.raises(ValueError, match="Stream is not open"): await stream.send(MagicMock()) - with self.assertRaisesRegex(ValueError, "Stream is not open"): + with pytest.raises(ValueError, match="Stream is not open"): await stream.recv() - with self.assertRaisesRegex(ValueError, "Stream is not open"): + with pytest.raises(ValueError, match="Stream is not open"): await stream.close() From ad68e91a3e5bb86bdfb018aeae1520d888edac4a Mon Sep 17 00:00:00 2001 From: Pulkit Aggarwal Date: Wed, 14 Jan 2026 06:33:17 +0000 Subject: [PATCH 10/10] fix lint errors --- .../retry/writes_resumption_strategy.py | 1 - .../retry/test_writes_resumption_strategy.py | 32 ++++-- .../test_async_appendable_object_writer.py | 104 +++++++++--------- tests/unit/asyncio/test_async_grpc_client.py | 1 - .../asyncio/test_async_write_object_stream.py | 4 +- 5 files changed, 77 insertions(+), 65 deletions(-) diff --git a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py index bef72ce64..1995fa637 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py @@ -15,7 +15,6 @@ from typing import Any, Dict, IO, List, Optional, Union import google_crc32c -from google.api_core import exceptions from google.cloud._storage_v2.types import storage as storage_type from google.cloud._storage_v2.types.storage import BidiWriteObjectRedirectedError from google.cloud.storage._experimental.asyncio.retry.base_strategy import ( diff --git a/tests/unit/asyncio/retry/test_writes_resumption_strategy.py b/tests/unit/asyncio/retry/test_writes_resumption_strategy.py index 138035e3b..ea3c25f8f 100644 --- a/tests/unit/asyncio/retry/test_writes_resumption_strategy.py +++ b/tests/unit/asyncio/retry/test_writes_resumption_strategy.py @@ -133,16 +133,16 @@ def test_generate_requests_flush_logic_exact_interval(self, strategy): requests = strategy.generate_requests(state) # Request index 1 (4 bytes total) should have flush=True - assert requests[0].flush == False - assert requests[1].flush == True + assert requests[0].flush is False + assert requests[1].flush is True # Request index 2 (8 bytes total) should have flush=True - assert requests[2].flush == False - assert requests[3].flush == True + assert requests[2].flush is False + assert requests[3].flush is True # Request index 3 (12 bytes total) should have flush=True - assert requests[4].flush == False - assert requests[5].flush == True + assert requests[4].flush is False + assert requests[5].flush is True # Verify counter reset in state assert write_state.bytes_since_last_flush == 0 @@ -158,7 +158,7 @@ def test_generate_requests_flush_logic_none_interval(self, strategy): requests = strategy.generate_requests(state) for req in requests: - assert req.flush == False + assert req.flush is False def test_generate_requests_flush_logic_data_less_than_interval(self, strategy): """Verify flush is not set if data sent is less than interval.""" @@ -173,7 +173,7 @@ def test_generate_requests_flush_logic_data_less_than_interval(self, strategy): # Total 5 bytes < 10 bytes interval for req in requests: - assert req.flush == False + assert req.flush is False assert write_state.bytes_since_last_flush == 5 @@ -298,7 +298,8 @@ async def test_recover_state_on_failure_direct_redirect(self, strategy): state = {"write_state": write_state} redirect = BidiWriteObjectRedirectedError( - routing_token="tok-1", write_handle=storage_type.BidiWriteHandle(handle=b"h-1"), + routing_token="tok-1", + write_handle=storage_type.BidiWriteHandle(handle=b"h-1"), ) await strategy.recover_state_on_failure(redirect, state) @@ -327,7 +328,9 @@ async def test_recover_state_on_failure_trailer_metadata_redirect(self, strategy redirect_proto = BidiWriteObjectRedirectedError(routing_token="metadata-token") status = status_pb2.Status() detail = status.details.add() - detail.type_url = "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + detail.type_url = ( + "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + ) detail.value = BidiWriteObjectRedirectedError.serialize(redirect_proto) # FIX: No spec= here, because Aborted doesn't have trailing_metadata in its base definition @@ -337,8 +340,13 @@ async def test_recover_state_on_failure_trailer_metadata_redirect(self, strategy ("grpc-status-details-bin", status.SerializeToString()) ] - with mock.patch("google.cloud.storage._experimental.asyncio.retry.writes_resumption_strategy._extract_bidi_writes_redirect_proto", return_value=redirect_proto): - await strategy.recover_state_on_failure(mock_error, {"write_state": write_state}) + with mock.patch( + "google.cloud.storage._experimental.asyncio.retry.writes_resumption_strategy._extract_bidi_writes_redirect_proto", + return_value=redirect_proto, + ): + await strategy.recover_state_on_failure( + mock_error, {"write_state": write_state} + ) assert write_state.routing_token == "metadata-token" diff --git a/tests/unit/asyncio/test_async_appendable_object_writer.py b/tests/unit/asyncio/test_async_appendable_object_writer.py index 9f8fc3c1e..9c4858309 100644 --- a/tests/unit/asyncio/test_async_appendable_object_writer.py +++ b/tests/unit/asyncio/test_async_appendable_object_writer.py @@ -13,7 +13,6 @@ # limitations under the License. import io -import unittest import unittest.mock as mock from unittest.mock import AsyncMock, MagicMock import pytest @@ -135,21 +134,24 @@ def test_init_defaults(self, mock_appendable_writer): assert writer.flush_interval == _DEFAULT_FLUSH_INTERVAL_BYTES def test_init_with_writer_options(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer["mock_client"], writer_options={"FLUSH_INTERVAL_BYTES": EIGHT_MIB}) + writer = self._make_one( + mock_appendable_writer["mock_client"], + writer_options={"FLUSH_INTERVAL_BYTES": EIGHT_MIB}, + ) assert writer.flush_interval == EIGHT_MIB def test_init_validation_chunk_size_raises(self, mock_appendable_writer): with pytest.raises(exceptions.OutOfRange): self._make_one( mock_appendable_writer["mock_client"], - writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES - 1} + writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES - 1}, ) def test_init_validation_multiple_raises(self, mock_appendable_writer): with pytest.raises(exceptions.OutOfRange): self._make_one( mock_appendable_writer["mock_client"], - writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES + 1} + writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES + 1}, ) def test_init_raises_if_crc32c_missing(self, mock_appendable_writer): @@ -166,50 +168,50 @@ def test_init_raises_if_crc32c_missing(self, mock_appendable_writer): @pytest.mark.asyncio async def test_state_lookup_success(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) writer._is_stream_open = True - writer.write_obj_stream = mock_appendable_writer['mock_stream'] + writer.write_obj_stream = mock_appendable_writer["mock_stream"] - mock_appendable_writer['mock_stream'].recv.return_value = storage_type.BidiWriteObjectResponse( - persisted_size=100 - ) + mock_appendable_writer[ + "mock_stream" + ].recv.return_value = storage_type.BidiWriteObjectResponse(persisted_size=100) size = await writer.state_lookup() - mock_appendable_writer['mock_stream'].send.assert_awaited_once() + mock_appendable_writer["mock_stream"].send.assert_awaited_once() assert size == 100 assert writer.persisted_size == 100 @pytest.mark.asyncio async def test_state_lookup_raises_if_not_open(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) with pytest.raises(ValueError, match="Stream is not open"): await writer.state_lookup() @pytest.mark.asyncio async def test_open_success(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) - mock_appendable_writer['mock_stream'].generation_number = 456 - mock_appendable_writer['mock_stream'].write_handle = b"new-h" - mock_appendable_writer['mock_stream'].persisted_size = 0 + writer = self._make_one(mock_appendable_writer["mock_client"]) + mock_appendable_writer["mock_stream"].generation_number = 456 + mock_appendable_writer["mock_stream"].write_handle = b"new-h" + mock_appendable_writer["mock_stream"].persisted_size = 0 await writer.open() assert writer._is_stream_open assert writer.generation == 456 assert writer.write_handle == b"new-h" - mock_appendable_writer['mock_stream'].open.assert_awaited_once() + mock_appendable_writer["mock_stream"].open.assert_awaited_once() @pytest.mark.asyncio async def test_open_already_open_raises(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) writer._is_stream_open = True with pytest.raises(ValueError, match="already open"): await writer.open() def test_on_open_error_redirection(self, mock_appendable_writer): """Verify redirect info is extracted from helper.""" - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) redirect = BidiWriteObjectRedirectedError( routing_token="rt1", write_handle=storage_type.BidiWriteHandle(handle=b"h1"), @@ -233,9 +235,9 @@ def test_on_open_error_redirection(self, mock_appendable_writer): @pytest.mark.asyncio async def test_append_integration_basic(self, mock_appendable_writer): """Verify append orchestrates manager and drives the internal generator.""" - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) writer._is_stream_open = True - writer.write_obj_stream = mock_appendable_writer['mock_stream'] + writer.write_obj_stream = mock_appendable_writer["mock_stream"] writer.persisted_size = 0 data = b"test-data" @@ -249,7 +251,7 @@ async def mock_execute(state, policy): dummy_reqs = [storage_type.BidiWriteObjectRequest()] gen = factory(dummy_reqs, state) - mock_appendable_writer['mock_stream'].recv.side_effect = [ + mock_appendable_writer["mock_stream"].recv.side_effect = [ storage_type.BidiWriteObjectResponse( persisted_size=len(data), write_handle=storage_type.BidiWriteHandle(handle=b"h2"), @@ -263,21 +265,23 @@ async def mock_execute(state, policy): await writer.append(data) assert writer.persisted_size == len(data) - sent_req = mock_appendable_writer['mock_stream'].send.call_args[0][0] + sent_req = mock_appendable_writer["mock_stream"].send.call_args[0][0] assert sent_req.state_lookup assert sent_req.flush @pytest.mark.asyncio async def test_append_recovery_reopens_stream(self, mock_appendable_writer): """Verifies re-opening logic on retry.""" - writer = self._make_one(mock_appendable_writer['mock_client'], write_handle=b"h1") + writer = self._make_one( + mock_appendable_writer["mock_client"], write_handle=b"h1" + ) writer._is_stream_open = True - writer.write_obj_stream = mock_appendable_writer['mock_stream'] + writer.write_obj_stream = mock_appendable_writer["mock_stream"] # Setup mock to allow close() call - mock_appendable_writer['mock_stream'].is_stream_open = True + mock_appendable_writer["mock_stream"].is_stream_open = True async def mock_open(metadata=None): - writer.write_obj_stream = mock_appendable_writer['mock_stream'] + writer.write_obj_stream = mock_appendable_writer["mock_stream"] writer._is_stream_open = True writer.persisted_size = 5 writer.write_handle = b"h_recovered" @@ -299,20 +303,20 @@ async def mock_execute(state, policy): pass # Simulate Attempt 2 gen2 = factory([], state) - mock_appendable_writer['mock_stream'].recv.return_value = None + mock_appendable_writer["mock_stream"].recv.return_value = None async for _ in gen2: pass MockManager.return_value.execute.side_effect = mock_execute await writer.append(b"0123456789") - mock_appendable_writer['mock_stream'].close.assert_awaited() + mock_appendable_writer["mock_stream"].close.assert_awaited() mock_writer_open.assert_awaited() assert writer.persisted_size == 5 @pytest.mark.asyncio async def test_append_unimplemented_string_raises(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) with pytest.raises(NotImplementedError): await writer.append_from_string("test") @@ -322,14 +326,14 @@ async def test_append_unimplemented_string_raises(self, mock_appendable_writer): @pytest.mark.asyncio async def test_flush_resets_counters(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) writer._is_stream_open = True - writer.write_obj_stream = mock_appendable_writer['mock_stream'] + writer.write_obj_stream = mock_appendable_writer["mock_stream"] writer.bytes_appended_since_last_flush = 100 - mock_appendable_writer['mock_stream'].recv.return_value = storage_type.BidiWriteObjectResponse( - persisted_size=200 - ) + mock_appendable_writer[ + "mock_stream" + ].recv.return_value = storage_type.BidiWriteObjectResponse(persisted_size=200) await writer.flush() @@ -338,55 +342,55 @@ async def test_flush_resets_counters(self, mock_appendable_writer): @pytest.mark.asyncio async def test_simple_flush(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) writer._is_stream_open = True - writer.write_obj_stream = mock_appendable_writer['mock_stream'] + writer.write_obj_stream = mock_appendable_writer["mock_stream"] writer.bytes_appended_since_last_flush = 50 await writer.simple_flush() - mock_appendable_writer['mock_stream'].send.assert_awaited_with( + mock_appendable_writer["mock_stream"].send.assert_awaited_with( storage_type.BidiWriteObjectRequest(flush=True) ) assert writer.bytes_appended_since_last_flush == 0 @pytest.mark.asyncio async def test_close_without_finalize(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) writer._is_stream_open = True - writer.write_obj_stream = mock_appendable_writer['mock_stream'] + writer.write_obj_stream = mock_appendable_writer["mock_stream"] writer.persisted_size = 50 size = await writer.close() - mock_appendable_writer['mock_stream'].close.assert_awaited() + mock_appendable_writer["mock_stream"].close.assert_awaited() assert not writer._is_stream_open assert size == 50 @pytest.mark.asyncio async def test_finalize_lifecycle(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) writer._is_stream_open = True - writer.write_obj_stream = mock_appendable_writer['mock_stream'] + writer.write_obj_stream = mock_appendable_writer["mock_stream"] resource = storage_type.Object(size=999) - mock_appendable_writer['mock_stream'].recv.return_value = storage_type.BidiWriteObjectResponse( - resource=resource - ) + mock_appendable_writer[ + "mock_stream" + ].recv.return_value = storage_type.BidiWriteObjectResponse(resource=resource) res = await writer.finalize() assert res == resource assert writer.persisted_size == 999 - mock_appendable_writer['mock_stream'].send.assert_awaited_with( + mock_appendable_writer["mock_stream"].send.assert_awaited_with( storage_type.BidiWriteObjectRequest(finish_write=True) ) - mock_appendable_writer['mock_stream'].close.assert_awaited() + mock_appendable_writer["mock_stream"].close.assert_awaited() assert not writer._is_stream_open @pytest.mark.asyncio async def test_close_with_finalize_on_close(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) writer._is_stream_open = True writer.finalize = AsyncMock() @@ -399,7 +403,7 @@ async def test_close_with_finalize_on_close(self, mock_appendable_writer): @pytest.mark.asyncio async def test_append_from_file_integration(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) writer._is_stream_open = True writer.append = AsyncMock() @@ -410,7 +414,7 @@ async def test_append_from_file_integration(self, mock_appendable_writer): @pytest.mark.asyncio async def test_methods_require_open_stream_raises(self, mock_appendable_writer): - writer = self._make_one(mock_appendable_writer['mock_client']) + writer = self._make_one(mock_appendable_writer["mock_client"]) methods = [ writer.append(b"data"), writer.flush(), diff --git a/tests/unit/asyncio/test_async_grpc_client.py b/tests/unit/asyncio/test_async_grpc_client.py index 7321f99ad..400fb2d9d 100644 --- a/tests/unit/asyncio/test_async_grpc_client.py +++ b/tests/unit/asyncio/test_async_grpc_client.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from unittest import mock from google.auth import credentials as auth_credentials from google.auth.credentials import AnonymousCredentials diff --git a/tests/unit/asyncio/test_async_write_object_stream.py b/tests/unit/asyncio/test_async_write_object_stream.py index aec0b3794..7bfa2cea0 100644 --- a/tests/unit/asyncio/test_async_write_object_stream.py +++ b/tests/unit/asyncio/test_async_write_object_stream.py @@ -54,7 +54,9 @@ def test_init_basic(self, mock_client): assert stream.bucket_name == BUCKET assert stream.object_name == OBJECT assert stream._full_bucket_name == FULL_BUCKET_PATH - assert stream.metadata == (("x-goog-request-params", f"bucket={FULL_BUCKET_PATH}"),) + assert stream.metadata == ( + ("x-goog-request-params", f"bucket={FULL_BUCKET_PATH}"), + ) assert not stream.is_stream_open def test_init_raises_value_error(self, mock_client):