From 89d8db3fb7a427b613634c744336eb8f8af09f41 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 1 Mar 2024 16:18:06 -0800 Subject: [PATCH] Implement _run_and_rollback_retrying (#6960) In general we should be using retrying transactions. Our isolated tests do retries at the outside, but non-isolated ones don't necessarily. Implement _run_and_rollback_retrying and use it some places that TransactionSerializationError flakes have been seen (in #6954 and in nightly CI). Honestly I don't really think that these places ought to be having TransactionSerializationErrors, but also it really is pretty required to do retries when using SERIALIZABLE, so we should start moving towards being more consistent about it. --- edb/testbase/connection.py | 20 +++++++++----------- edb/testbase/server.py | 18 ++++++++++++++++++ tests/test_dump01.py | 13 +++++-------- tests/test_dump02.py | 13 +++++-------- tests/test_dump03.py | 8 +++----- tests/test_dump_v2.py | 13 +++++-------- tests/test_dump_v3.py | 9 +++------ tests/test_dump_v4.py | 11 ++++------- tests/test_edgeql_extensions.py | 25 +++++++++++++++---------- 9 files changed, 67 insertions(+), 63 deletions(-) diff --git a/edb/testbase/connection.py b/edb/testbase/connection.py index 1df65bd744c..906181cdb16 100644 --- a/edb/testbase/connection.py +++ b/edb/testbase/connection.py @@ -92,7 +92,10 @@ def _make_start_query(self): if self._state is TransactionState.STARTED: raise errors.InterfaceError( 'cannot start; the transaction is already started') - return self._make_start_query_inner() + qry = self._make_start_query_inner() + if self._connection._top_xact is None: + self._connection._top_xact = self + return qry @abc.abstractmethod def _make_start_query_inner(self): @@ -101,6 +104,9 @@ def _make_start_query_inner(self): def _make_commit_query(self): self.__check_state('commit') + if self._connection._top_xact is self: + self._connection._top_xact = None + return 'COMMIT;' def _make_rollback_query(self): @@ -163,9 +169,7 @@ class RawTransaction(BaseTransaction): def _make_start_query_inner(self): con = self._connection - if con._top_xact is None: - con._top_xact = self - else: + if con._top_xact is not None: # Nested transaction block self._nested = True @@ -179,9 +183,6 @@ def _make_start_query_inner(self): def _make_commit_query(self): query = super()._make_commit_query() - if self._connection._top_xact is self: - self._connection._top_xact = None - if self._nested: query = f'RELEASE SAVEPOINT {self._id};' @@ -190,9 +191,6 @@ def _make_commit_query(self): def _make_rollback_query(self): query = super()._make_rollback_query() - if self._connection._top_xact is self: - self._connection._top_xact = None - if self._nested: query = f'ROLLBACK TO SAVEPOINT {self._id};' @@ -293,7 +291,7 @@ async def _ensure_transaction(self): class Retry: - def __init__(self, connection): + def __init__(self, connection, raw=False): self._connection = connection self._iteration = 0 self._done = False diff --git a/edb/testbase/server.py b/edb/testbase/server.py index f7e170a78d7..a92ba8e1918 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -359,6 +359,10 @@ def __getstate__(self): } +class RollbackException(Exception): + pass + + class RollbackChanges: def __init__(self, test): self._conn = test.con @@ -950,6 +954,20 @@ def repl(self): def _run_and_rollback(self): return RollbackChanges(self) + async def _run_and_rollback_retrying(self): + @contextlib.asynccontextmanager + async def cm(tx): + try: + async with tx: + await tx._ensure_transaction() + yield tx + raise RollbackException + except RollbackException: + pass + + async for tx in self.con.retrying_transaction(): + yield cm(tx) + def assert_data_shape(self, data, shape, message=None): assert_data_shape.assert_data_shape( data, shape, self.fail, message=message) diff --git a/tests/test_dump01.py b/tests/test_dump01.py index 32425326b8f..44fe38d1677 100644 --- a/tests/test_dump01.py +++ b/tests/test_dump01.py @@ -27,14 +27,11 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self, include_data=True): - tx = self.con.transaction() - await tx.start() - try: - await self._ensure_schema_integrity() - if include_data: - await self._ensure_data_integrity() - finally: - await tx.rollback() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._ensure_schema_integrity() + if include_data: + await self._ensure_data_integrity() async def _ensure_schema_integrity(self): # check that all the type annotations are in place diff --git a/tests/test_dump02.py b/tests/test_dump02.py index b2eed4b1064..393b51a9504 100644 --- a/tests/test_dump02.py +++ b/tests/test_dump02.py @@ -25,14 +25,11 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self, include_data=True): - tx = self.con.transaction() - await tx.start() - try: - await self._ensure_schema_integrity() - if include_data: - await self._ensure_data_integrity() - finally: - await tx.rollback() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._ensure_schema_integrity() + if include_data: + await self._ensure_data_integrity() async def _ensure_schema_integrity(self): # Check that index exists diff --git a/tests/test_dump03.py b/tests/test_dump03.py index 33751317dd3..4675cc737e6 100644 --- a/tests/test_dump03.py +++ b/tests/test_dump03.py @@ -25,12 +25,10 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self): - tx = self.con.transaction() - await tx.start() - try: + # We can't use _retrying here, since the sequences won't get reset. + # Hopefully this won't be a problem. + async with self._run_and_rollback(): await self._ensure_schema_data_integrity() - finally: - await tx.rollback() async def _ensure_schema_data_integrity(self): await self.assert_query_result( diff --git a/tests/test_dump_v2.py b/tests/test_dump_v2.py index ca42c579770..1c617f046f7 100644 --- a/tests/test_dump_v2.py +++ b/tests/test_dump_v2.py @@ -27,14 +27,11 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self, include_data=True): - tx = self.con.transaction() - await tx.start() - try: - await self._ensure_schema_integrity() - if include_data: - await self._ensure_data_integrity() - finally: - await tx.rollback() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._ensure_schema_integrity() + if include_data: + await self._ensure_data_integrity() async def _ensure_schema_integrity(self): # Validate access policies diff --git a/tests/test_dump_v3.py b/tests/test_dump_v3.py index 4c3682290b1..9a24a5adbea 100644 --- a/tests/test_dump_v3.py +++ b/tests/test_dump_v3.py @@ -25,12 +25,9 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self): - tx = self.con.transaction() - await tx.start() - try: - await self._ensure_schema_data_integrity() - finally: - await tx.rollback() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._ensure_schema_data_integrity() async def _ensure_schema_data_integrity(self): await self.assert_query_result( diff --git a/tests/test_dump_v4.py b/tests/test_dump_v4.py index 3fb6a08c217..f4e81e0b83b 100644 --- a/tests/test_dump_v4.py +++ b/tests/test_dump_v4.py @@ -25,13 +25,10 @@ class DumpTestCaseMixin: async def ensure_schema_data_integrity(self, include_secrets=False): - tx = self.con.transaction() - await tx.start() - try: - await self._ensure_schema_data_integrity( - include_secrets=include_secrets) - finally: - await tx.rollback() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._ensure_schema_data_integrity( + include_secrets=include_secrets) async def _ensure_schema_data_integrity(self, include_secrets): await self.assert_query_result( diff --git a/tests/test_edgeql_extensions.py b/tests/test_edgeql_extensions.py index f5e493c1bba..b13cc207040 100644 --- a/tests/test_edgeql_extensions.py +++ b/tests/test_edgeql_extensions.py @@ -117,8 +117,9 @@ async def test_edgeql_extensions_01(self): }; ''') try: - async with self._run_and_rollback(): - await self._extension_test_01() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._extension_test_01() finally: await self.con.execute(''' drop extension package ltree VERSION '1.0' @@ -314,10 +315,12 @@ async def test_edgeql_extensions_02(self): }; ''') try: - async with self._run_and_rollback(): - await self._extension_test_02a() - async with self._run_and_rollback(): - await self._extension_test_02b() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._extension_test_02a() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._extension_test_02b() finally: await self.con.execute(''' drop extension package varchar VERSION '1.0' @@ -842,8 +845,9 @@ async def test_edgeql_extensions_05(self): ''') try: - async with self._run_and_rollback(): - await self._extension_test_05(in_tx=True) + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._extension_test_05(in_tx=True) try: await self._extension_test_05(in_tx=False) finally: @@ -984,8 +988,9 @@ async def test_edgeql_extensions_06(self): }; ''') try: - async with self._run_and_rollback(): - await self._extension_test_06b() + async for tx in self._run_and_rollback_retrying(): + async with tx: + await self._extension_test_06b() finally: await self.con.execute(''' drop extension package bar VERSION '1.0';