Skip to content

Commit

Permalink
test: support inline-begin in mock server (googleapis#1271)
Browse files Browse the repository at this point in the history
  • Loading branch information
olavloite authored Dec 20, 2024
1 parent f2483e1 commit 6352dd2
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 10 deletions.
46 changes: 36 additions & 10 deletions google/cloud/spanner_v1/testing/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@

from google.protobuf import empty_pb2
from grpc_status.rpc_status import _Status

from google.cloud.spanner_v1 import (
TransactionOptions,
ResultSetMetadata,
ExecuteSqlRequest,
ExecuteBatchDmlRequest,
)
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc
import google.cloud.spanner_v1.testing.spanner_pb2_grpc as spanner_grpc
Expand Down Expand Up @@ -51,23 +58,25 @@ def pop_error(self, context):
context.abort_with_status(error)

def get_result_as_partial_result_sets(
self, sql: str
self, sql: str, started_transaction: transaction.Transaction
) -> [result_set.PartialResultSet]:
result: result_set.ResultSet = self.get_result(sql)
partials = []
first = True
if len(result.rows) == 0:
partial = result_set.PartialResultSet()
partial.metadata = result.metadata
partial.metadata = ResultSetMetadata(result.metadata)
partials.append(partial)
else:
for row in result.rows:
partial = result_set.PartialResultSet()
if first:
partial.metadata = result.metadata
partial.metadata = ResultSetMetadata(result.metadata)
partial.values.extend(row)
partials.append(partial)
partials[len(partials) - 1].stats = result.stats
if started_transaction:
partials[0].metadata.transaction = started_transaction
return partials


Expand Down Expand Up @@ -129,22 +138,29 @@ def DeleteSession(self, request, context):

def ExecuteSql(self, request, context):
self._requests.append(request)
return result_set.ResultSet()
self.mock_spanner.pop_error(context)
started_transaction = self.__maybe_create_transaction(request)
result: result_set.ResultSet = self.mock_spanner.get_result(request.sql)
if started_transaction:
result.metadata = ResultSetMetadata(result.metadata)
result.metadata.transaction = started_transaction
return result

def ExecuteStreamingSql(self, request, context):
self._requests.append(request)
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
self.mock_spanner.pop_error(context)
started_transaction = self.__maybe_create_transaction(request)
partials = self.mock_spanner.get_result_as_partial_result_sets(
request.sql, started_transaction
)
for result in partials:
yield result

def ExecuteBatchDml(self, request, context):
self._requests.append(request)
self.mock_spanner.pop_error(context)
response = spanner.ExecuteBatchDmlResponse()
started_transaction = None
if not request.transaction.begin == transaction.TransactionOptions():
started_transaction = self.__create_transaction(
request.session, request.transaction.begin
)
started_transaction = self.__maybe_create_transaction(request)
first = True
for statement in request.statements:
result = self.mock_spanner.get_result(statement.sql)
Expand All @@ -170,6 +186,16 @@ def BeginTransaction(self, request, context):
self._requests.append(request)
return self.__create_transaction(request.session, request.options)

def __maybe_create_transaction(
self, request: ExecuteSqlRequest | ExecuteBatchDmlRequest
):
started_transaction = None
if not request.transaction.begin == TransactionOptions():
started_transaction = self.__create_transaction(
request.session, request.transaction.begin
)
return started_transaction

def __create_transaction(
self, session: str, options: transaction.TransactionOptions
) -> transaction.Transaction:
Expand Down
69 changes: 69 additions & 0 deletions tests/mockserver_tests/test_aborted_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
BatchCreateSessionsRequest,
BeginTransactionRequest,
CommitRequest,
ExecuteSqlRequest,
TypeCode,
ExecuteBatchDmlRequest,
)
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
from google.cloud.spanner_v1.transaction import Transaction
from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_error,
aborted_status,
add_update_count,
add_single_result,
)


Expand All @@ -45,6 +50,70 @@ def test_run_in_transaction_commit_aborted(self):
self.assertTrue(isinstance(requests[3], BeginTransactionRequest))
self.assertTrue(isinstance(requests[4], CommitRequest))

def test_run_in_transaction_update_aborted(self):
add_update_count("update my_table set my_col=1 where id=2", 1)
add_error(SpannerServicer.ExecuteSql.__name__, aborted_status())
self.database.run_in_transaction(_execute_update)

# Verify that the transaction was retried.
requests = self.spanner_service.requests
self.assertEqual(4, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[3], CommitRequest))

def test_run_in_transaction_query_aborted(self):
add_single_result(
"select value from my_table where id=1",
"value",
TypeCode.STRING,
"my-value",
)
add_error(SpannerServicer.ExecuteStreamingSql.__name__, aborted_status())
self.database.run_in_transaction(_execute_query)

# Verify that the transaction was retried.
requests = self.spanner_service.requests
self.assertEqual(4, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[3], CommitRequest))

def test_run_in_transaction_batch_dml_aborted(self):
add_update_count("update my_table set my_col=1 where id=1", 1)
add_update_count("update my_table set my_col=1 where id=2", 1)
add_error(SpannerServicer.ExecuteBatchDml.__name__, aborted_status())
self.database.run_in_transaction(_execute_batch_dml)

# Verify that the transaction was retried.
requests = self.spanner_service.requests
self.assertEqual(4, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], ExecuteBatchDmlRequest))
self.assertTrue(isinstance(requests[2], ExecuteBatchDmlRequest))
self.assertTrue(isinstance(requests[3], CommitRequest))


def _insert_mutations(transaction: Transaction):
transaction.insert("my_table", ["col1", "col2"], ["value1", "value2"])


def _execute_update(transaction: Transaction):
transaction.execute_update("update my_table set my_col=1 where id=2")


def _execute_query(transaction: Transaction):
rows = transaction.execute_sql("select value from my_table where id=1")
for _ in rows:
pass


def _execute_batch_dml(transaction: Transaction):
transaction.batch_update(
[
"update my_table set my_col=1 where id=1",
"update my_table set my_col=1 where id=2",
]
)

0 comments on commit 6352dd2

Please sign in to comment.