diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py new file mode 100644 index 0000000000..1cd7656297 --- /dev/null +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -0,0 +1,139 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode +from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer +from google.cloud.spanner_v1.testing.mock_spanner import ( + start_mock_server, + SpannerServicer, +) +import google.cloud.spanner_v1.types.type as spanner_type +import google.cloud.spanner_v1.types.result_set as result_set +from google.api_core.client_options import ClientOptions +from google.auth.credentials import AnonymousCredentials +from google.cloud.spanner_v1 import Client, TypeCode, FixedSizePool +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.instance import Instance +import grpc + + +def add_result(sql: str, result: result_set.ResultSet): + MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result) + + +def add_update_count( + sql: str, count: int, dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL +): + if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC: + stats = dict(row_count_lower_bound=count) + else: + stats = dict(row_count_exact=count) + result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats))) + add_result(sql, result) + + +def add_select1_result(): + add_single_result("select 1", "c", TypeCode.INT64, [("1",)]) + + +def add_single_result( + sql: str, column_name: str, type_code: spanner_type.TypeCode, row +): + result = result_set.ResultSet( + dict( + metadata=result_set.ResultSetMetadata( + dict( + row_type=spanner_type.StructType( + dict( + fields=[ + spanner_type.StructType.Field( + dict( + name=column_name, + type=spanner_type.Type(dict(code=type_code)), + ) + ) + ] + ) + ) + ) + ), + ) + ) + result.rows.extend(row) + MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result) + + +class MockServerTestBase(unittest.TestCase): + server: grpc.Server = None + spanner_service: SpannerServicer = None + database_admin_service: DatabaseAdminServicer = None + port: int = None + + def __init__(self, *args, **kwargs): + super(MockServerTestBase, self).__init__(*args, **kwargs) + self._client = None + self._instance = None + self._database = None + + @classmethod + def setup_class(cls): + ( + MockServerTestBase.server, + MockServerTestBase.spanner_service, + MockServerTestBase.database_admin_service, + MockServerTestBase.port, + ) = start_mock_server() + + @classmethod + def teardown_class(cls): + if MockServerTestBase.server is not None: + MockServerTestBase.server.stop(grace=None) + MockServerTestBase.server = None + + def setup_method(self, *args, **kwargs): + self._client = None + self._instance = None + self._database = None + + def teardown_method(self, *args, **kwargs): + MockServerTestBase.spanner_service.clear_requests() + MockServerTestBase.database_admin_service.clear_requests() + + @property + def client(self) -> Client: + if self._client is None: + self._client = Client( + project="p", + credentials=AnonymousCredentials(), + client_options=ClientOptions( + api_endpoint="localhost:" + str(MockServerTestBase.port), + ), + ) + return self._client + + @property + def instance(self) -> Instance: + if self._instance is None: + self._instance = self.client.instance("test-instance") + return self._instance + + @property + def database(self) -> Database: + if self._database is None: + self._database = self.instance.database( + "test-database", pool=FixedSizePool(size=10) + ) + return self._database diff --git a/tests/mockserver_tests/test_basics.py b/tests/mockserver_tests/test_basics.py index 9d6dad095e..ed0906cb9b 100644 --- a/tests/mockserver_tests/test_basics.py +++ b/tests/mockserver_tests/test_basics.py @@ -12,131 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - from google.cloud.spanner_admin_database_v1.types import spanner_database_admin from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode -from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer -from google.cloud.spanner_v1.testing.mock_spanner import ( - start_mock_server, - SpannerServicer, -) -import google.cloud.spanner_v1.types.type as spanner_type -import google.cloud.spanner_v1.types.result_set as result_set -from google.api_core.client_options import ClientOptions -from google.auth.credentials import AnonymousCredentials from google.cloud.spanner_v1 import ( - Client, - FixedSizePool, BatchCreateSessionsRequest, ExecuteSqlRequest, BeginTransactionRequest, TransactionOptions, ) -from google.cloud.spanner_v1.database import Database -from google.cloud.spanner_v1.instance import Instance -import grpc - - -class TestBasics(unittest.TestCase): - server: grpc.Server = None - spanner_service: SpannerServicer = None - database_admin_service: DatabaseAdminServicer = None - port: int = None - - def __init__(self, *args, **kwargs): - super(TestBasics, self).__init__(*args, **kwargs) - self._client = None - self._instance = None - self._database = None - @classmethod - def setUpClass(cls): - ( - TestBasics.server, - TestBasics.spanner_service, - TestBasics.database_admin_service, - TestBasics.port, - ) = start_mock_server() - - @classmethod - def tearDownClass(cls): - if TestBasics.server is not None: - TestBasics.server.stop(grace=None) - TestBasics.server = None - - def teardown_method(self, *args, **kwargs): - TestBasics.spanner_service.clear_requests() - TestBasics.database_admin_service.clear_requests() - - def _add_select1_result(self): - result = result_set.ResultSet( - dict( - metadata=result_set.ResultSetMetadata( - dict( - row_type=spanner_type.StructType( - dict( - fields=[ - spanner_type.StructType.Field( - dict( - name="c", - type=spanner_type.Type( - dict(code=spanner_type.TypeCode.INT64) - ), - ) - ) - ] - ) - ) - ) - ), - ) - ) - result.rows.extend(["1"]) - TestBasics.spanner_service.mock_spanner.add_result("select 1", result) - - def add_update_count( - self, - sql: str, - count: int, - dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL, - ): - if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC: - stats = dict(row_count_lower_bound=count) - else: - stats = dict(row_count_exact=count) - result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats))) - TestBasics.spanner_service.mock_spanner.add_result(sql, result) - - @property - def client(self) -> Client: - if self._client is None: - self._client = Client( - project="test-project", - credentials=AnonymousCredentials(), - client_options=ClientOptions( - api_endpoint="localhost:" + str(TestBasics.port), - ), - ) - return self._client - - @property - def instance(self) -> Instance: - if self._instance is None: - self._instance = self.client.instance("test-instance") - return self._instance +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_select1_result, + add_update_count, +) - @property - def database(self) -> Database: - if self._database is None: - self._database = self.instance.database( - "test-database", pool=FixedSizePool(size=10) - ) - return self._database +class TestBasics(MockServerTestBase): def test_select1(self): - self._add_select1_result() + add_select1_result() with self.database.snapshot() as snapshot: results = snapshot.execute_sql("select 1") result_list = [] @@ -171,7 +66,7 @@ def test_create_table(self): # been re-factored to use a base class for the boiler plate code. def test_dbapi_partitioned_dml(self): sql = "UPDATE singers SET foo='bar' WHERE active = true" - self.add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC) + add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC) connection = Connection(self.instance, self.database) connection.autocommit = True connection.set_autocommit_dml_mode(AutocommitDmlMode.PARTITIONED_NON_ATOMIC)