diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index d01c63aff5..1f37ff2a03 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import inspect import grpc from concurrent import futures from google.protobuf import empty_pb2 +from grpc_status.rpc_status import _Status from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc import google.cloud.spanner_v1.testing.spanner_pb2_grpc as spanner_grpc @@ -28,6 +30,7 @@ class MockSpanner: def __init__(self): self.results = {} + self.errors = {} def add_result(self, sql: str, result: result_set.ResultSet): self.results[sql.lower().strip()] = result @@ -38,6 +41,15 @@ def get_result(self, sql: str) -> result_set.ResultSet: raise ValueError(f"No result found for {sql}") return result + def add_error(self, method: str, error: _Status): + self.errors[method] = error + + def pop_error(self, context): + name = inspect.currentframe().f_back.f_code.co_name + error: _Status | None = self.errors.pop(name, None) + if error: + context.abort_with_status(error) + def get_result_as_partial_result_sets( self, sql: str ) -> [result_set.PartialResultSet]: @@ -174,6 +186,7 @@ def __create_transaction( def Commit(self, request, context): self._requests.append(request) + self.mock_spanner.pop_error(context) tx = self.transactions[request.transaction_id] if tx is None: raise ValueError(f"Transaction not found: {request.transaction_id}") diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 1cd7656297..12c98bc51b 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -28,6 +28,37 @@ from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.instance import Instance import grpc +from google.rpc import code_pb2 +from google.rpc import status_pb2 +from google.rpc.error_details_pb2 import RetryInfo +from google.protobuf.duration_pb2 import Duration +from grpc_status._common import code_to_grpc_status_code +from grpc_status.rpc_status import _Status + + +# Creates an aborted status with the smallest possible retry delay. +def aborted_status() -> _Status: + error = status_pb2.Status( + code=code_pb2.ABORTED, + message="Transaction was aborted.", + ) + retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1)) + status = _Status( + code=code_to_grpc_status_code(error.code), + details=error.message, + trailing_metadata=( + ("grpc-status-details-bin", error.SerializeToString()), + ( + "google.rpc.retryinfo-bin", + retry_info.SerializeToString(), + ), + ), + ) + return status + + +def add_error(method: str, error: status_pb2.Status): + MockServerTestBase.spanner_service.mock_spanner.add_error(method, error) def add_result(sql: str, result: result_set.ResultSet): diff --git a/tests/mockserver_tests/test_aborted_transaction.py b/tests/mockserver_tests/test_aborted_transaction.py new file mode 100644 index 0000000000..ede2675ce6 --- /dev/null +++ b/tests/mockserver_tests/test_aborted_transaction.py @@ -0,0 +1,50 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.spanner_v1 import ( + BatchCreateSessionsRequest, + BeginTransactionRequest, + CommitRequest, +) +from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer +from google.cloud.spanner_v1.transaction import Transaction +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_error, + aborted_status, +) + + +class TestAbortedTransaction(MockServerTestBase): + def test_run_in_transaction_commit_aborted(self): + # Add an Aborted error for the Commit method on the mock server. + add_error(SpannerServicer.Commit.__name__, aborted_status()) + # Run a transaction. The Commit method will return Aborted the first + # time that the transaction tries to commit. It will then be retried + # and succeed. + self.database.run_in_transaction(_insert_mutations) + + # Verify that the transaction was retried. + requests = self.spanner_service.requests + self.assertEqual(5, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], CommitRequest)) + # The transaction is aborted and retried. + self.assertTrue(isinstance(requests[3], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[4], CommitRequest)) + + +def _insert_mutations(transaction: Transaction): + transaction.insert("my_table", ["col1", "col2"], ["value1", "value2"])