2525from dataclasses import dataclass
2626from dataclasses import field
2727from grpc import StatusCode
28+ from grpc import RpcError
29+ from grpc .aio import AioRpcError
2830
31+ from google .api_core .exceptions import GoogleAPICallError
32+ from google .api_core .retry import RetryFailureReason
2933import google .cloud .bigtable .data .exceptions as bt_exceptions
3034from google .cloud .bigtable_v2 .types .response_params import ResponseParams
3135from google .cloud .bigtable .data ._helpers import TrackedBackoffGenerator
36+ from google .cloud .bigtable .data .exceptions import _MutateRowsIncomplete
37+ from google .cloud .bigtable .data .exceptions import RetryExceptionGroup
3238from google .protobuf .message import DecodeError
3339
3440if TYPE_CHECKING :
@@ -280,7 +286,7 @@ def _parse_response_metadata_blob(blob: bytes) -> Tuple[str, str] | None:
280286 # failed to parse metadata
281287 return None
282288
283- def end_attempt_with_status (self , status : StatusCode | Exception ) -> None :
289+ def end_attempt_with_status (self , status : StatusCode | BaseException ) -> None :
284290 """
285291 Called to mark the end of an attempt for the operation.
286292
@@ -297,7 +303,7 @@ def end_attempt_with_status(self, status: StatusCode | Exception) -> None:
297303 return self ._handle_error (
298304 INVALID_STATE_ERROR .format ("end_attempt_with_status" , self .state )
299305 )
300- if isinstance (status , Exception ):
306+ if isinstance (status , BaseException ):
301307 status = self ._exc_to_status (status )
302308 complete_attempt = CompletedAttemptMetric (
303309 duration_ns = time .monotonic_ns () - self .active_attempt .start_time_ns ,
@@ -312,7 +318,7 @@ def end_attempt_with_status(self, status: StatusCode | Exception) -> None:
312318 for handler in self .handlers :
313319 handler .on_attempt_complete (complete_attempt , self )
314320
315- def end_with_status (self , status : StatusCode | Exception ) -> None :
321+ def end_with_status (self , status : StatusCode | BaseException ) -> None :
316322 """
317323 Called to mark the end of the operation. If there is an active attempt,
318324 end_attempt_with_status will be called with the same status.
@@ -329,7 +335,7 @@ def end_with_status(self, status: StatusCode | Exception) -> None:
329335 INVALID_STATE_ERROR .format ("end_with_status" , self .state )
330336 )
331337 final_status = (
332- self ._exc_to_status (status ) if isinstance (status , Exception ) else status
338+ self ._exc_to_status (status ) if isinstance (status , BaseException ) else status
333339 )
334340 if self .state == OperationState .ACTIVE_ATTEMPT :
335341 self .end_attempt_with_status (final_status )
@@ -367,7 +373,7 @@ def cancel(self):
367373 handler .on_operation_cancelled (self )
368374
369375 @staticmethod
370- def _exc_to_status (exc : Exception ) -> StatusCode :
376+ def _exc_to_status (exc : BaseException ) -> StatusCode :
371377 """
372378 Extracts the grpc status code from an exception.
373379
@@ -389,8 +395,71 @@ def _exc_to_status(exc: Exception) -> StatusCode:
389395 and exc .__cause__ .grpc_status_code is not None
390396 ):
391397 return exc .__cause__ .grpc_status_code
398+ if isinstance (exc , AioRpcError ) or isinstance (exc , RpcError ):
399+ return exc .code ()
392400 return StatusCode .UNKNOWN
393401
402+ def track_retryable_error (self , exc : Exception ) -> None :
403+ """
404+ Used as input to api_core.Retry classes, to track when retryable errors are encountered
405+
406+ Should be passed as on_error callback
407+ """
408+ try :
409+ # record metadata from failed rpc
410+ if (
411+ isinstance (exc , GoogleAPICallError )
412+ and exc .errors
413+ ):
414+ rpc_error = exc .errors [- 1 ]
415+ metadata = list (rpc_error .trailing_metadata ()) + list (
416+ rpc_error .initial_metadata ()
417+ )
418+ self .add_response_metadata ({k : v for k , v in metadata })
419+ except Exception :
420+ # ignore errors in metadata collection
421+ pass
422+ if isinstance (exc , _MutateRowsIncomplete ):
423+ # _MutateRowsIncomplete represents a successful rpc with some failed mutations
424+ # mark the attempt as successful
425+ self .end_attempt_with_status (StatusCode .OK )
426+ else :
427+ self .end_attempt_with_status (exc )
428+
429+ def track_terminal_error (self , exception_factory :callable [
430+ [list [Exception ], RetryFailureReason , float | None ],tuple [Exception , Exception | None ],
431+ ]) -> callable [[list [Exception ], RetryFailureReason , float | None ], None ]:
432+ """
433+ Used as input to api_core.Retry classes, to track when terminal errors are encountered
434+
435+ Should be used as a wrapper over an exception_factory callback
436+ """
437+ def wrapper (
438+ exc_list : list [Exception ], reason : RetryFailureReason , timeout_val : float | None
439+ ) -> tuple [Exception , Exception | None ]:
440+ source_exc , cause_exc = exception_factory (exc_list , reason , timeout_val )
441+ try :
442+ # record metadata from failed rpc
443+ if (
444+ isinstance (source_exc , GoogleAPICallError )
445+ and source_exc .errors
446+ ):
447+ rpc_error = source_exc .errors [- 1 ]
448+ metadata = list (rpc_error .trailing_metadata ()) + list (
449+ rpc_error .initial_metadata ()
450+ )
451+ self .add_response_metadata ({k : v for k , v in metadata })
452+ except Exception :
453+ # ignore errors in metadata collection
454+ pass
455+ if reason == RetryFailureReason .TIMEOUT and self .state == OperationState .ACTIVE_ATTEMPT and exc_list :
456+ # record ending attempt for timeout failures
457+ attempt_exc = exc_list [- 1 ]
458+ self .track_retryable_error (attempt_exc )
459+ self .end_with_status (source_exc )
460+ return source_exc , cause_exc
461+ return wrapper
462+
394463 @staticmethod
395464 def _handle_error (message : str ) -> None :
396465 """
@@ -418,8 +487,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
418487
419488 The operation is automatically ended on exit, with the status determined
420489 by the exception type and value.
490+
491+ If operation was already ended manually, do nothing.
421492 """
422- if exc_val is None :
423- self .end_with_success ()
424- else :
425- self .end_with_status (exc_val )
493+ if not self .state == OperationState .COMPLETED :
494+ if exc_val is None :
495+ self .end_with_success ()
496+ else :
497+ self .end_with_status (exc_val )
0 commit comments