diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index 1f37ff2a03..6b50d9a6d1 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -18,6 +18,13 @@ from google.protobuf import empty_pb2 from grpc_status.rpc_status import _Status + +from google.cloud.spanner_v1 import ( + TransactionOptions, + ResultSetMetadata, + ExecuteSqlRequest, + ExecuteBatchDmlRequest, +) from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc import google.cloud.spanner_v1.testing.spanner_pb2_grpc as spanner_grpc @@ -51,23 +58,25 @@ def pop_error(self, context): context.abort_with_status(error) def get_result_as_partial_result_sets( - self, sql: str + self, sql: str, started_transaction: transaction.Transaction ) -> [result_set.PartialResultSet]: result: result_set.ResultSet = self.get_result(sql) partials = [] first = True if len(result.rows) == 0: partial = result_set.PartialResultSet() - partial.metadata = result.metadata + partial.metadata = ResultSetMetadata(result.metadata) partials.append(partial) else: for row in result.rows: partial = result_set.PartialResultSet() if first: - partial.metadata = result.metadata + partial.metadata = ResultSetMetadata(result.metadata) partial.values.extend(row) partials.append(partial) partials[len(partials) - 1].stats = result.stats + if started_transaction: + partials[0].metadata.transaction = started_transaction return partials @@ -129,22 +138,29 @@ def DeleteSession(self, request, context): def ExecuteSql(self, request, context): self._requests.append(request) - return result_set.ResultSet() + self.mock_spanner.pop_error(context) + started_transaction = self.__maybe_create_transaction(request) + result: result_set.ResultSet = self.mock_spanner.get_result(request.sql) + if started_transaction: + result.metadata = ResultSetMetadata(result.metadata) + result.metadata.transaction = started_transaction + return result def ExecuteStreamingSql(self, request, context): self._requests.append(request) - partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql) + self.mock_spanner.pop_error(context) + started_transaction = self.__maybe_create_transaction(request) + partials = self.mock_spanner.get_result_as_partial_result_sets( + request.sql, started_transaction + ) for result in partials: yield result def ExecuteBatchDml(self, request, context): self._requests.append(request) + self.mock_spanner.pop_error(context) response = spanner.ExecuteBatchDmlResponse() - started_transaction = None - if not request.transaction.begin == transaction.TransactionOptions(): - started_transaction = self.__create_transaction( - request.session, request.transaction.begin - ) + started_transaction = self.__maybe_create_transaction(request) first = True for statement in request.statements: result = self.mock_spanner.get_result(statement.sql) @@ -170,6 +186,16 @@ def BeginTransaction(self, request, context): self._requests.append(request) return self.__create_transaction(request.session, request.options) + def __maybe_create_transaction( + self, request: ExecuteSqlRequest | ExecuteBatchDmlRequest + ): + started_transaction = None + if not request.transaction.begin == TransactionOptions(): + started_transaction = self.__create_transaction( + request.session, request.transaction.begin + ) + return started_transaction + def __create_transaction( self, session: str, options: transaction.TransactionOptions ) -> transaction.Transaction: diff --git a/tests/mockserver_tests/test_aborted_transaction.py b/tests/mockserver_tests/test_aborted_transaction.py index ede2675ce6..89b30a0875 100644 --- a/tests/mockserver_tests/test_aborted_transaction.py +++ b/tests/mockserver_tests/test_aborted_transaction.py @@ -16,6 +16,9 @@ BatchCreateSessionsRequest, BeginTransactionRequest, CommitRequest, + ExecuteSqlRequest, + TypeCode, + ExecuteBatchDmlRequest, ) from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer from google.cloud.spanner_v1.transaction import Transaction @@ -23,6 +26,8 @@ MockServerTestBase, add_error, aborted_status, + add_update_count, + add_single_result, ) @@ -45,6 +50,70 @@ def test_run_in_transaction_commit_aborted(self): self.assertTrue(isinstance(requests[3], BeginTransactionRequest)) self.assertTrue(isinstance(requests[4], CommitRequest)) + def test_run_in_transaction_update_aborted(self): + add_update_count("update my_table set my_col=1 where id=2", 1) + add_error(SpannerServicer.ExecuteSql.__name__, aborted_status()) + self.database.run_in_transaction(_execute_update) + + # Verify that the transaction was retried. + requests = self.spanner_service.requests + self.assertEqual(4, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[3], CommitRequest)) + + def test_run_in_transaction_query_aborted(self): + add_single_result( + "select value from my_table where id=1", + "value", + TypeCode.STRING, + "my-value", + ) + add_error(SpannerServicer.ExecuteStreamingSql.__name__, aborted_status()) + self.database.run_in_transaction(_execute_query) + + # Verify that the transaction was retried. + requests = self.spanner_service.requests + self.assertEqual(4, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[3], CommitRequest)) + + def test_run_in_transaction_batch_dml_aborted(self): + add_update_count("update my_table set my_col=1 where id=1", 1) + add_update_count("update my_table set my_col=1 where id=2", 1) + add_error(SpannerServicer.ExecuteBatchDml.__name__, aborted_status()) + self.database.run_in_transaction(_execute_batch_dml) + + # Verify that the transaction was retried. + requests = self.spanner_service.requests + self.assertEqual(4, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], ExecuteBatchDmlRequest)) + self.assertTrue(isinstance(requests[2], ExecuteBatchDmlRequest)) + self.assertTrue(isinstance(requests[3], CommitRequest)) + def _insert_mutations(transaction: Transaction): transaction.insert("my_table", ["col1", "col2"], ["value1", "value2"]) + + +def _execute_update(transaction: Transaction): + transaction.execute_update("update my_table set my_col=1 where id=2") + + +def _execute_query(transaction: Transaction): + rows = transaction.execute_sql("select value from my_table where id=1") + for _ in rows: + pass + + +def _execute_batch_dml(transaction: Transaction): + transaction.batch_update( + [ + "update my_table set my_col=1 where id=1", + "update my_table set my_col=1 where id=2", + ] + )