diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 40f30f1d8..02220fcfd 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -19,6 +19,7 @@ cast, Any, AsyncIterable, + Callable, Optional, Set, Sequence, @@ -97,18 +98,24 @@ ) from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._async._swappable_channel import ( - AsyncSwappableChannel, + AsyncSwappableChannel as SwappableChannelType, + ) + from google.cloud.bigtable.data._async.metrics_interceptor import ( + AsyncBigtableMetricsInterceptor as MetricsInterceptorType, ) else: from typing import Iterable # noqa: F401 from grpc import insecure_channel + from grpc import intercept_channel from google.cloud.bigtable_v2.services.bigtable.transports import BigtableGrpcTransport as TransportType # type: ignore from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient # type: ignore from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( # noqa: F401 - SwappableChannel, + SwappableChannel as SwappableChannelType, + ) + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( # noqa: F401 + BigtableMetricsInterceptor as MetricsInterceptorType, ) - if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples @@ -203,7 +210,7 @@ def __init__( credentials = google.auth.credentials.AnonymousCredentials() if project is None: project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT - + self._metrics_interceptor = MetricsInterceptorType() # initialize client ClientWithProject.__init__( self, @@ -257,12 +264,11 @@ def __init__( stacklevel=2, ) - @CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"}) - def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel: + def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannelType: """ This method is called by the gapic transport to create a grpc channel. - The init arguments passed down are captured in a partial used by AsyncSwappableChannel + The init arguments passed down are captured in a partial used by SwappableChannel to create new channel instances in the future, as part of the channel refresh logic Emulators always use an inseucre channel @@ -273,12 +279,30 @@ def _build_grpc_channel(self, *args, **kwargs) -> AsyncSwappableChannel: Returns: a custom wrapped swappable channel """ + create_channel_fn: Callable[[], Channel] if self._emulator_host is not None: - # emulators use insecure channel + # Emulators use insecure channels create_channel_fn = partial(insecure_channel, self._emulator_host) - else: + elif CrossSync.is_async: + # For async client, use the default create_channel. create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) - return AsyncSwappableChannel(create_channel_fn) + else: + # For sync client, wrap create_channel with interceptors. + def sync_create_channel_fn(): + return intercept_channel( + TransportType.create_channel(*args, **kwargs), + self._metrics_interceptor, + ) + + create_channel_fn = sync_create_channel_fn + + # Instantiate SwappableChannelType with the determined creation function. + new_channel = SwappableChannelType(create_channel_fn) + if CrossSync.is_async: + # Attach async interceptors to the channel instance itself. + new_channel._unary_unary_interceptors.append(self._metrics_interceptor) + new_channel._unary_stream_interceptors.append(self._metrics_interceptor) + return new_channel @property def universe_domain(self) -> str: @@ -400,7 +424,7 @@ def _invalidate_channel_stubs(self): self.transport._stubs = {} self.transport._prep_wrapped_messages(self.client_info) - @CrossSync.convert(replace_symbols={"AsyncSwappableChannel": "SwappableChannel"}) + @CrossSync.convert async def _manage_channel( self, refresh_interval_min: float = 60 * 35, @@ -425,10 +449,10 @@ async def _manage_channel( grace_period: time to allow previous channel to serve existing requests before closing, in seconds """ - if not isinstance(self.transport.grpc_channel, AsyncSwappableChannel): + if not isinstance(self.transport.grpc_channel, SwappableChannelType): warnings.warn("Channel does not support auto-refresh.") return - super_channel: AsyncSwappableChannel = self.transport.grpc_channel + super_channel: SwappableChannelType = self.transport.grpc_channel first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) diff --git a/google/cloud/bigtable/data/_async/metrics_interceptor.py b/google/cloud/bigtable/data/_async/metrics_interceptor.py new file mode 100644 index 000000000..a154c0083 --- /dev/null +++ b/google/cloud/bigtable/data/_async/metrics_interceptor.py @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from google.cloud.bigtable.data._cross_sync import CrossSync + +if CrossSync.is_async: + from grpc.aio import UnaryUnaryClientInterceptor + from grpc.aio import UnaryStreamClientInterceptor +else: + from grpc import UnaryUnaryClientInterceptor + from grpc import UnaryStreamClientInterceptor + + +__CROSS_SYNC_OUTPUT__ = "google.cloud.bigtable.data._sync_autogen.metrics_interceptor" + + +@CrossSync.convert_class(sync_name="BigtableMetricsInterceptor") +class AsyncBigtableMetricsInterceptor( + UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor +): + """ + An async gRPC interceptor to add client metadata and print server metadata. + """ + + @CrossSync.convert + async def intercept_unary_unary(self, continuation, client_call_details, request): + """ + Interceptor for unary rpcs: + - MutateRow + - CheckAndMutateRow + - ReadModifyWriteRow + """ + try: + call = await continuation(client_call_details, request) + return call + except Exception as rpc_error: + raise rpc_error + + @CrossSync.convert + async def intercept_unary_stream(self, continuation, client_call_details, request): + """ + Interceptor for streaming rpcs: + - ReadRows + - MutateRows + - SampleRowKeys + """ + try: + return self._streaming_generator_wrapper( + await continuation(client_call_details, request) + ) + except Exception as rpc_error: + # handle errors while intializing stream + raise rpc_error + + @staticmethod + @CrossSync.convert + async def _streaming_generator_wrapper(call): + """ + Wrapped generator to be returned by intercept_unary_stream. + """ + try: + async for response in call: + yield response + except Exception as e: + # handle errors while processing stream + raise e diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index 1c75823ae..e73d6e94c 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -17,7 +17,7 @@ # This file is automatically generated by CrossSync. Do not edit manually. from __future__ import annotations -from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING +from typing import cast, Any, Callable, Optional, Set, Sequence, TYPE_CHECKING import abc import time import warnings @@ -75,12 +75,18 @@ from google.cloud.bigtable.data._cross_sync import CrossSync from typing import Iterable from grpc import insecure_channel +from grpc import intercept_channel from google.cloud.bigtable_v2.services.bigtable.transports import ( BigtableGrpcTransport as TransportType, ) from google.cloud.bigtable_v2.services.bigtable import BigtableClient as GapicClient from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE -from google.cloud.bigtable.data._sync_autogen._swappable_channel import SwappableChannel +from google.cloud.bigtable.data._sync_autogen._swappable_channel import ( + SwappableChannel as SwappableChannelType, +) +from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( + BigtableMetricsInterceptor as MetricsInterceptorType, +) if TYPE_CHECKING: from google.cloud.bigtable.data._helpers import RowKeySamples @@ -143,6 +149,7 @@ def __init__( credentials = google.auth.credentials.AnonymousCredentials() if project is None: project = _DEFAULT_BIGTABLE_EMULATOR_CLIENT + self._metrics_interceptor = MetricsInterceptorType() ClientWithProject.__init__( self, credentials=credentials, @@ -186,7 +193,7 @@ def __init__( stacklevel=2, ) - def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannel: + def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannelType: """This method is called by the gapic transport to create a grpc channel. The init arguments passed down are captured in a partial used by SwappableChannel @@ -199,11 +206,20 @@ def _build_grpc_channel(self, *args, **kwargs) -> SwappableChannel: - **kwargs: keyword arguments passed by the gapic layer to create a new channel with Returns: a custom wrapped swappable channel""" + create_channel_fn: Callable[[], Channel] if self._emulator_host is not None: create_channel_fn = partial(insecure_channel, self._emulator_host) else: - create_channel_fn = partial(TransportType.create_channel, *args, **kwargs) - return SwappableChannel(create_channel_fn) + + def sync_create_channel_fn(): + return intercept_channel( + TransportType.create_channel(*args, **kwargs), + self._metrics_interceptor, + ) + + create_channel_fn = sync_create_channel_fn + new_channel = SwappableChannelType(create_channel_fn) + return new_channel @property def universe_domain(self) -> str: @@ -324,10 +340,10 @@ def _manage_channel( between `refresh_interval_min` and `refresh_interval_max` grace_period: time to allow previous channel to serve existing requests before closing, in seconds""" - if not isinstance(self.transport.grpc_channel, SwappableChannel): + if not isinstance(self.transport.grpc_channel, SwappableChannelType): warnings.warn("Channel does not support auto-refresh.") return - super_channel: SwappableChannel = self.transport.grpc_channel + super_channel: SwappableChannelType = self.transport.grpc_channel first_refresh = self._channel_init_time + random.uniform( refresh_interval_min, refresh_interval_max ) diff --git a/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py new file mode 100644 index 000000000..9e47313b0 --- /dev/null +++ b/google/cloud/bigtable/data/_sync_autogen/metrics_interceptor.py @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is automatically generated by CrossSync. Do not edit manually. + +from __future__ import annotations +from grpc import UnaryUnaryClientInterceptor +from grpc import UnaryStreamClientInterceptor + + +class BigtableMetricsInterceptor( + UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor +): + """ + An async gRPC interceptor to add client metadata and print server metadata. + """ + + def intercept_unary_unary(self, continuation, client_call_details, request): + """Interceptor for unary rpcs: + - MutateRow + - CheckAndMutateRow + - ReadModifyWriteRow""" + try: + call = continuation(client_call_details, request) + return call + except Exception as rpc_error: + raise rpc_error + + def intercept_unary_stream(self, continuation, client_call_details, request): + """Interceptor for streaming rpcs: + - ReadRows + - MutateRows + - SampleRowKeys""" + try: + return self._streaming_generator_wrapper( + continuation(client_call_details, request) + ) + except Exception as rpc_error: + raise rpc_error + + @staticmethod + def _streaming_generator_wrapper(call): + """Wrapped generator to be returned by intercept_unary_stream.""" + try: + for response in call: + yield response + except Exception as e: + raise e diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index c96570b76..39c454996 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -285,23 +285,28 @@ async def test_channel_refresh(self, table_id, instance_id, temp_rows): async with client.get_table(instance_id, table_id) as table: rows = await table.read_rows({}) channel_wrapper = client.transport.grpc_channel - first_channel = client.transport.grpc_channel._channel + first_channel = channel_wrapper._channel assert len(rows) == 2 await CrossSync.sleep(2) rows_after_refresh = await table.read_rows({}) assert len(rows_after_refresh) == 2 assert client.transport.grpc_channel is channel_wrapper - assert client.transport.grpc_channel._channel is not first_channel - # ensure gapic's logging interceptor is still active + updated_channel = channel_wrapper._channel + assert updated_channel is not first_channel + # ensure interceptors are kept (gapic's logging interceptor, and metric interceptor) if CrossSync.is_async: - interceptors = ( - client.transport.grpc_channel._channel._unary_unary_interceptors - ) - assert GapicInterceptor in [type(i) for i in interceptors] + unary_interceptors = updated_channel._unary_unary_interceptors + assert len(unary_interceptors) == 2 + assert GapicInterceptor in [type(i) for i in unary_interceptors] + assert client._metrics_interceptor in unary_interceptors + stream_interceptors = updated_channel._unary_stream_interceptors + assert len(stream_interceptors) == 1 + assert client._metrics_interceptor in stream_interceptors else: assert isinstance( client.transport._logged_channel._interceptor, GapicInterceptor ) + assert updated_channel._interceptor == client._metrics_interceptor finally: await client.close() diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index a78a8eb4c..37c00f2ae 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -237,16 +237,18 @@ def test_channel_refresh(self, table_id, instance_id, temp_rows): with client.get_table(instance_id, table_id) as table: rows = table.read_rows({}) channel_wrapper = client.transport.grpc_channel - first_channel = client.transport.grpc_channel._channel + first_channel = channel_wrapper._channel assert len(rows) == 2 CrossSync._Sync_Impl.sleep(2) rows_after_refresh = table.read_rows({}) assert len(rows_after_refresh) == 2 assert client.transport.grpc_channel is channel_wrapper - assert client.transport.grpc_channel._channel is not first_channel + updated_channel = channel_wrapper._channel + assert updated_channel is not first_channel assert isinstance( client.transport._logged_channel._interceptor, GapicInterceptor ) + assert updated_channel._interceptor == client._metrics_interceptor finally: client.close() @@ -258,7 +260,7 @@ def test_mutation_set_cell(self, target, temp_rows): """Ensure cells can be set properly""" row_key = b"bulk_mutate" new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) target.mutate_row(row_key, mutation) @@ -312,7 +314,7 @@ def test_bulk_mutations_set_cell(self, client, target, temp_rows): from google.cloud.bigtable.data.mutations import RowMutationEntry new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) @@ -347,11 +349,11 @@ def test_mutations_batcher_context_manager(self, client, target, temp_rows): """test batcher with context manager. Should flush on exit""" from google.cloud.bigtable.data.mutations import RowMutationEntry - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( target, temp_rows, new_value=new_value2 ) bulk_mutation = RowMutationEntry(row_key, [mutation]) @@ -372,7 +374,7 @@ def test_mutations_batcher_timer_flush(self, client, target, temp_rows): from google.cloud.bigtable.data.mutations import RowMutationEntry new_value = uuid.uuid4().hex.encode() - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) @@ -394,12 +396,12 @@ def test_mutations_batcher_count_flush(self, client, target, temp_rows): """batch should flush after flush_limit_mutation_count mutations""" from google.cloud.bigtable.data.mutations import RowMutationEntry - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( target, temp_rows, new_value=new_value2 ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) @@ -426,12 +428,12 @@ def test_mutations_batcher_bytes_flush(self, client, target, temp_rows): """batch should flush after flush_limit_bytes bytes""" from google.cloud.bigtable.data.mutations import RowMutationEntry - new_value, new_value2 = [uuid.uuid4().hex.encode() for _ in range(2)] - row_key, mutation = self._create_row_and_mutation( + (new_value, new_value2) = [uuid.uuid4().hex.encode() for _ in range(2)] + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( target, temp_rows, new_value=new_value2 ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) @@ -457,11 +459,11 @@ def test_mutations_batcher_no_flush(self, client, target, temp_rows): new_value = uuid.uuid4().hex.encode() start_value = b"unchanged" - row_key, mutation = self._create_row_and_mutation( + (row_key, mutation) = self._create_row_and_mutation( target, temp_rows, start_value=start_value, new_value=new_value ) bulk_mutation = RowMutationEntry(row_key, [mutation]) - row_key2, mutation2 = self._create_row_and_mutation( + (row_key2, mutation2) = self._create_row_and_mutation( target, temp_rows, start_value=start_value, new_value=new_value ) bulk_mutation2 = RowMutationEntry(row_key2, [mutation2]) diff --git a/tests/unit/data/_async/test_metrics_interceptor.py b/tests/unit/data/_async/test_metrics_interceptor.py new file mode 100644 index 000000000..6ea958358 --- /dev/null +++ b/tests/unit/data/_async/test_metrics_interceptor.py @@ -0,0 +1,168 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from grpc import RpcError + +from google.cloud.bigtable.data._cross_sync import CrossSync + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock +except ImportError: # pragma: NO COVER + import mock # type: ignore + +if CrossSync.is_async: + from google.cloud.bigtable.data._async.metrics_interceptor import ( + AsyncBigtableMetricsInterceptor, + ) +else: + from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( # noqa: F401 + BigtableMetricsInterceptor, + ) + + +__CROSS_SYNC_OUTPUT__ = "tests.unit.data._sync_autogen.test_metrics_interceptor" + + +@CrossSync.convert(replace_symbols={"__aiter__": "__iter__"}) +def _make_mock_stream_call(values, exc=None): + """ + Create a mock call object that can be used for streaming calls + """ + call = CrossSync.Mock() + + async def gen(): + for val in values: + yield val + if exc: + raise exc + + call.__aiter__ = mock.Mock(return_value=gen()) + return call + + +@CrossSync.convert_class(sync_name="TestMetricsInterceptor") +class TestMetricsInterceptorAsync: + @staticmethod + @CrossSync.convert( + replace_symbols={ + "AsyncBigtableMetricsInterceptor": "BigtableMetricsInterceptor" + } + ) + def _get_target_class(): + return AsyncBigtableMetricsInterceptor + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + @CrossSync.pytest + async def test_unary_unary_interceptor_success(self): + """Test that interceptor handles successful unary-unary calls""" + instance = self._make_one() + continuation = CrossSync.Mock() + call = continuation.return_value + details = mock.Mock() + request = mock.Mock() + result = await instance.intercept_unary_unary(continuation, details, request) + assert result == call + continuation.assert_called_once_with(details, request) + + @CrossSync.pytest + async def test_unary_unary_interceptor_failure(self): + """Test a failed RpcError with metadata""" + + instance = self._make_one() + exc = RpcError("test") + continuation = CrossSync.Mock(side_effect=exc) + details = mock.Mock() + request = mock.Mock() + with pytest.raises(RpcError) as e: + await instance.intercept_unary_unary(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + + @CrossSync.pytest + async def test_unary_unary_interceptor_failure_generic(self): + """Test generic exception""" + + instance = self._make_one() + exc = ValueError("test") + continuation = CrossSync.Mock(side_effect=exc) + details = mock.Mock() + request = mock.Mock() + with pytest.raises(ValueError) as e: + await instance.intercept_unary_unary(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + + @CrossSync.pytest + async def test_unary_stream_interceptor_success(self): + """Test that interceptor handles successful unary-stream calls""" + + instance = self._make_one() + + continuation = CrossSync.Mock(return_value=_make_mock_stream_call([1, 2])) + details = mock.Mock() + request = mock.Mock() + wrapper = await instance.intercept_unary_stream(continuation, details, request) + results = [val async for val in wrapper] + assert results == [1, 2] + continuation.assert_called_once_with(details, request) + + @CrossSync.pytest + async def test_unary_stream_interceptor_failure_mid_stream(self): + """Test that interceptor handles failures mid-stream""" + instance = self._make_one() + exc = ValueError("test") + continuation = CrossSync.Mock(return_value=_make_mock_stream_call([1], exc=exc)) + details = mock.Mock() + request = mock.Mock() + wrapper = await instance.intercept_unary_stream(continuation, details, request) + with pytest.raises(ValueError) as e: + [val async for val in wrapper] + assert e.value == exc + continuation.assert_called_once_with(details, request) + + @CrossSync.pytest + async def test_unary_stream_interceptor_failure_start_stream(self): + """Test that interceptor handles failures at start of stream with RpcError with metadata""" + + instance = self._make_one() + exc = RpcError("test") + + continuation = CrossSync.Mock() + continuation.side_effect = exc + details = mock.Mock() + request = mock.Mock() + with pytest.raises(RpcError) as e: + await instance.intercept_unary_stream(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + + @CrossSync.pytest + async def test_unary_stream_interceptor_failure_start_stream_generic(self): + """Test that interceptor handles failures at start of stream with generic exception""" + + instance = self._make_one() + exc = ValueError("test") + + continuation = CrossSync.Mock() + continuation.side_effect = exc + details = mock.Mock() + request = mock.Mock() + with pytest.raises(ValueError) as e: + await instance.intercept_unary_stream(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) diff --git a/tests/unit/data/_sync_autogen/test_metrics_interceptor.py b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py new file mode 100644 index 000000000..56a6f3650 --- /dev/null +++ b/tests/unit/data/_sync_autogen/test_metrics_interceptor.py @@ -0,0 +1,140 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This file is automatically generated by CrossSync. Do not edit manually. + +import pytest +from grpc import RpcError +from google.cloud.bigtable.data._cross_sync import CrossSync + +try: + from unittest import mock +except ImportError: + import mock +from google.cloud.bigtable.data._sync_autogen.metrics_interceptor import ( + BigtableMetricsInterceptor, +) + + +def _make_mock_stream_call(values, exc=None): + """Create a mock call object that can be used for streaming calls""" + call = CrossSync._Sync_Impl.Mock() + + def gen(): + for val in values: + yield val + if exc: + raise exc + + call.__iter__ = mock.Mock(return_value=gen()) + return call + + +class TestMetricsInterceptor: + @staticmethod + def _get_target_class(): + return BigtableMetricsInterceptor + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_unary_unary_interceptor_success(self): + """Test that interceptor handles successful unary-unary calls""" + instance = self._make_one() + continuation = CrossSync._Sync_Impl.Mock() + call = continuation.return_value + details = mock.Mock() + request = mock.Mock() + result = instance.intercept_unary_unary(continuation, details, request) + assert result == call + continuation.assert_called_once_with(details, request) + + def test_unary_unary_interceptor_failure(self): + """Test a failed RpcError with metadata""" + instance = self._make_one() + exc = RpcError("test") + continuation = CrossSync._Sync_Impl.Mock(side_effect=exc) + details = mock.Mock() + request = mock.Mock() + with pytest.raises(RpcError) as e: + instance.intercept_unary_unary(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + + def test_unary_unary_interceptor_failure_generic(self): + """Test generic exception""" + instance = self._make_one() + exc = ValueError("test") + continuation = CrossSync._Sync_Impl.Mock(side_effect=exc) + details = mock.Mock() + request = mock.Mock() + with pytest.raises(ValueError) as e: + instance.intercept_unary_unary(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + + def test_unary_stream_interceptor_success(self): + """Test that interceptor handles successful unary-stream calls""" + instance = self._make_one() + continuation = CrossSync._Sync_Impl.Mock( + return_value=_make_mock_stream_call([1, 2]) + ) + details = mock.Mock() + request = mock.Mock() + wrapper = instance.intercept_unary_stream(continuation, details, request) + results = [val for val in wrapper] + assert results == [1, 2] + continuation.assert_called_once_with(details, request) + + def test_unary_stream_interceptor_failure_mid_stream(self): + """Test that interceptor handles failures mid-stream""" + instance = self._make_one() + exc = ValueError("test") + continuation = CrossSync._Sync_Impl.Mock( + return_value=_make_mock_stream_call([1], exc=exc) + ) + details = mock.Mock() + request = mock.Mock() + wrapper = instance.intercept_unary_stream(continuation, details, request) + with pytest.raises(ValueError) as e: + [val for val in wrapper] + assert e.value == exc + continuation.assert_called_once_with(details, request) + + def test_unary_stream_interceptor_failure_start_stream(self): + """Test that interceptor handles failures at start of stream with RpcError with metadata""" + instance = self._make_one() + exc = RpcError("test") + continuation = CrossSync._Sync_Impl.Mock() + continuation.side_effect = exc + details = mock.Mock() + request = mock.Mock() + with pytest.raises(RpcError) as e: + instance.intercept_unary_stream(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) + + def test_unary_stream_interceptor_failure_start_stream_generic(self): + """Test that interceptor handles failures at start of stream with generic exception""" + instance = self._make_one() + exc = ValueError("test") + continuation = CrossSync._Sync_Impl.Mock() + continuation.side_effect = exc + details = mock.Mock() + request = mock.Mock() + with pytest.raises(ValueError) as e: + instance.intercept_unary_stream(continuation, details, request) + assert e.value == exc + continuation.assert_called_once_with(details, request) diff --git a/tests/unit/data/test_sync_up_to_date.py b/tests/unit/data/test_sync_up_to_date.py index d4623a6c8..e6bce9cf6 100644 --- a/tests/unit/data/test_sync_up_to_date.py +++ b/tests/unit/data/test_sync_up_to_date.py @@ -90,7 +90,7 @@ def test_verify_headers(sync_file): \#\ distributed\ under\ the\ License\ is\ distributed\ on\ an\ \"AS\ IS\"\ BASIS,\n \#\ WITHOUT\ WARRANTIES\ OR\ CONDITIONS\ OF\ ANY\ KIND,\ either\ express\ or\ implied\.\n \#\ See\ the\ License\ for\ the\ specific\ language\ governing\ permissions\ and\n - \#\ limitations\ under\ the\ License\. + \#\ limitations\ under\ the\ License """ pattern = re.compile(license_regex, re.VERBOSE)