diff --git a/edb/testbase/connection.py b/edb/testbase/connection.py index 1df65bd744c..906181cdb16 100644 --- a/edb/testbase/connection.py +++ b/edb/testbase/connection.py @@ -92,7 +92,10 @@ def _make_start_query(self): if self._state is TransactionState.STARTED: raise errors.InterfaceError( 'cannot start; the transaction is already started') - return self._make_start_query_inner() + qry = self._make_start_query_inner() + if self._connection._top_xact is None: + self._connection._top_xact = self + return qry @abc.abstractmethod def _make_start_query_inner(self): @@ -101,6 +104,9 @@ def _make_start_query_inner(self): def _make_commit_query(self): self.__check_state('commit') + if self._connection._top_xact is self: + self._connection._top_xact = None + return 'COMMIT;' def _make_rollback_query(self): @@ -163,9 +169,7 @@ class RawTransaction(BaseTransaction): def _make_start_query_inner(self): con = self._connection - if con._top_xact is None: - con._top_xact = self - else: + if con._top_xact is not None: # Nested transaction block self._nested = True @@ -179,9 +183,6 @@ def _make_start_query_inner(self): def _make_commit_query(self): query = super()._make_commit_query() - if self._connection._top_xact is self: - self._connection._top_xact = None - if self._nested: query = f'RELEASE SAVEPOINT {self._id};' @@ -190,9 +191,6 @@ def _make_commit_query(self): def _make_rollback_query(self): query = super()._make_rollback_query() - if self._connection._top_xact is self: - self._connection._top_xact = None - if self._nested: query = f'ROLLBACK TO SAVEPOINT {self._id};' @@ -293,7 +291,7 @@ async def _ensure_transaction(self): class Retry: - def __init__(self, connection): + def __init__(self, connection, raw=False): self._connection = connection self._iteration = 0 self._done = False diff --git a/edb/testbase/server.py b/edb/testbase/server.py index f7e170a78d7..a92ba8e1918 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -359,6 +359,10 @@ def __getstate__(self): } +class RollbackException(Exception): + pass + + class RollbackChanges: def __init__(self, test): self._conn = test.con @@ -950,6 +954,20 @@ def repl(self): def _run_and_rollback(self): return RollbackChanges(self) + async def _run_and_rollback_retrying(self): + @contextlib.asynccontextmanager + async def cm(tx): + try: + async with tx: + await tx._ensure_transaction() + yield tx + raise RollbackException + except RollbackException: + pass + + async for tx in self.con.retrying_transaction(): + yield cm(tx) + def assert_data_shape(self, data, shape, message=None): assert_data_shape.assert_data_shape( data, shape, self.fail, message=message) diff --git a/tests/test_dump01.py b/tests/test_dump01.py index 32425326b8f..44fe38d1677 100644 --- a/tests/test_dump01.py +++ b/tests/test_dump01.py @@ -27,14 +27,11 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self, include_data=True): - tx = self.con.transaction() - await tx.start() - try: - await self._ensure_schema_integrity() - if include_data: - await self._ensure_data_integrity() - finally: - await tx.rollback() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._ensure_schema_integrity() + if include_data: + await self._ensure_data_integrity() async def _ensure_schema_integrity(self): # check that all the type annotations are in place diff --git a/tests/test_dump02.py b/tests/test_dump02.py index b2eed4b1064..393b51a9504 100644 --- a/tests/test_dump02.py +++ b/tests/test_dump02.py @@ -25,14 +25,11 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self, include_data=True): - tx = self.con.transaction() - await tx.start() - try: - await self._ensure_schema_integrity() - if include_data: - await self._ensure_data_integrity() - finally: - await tx.rollback() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._ensure_schema_integrity() + if include_data: + await self._ensure_data_integrity() async def _ensure_schema_integrity(self): # Check that index exists diff --git a/tests/test_dump03.py b/tests/test_dump03.py index 33751317dd3..4675cc737e6 100644 --- a/tests/test_dump03.py +++ b/tests/test_dump03.py @@ -25,12 +25,10 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self): - tx = self.con.transaction() - await tx.start() - try: + # We can't use _retrying here, since the sequences won't get reset. + # Hopefully this won't be a problem. + async with self._run_and_rollback(): await self._ensure_schema_data_integrity() - finally: - await tx.rollback() async def _ensure_schema_data_integrity(self): await self.assert_query_result( diff --git a/tests/test_dump_v2.py b/tests/test_dump_v2.py index ca42c579770..1c617f046f7 100644 --- a/tests/test_dump_v2.py +++ b/tests/test_dump_v2.py @@ -27,14 +27,11 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self, include_data=True): - tx = self.con.transaction() - await tx.start() - try: - await self._ensure_schema_integrity() - if include_data: - await self._ensure_data_integrity() - finally: - await tx.rollback() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._ensure_schema_integrity() + if include_data: + await self._ensure_data_integrity() async def _ensure_schema_integrity(self): # Validate access policies diff --git a/tests/test_dump_v3.py b/tests/test_dump_v3.py index 4c3682290b1..9a24a5adbea 100644 --- a/tests/test_dump_v3.py +++ b/tests/test_dump_v3.py @@ -25,12 +25,9 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self): - tx = self.con.transaction() - await tx.start() - try: - await self._ensure_schema_data_integrity() - finally: - await tx.rollback() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._ensure_schema_data_integrity() async def _ensure_schema_data_integrity(self): await self.assert_query_result( diff --git a/tests/test_dump_v4.py b/tests/test_dump_v4.py index 3fb6a08c217..f4e81e0b83b 100644 --- a/tests/test_dump_v4.py +++ b/tests/test_dump_v4.py @@ -25,13 +25,10 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self, include_secrets=False): - tx = self.con.transaction() - await tx.start() - try: - await self._ensure_schema_data_integrity( - include_secrets=include_secrets) - finally: - await tx.rollback() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._ensure_schema_data_integrity( + include_secrets=include_secrets) async def _ensure_schema_data_integrity(self, include_secrets): await self.assert_query_result( diff --git a/tests/test_edgeql_extensions.py b/tests/test_edgeql_extensions.py index f5e493c1bba..b13cc207040 100644 --- a/tests/test_edgeql_extensions.py +++ b/tests/test_edgeql_extensions.py @@ -117,8 +117,9 @@ async def test_edgeql_extensions_01(self): }; ''') try: - async with self._run_and_rollback(): - await self._extension_test_01() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._extension_test_01() finally: await self.con.execute(''' drop extension package ltree VERSION '1.0' @@ -314,10 +315,12 @@ async def test_edgeql_extensions_02(self): }; ''') try: - async with self._run_and_rollback(): - await self._extension_test_02a() - async with self._run_and_rollback(): - await self._extension_test_02b() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._extension_test_02a() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._extension_test_02b() finally: await self.con.execute(''' drop extension package varchar VERSION '1.0' @@ -842,8 +845,9 @@ async def test_edgeql_extensions_05(self): ''') try: - async with self._run_and_rollback(): - await self._extension_test_05(in_tx=True) + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._extension_test_05(in_tx=True) try: await self._extension_test_05(in_tx=False) finally: @@ -984,8 +988,9 @@ async def test_edgeql_extensions_06(self): }; ''') try: - async with self._run_and_rollback(): - await self._extension_test_06b() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._extension_test_06b() finally: await self.con.execute(''' drop extension package bar VERSION '1.0';