Skip to content

Commit

Permalink
Implement _run_and_rollback_retrying (#6960)
Browse files Browse the repository at this point in the history
In general we should be using retrying transactions. Our isolated
tests do retries at the outside, but non-isolated ones don't
necessarily.

Implement _run_and_rollback_retrying and use it some places that
TransactionSerializationError flakes have been seen (in #6954 and
in nightly CI).

Honestly I don't really think that these places ought to be having
TransactionSerializationErrors, but also it really is pretty required
to do retries when using SERIALIZABLE, so we should start moving
towards being more consistent about it.
  • Loading branch information
msullivan committed Mar 7, 2024
1 parent 98fad1d commit 89d8db3
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 63 deletions.
20 changes: 9 additions & 11 deletions edb/testbase/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def _make_start_query(self):
if self._state is TransactionState.STARTED:
raise errors.InterfaceError(
'cannot start; the transaction is already started')
return self._make_start_query_inner()
qry = self._make_start_query_inner()
if self._connection._top_xact is None:
self._connection._top_xact = self
return qry

@abc.abstractmethod
def _make_start_query_inner(self):
Expand All @@ -101,6 +104,9 @@ def _make_start_query_inner(self):
def _make_commit_query(self):
self.__check_state('commit')

if self._connection._top_xact is self:
self._connection._top_xact = None

return 'COMMIT;'

def _make_rollback_query(self):
Expand Down Expand Up @@ -163,9 +169,7 @@ class RawTransaction(BaseTransaction):
def _make_start_query_inner(self):
con = self._connection

if con._top_xact is None:
con._top_xact = self
else:
if con._top_xact is not None:
# Nested transaction block
self._nested = True

Expand All @@ -179,9 +183,6 @@ def _make_start_query_inner(self):
def _make_commit_query(self):
query = super()._make_commit_query()

if self._connection._top_xact is self:
self._connection._top_xact = None

if self._nested:
query = f'RELEASE SAVEPOINT {self._id};'

Expand All @@ -190,9 +191,6 @@ def _make_commit_query(self):
def _make_rollback_query(self):
query = super()._make_rollback_query()

if self._connection._top_xact is self:
self._connection._top_xact = None

if self._nested:
query = f'ROLLBACK TO SAVEPOINT {self._id};'

Expand Down Expand Up @@ -293,7 +291,7 @@ async def _ensure_transaction(self):


class Retry:
def __init__(self, connection):
def __init__(self, connection, raw=False):
self._connection = connection
self._iteration = 0
self._done = False
Expand Down
18 changes: 18 additions & 0 deletions edb/testbase/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def __getstate__(self):
}


class RollbackException(Exception):
pass


class RollbackChanges:
def __init__(self, test):
self._conn = test.con
Expand Down Expand Up @@ -950,6 +954,20 @@ def repl(self):
def _run_and_rollback(self):
return RollbackChanges(self)

async def _run_and_rollback_retrying(self):
@contextlib.asynccontextmanager
async def cm(tx):
try:
async with tx:
await tx._ensure_transaction()
yield tx
raise RollbackException
except RollbackException:
pass

async for tx in self.con.retrying_transaction():
yield cm(tx)

def assert_data_shape(self, data, shape, message=None):
assert_data_shape.assert_data_shape(
data, shape, self.fail, message=message)
Expand Down
13 changes: 5 additions & 8 deletions tests/test_dump01.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,11 @@
class DumpTestCaseMixin:

async def ensure_schema_data_integrity(self, include_data=True):
tx = self.con.transaction()
await tx.start()
try:
await self._ensure_schema_integrity()
if include_data:
await self._ensure_data_integrity()
finally:
await tx.rollback()
async for tx in self._run_and_rollback_retrying():
async with tx:
await self._ensure_schema_integrity()
if include_data:
await self._ensure_data_integrity()

async def _ensure_schema_integrity(self):
# check that all the type annotations are in place
Expand Down
13 changes: 5 additions & 8 deletions tests/test_dump02.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,11 @@
class DumpTestCaseMixin:

async def ensure_schema_data_integrity(self, include_data=True):
tx = self.con.transaction()
await tx.start()
try:
await self._ensure_schema_integrity()
if include_data:
await self._ensure_data_integrity()
finally:
await tx.rollback()
async for tx in self._run_and_rollback_retrying():
async with tx:
await self._ensure_schema_integrity()
if include_data:
await self._ensure_data_integrity()

async def _ensure_schema_integrity(self):
# Check that index exists
Expand Down
8 changes: 3 additions & 5 deletions tests/test_dump03.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@
class DumpTestCaseMixin:

async def ensure_schema_data_integrity(self):
tx = self.con.transaction()
await tx.start()
try:
# We can't use _retrying here, since the sequences won't get reset.
# Hopefully this won't be a problem.
async with self._run_and_rollback():
await self._ensure_schema_data_integrity()
finally:
await tx.rollback()

async def _ensure_schema_data_integrity(self):
await self.assert_query_result(
Expand Down
13 changes: 5 additions & 8 deletions tests/test_dump_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,11 @@
class DumpTestCaseMixin:

async def ensure_schema_data_integrity(self, include_data=True):
tx = self.con.transaction()
await tx.start()
try:
await self._ensure_schema_integrity()
if include_data:
await self._ensure_data_integrity()
finally:
await tx.rollback()
async for tx in self._run_and_rollback_retrying():
async with tx:
await self._ensure_schema_integrity()
if include_data:
await self._ensure_data_integrity()

async def _ensure_schema_integrity(self):
# Validate access policies
Expand Down
9 changes: 3 additions & 6 deletions tests/test_dump_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@
class DumpTestCaseMixin:

async def ensure_schema_data_integrity(self):
tx = self.con.transaction()
await tx.start()
try:
await self._ensure_schema_data_integrity()
finally:
await tx.rollback()
async for tx in self._run_and_rollback_retrying():
async with tx:
await self._ensure_schema_data_integrity()

async def _ensure_schema_data_integrity(self):
await self.assert_query_result(
Expand Down
11 changes: 4 additions & 7 deletions tests/test_dump_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,10 @@
class DumpTestCaseMixin:

async def ensure_schema_data_integrity(self, include_secrets=False):
tx = self.con.transaction()
await tx.start()
try:
await self._ensure_schema_data_integrity(
include_secrets=include_secrets)
finally:
await tx.rollback()
async for tx in self._run_and_rollback_retrying():
async with tx:
await self._ensure_schema_data_integrity(
include_secrets=include_secrets)

async def _ensure_schema_data_integrity(self, include_secrets):
await self.assert_query_result(
Expand Down
25 changes: 15 additions & 10 deletions tests/test_edgeql_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ async def test_edgeql_extensions_01(self):
};
''')
try:
async with self._run_and_rollback():
await self._extension_test_01()
async for tx in self._run_and_rollback_retrying():
async with tx:
await self._extension_test_01()
finally:
await self.con.execute('''
drop extension package ltree VERSION '1.0'
Expand Down Expand Up @@ -314,10 +315,12 @@ async def test_edgeql_extensions_02(self):
};
''')
try:
async with self._run_and_rollback():
await self._extension_test_02a()
async with self._run_and_rollback():
await self._extension_test_02b()
async for tx in self._run_and_rollback_retrying():
async with tx:
await self._extension_test_02a()
async for tx in self._run_and_rollback_retrying():
async with tx:
await self._extension_test_02b()
finally:
await self.con.execute('''
drop extension package varchar VERSION '1.0'
Expand Down Expand Up @@ -842,8 +845,9 @@ async def test_edgeql_extensions_05(self):
''')

try:
async with self._run_and_rollback():
await self._extension_test_05(in_tx=True)
async for tx in self._run_and_rollback_retrying():
async with tx:
await self._extension_test_05(in_tx=True)
try:
await self._extension_test_05(in_tx=False)
finally:
Expand Down Expand Up @@ -984,8 +988,9 @@ async def test_edgeql_extensions_06(self):
};
''')
try:
async with self._run_and_rollback():
await self._extension_test_06b()
async for tx in self._run_and_rollback_retrying():
async with tx:
await self._extension_test_06b()
finally:
await self.con.execute('''
drop extension package bar VERSION '1.0';
Expand Down

0 comments on commit 89d8db3

Please sign in to comment.