diff --git a/edb/testbase/connection.py b/edb/testbase/connection.py index 658118044e1..906181cdb16 100644 --- a/edb/testbase/connection.py +++ b/edb/testbase/connection.py @@ -92,7 +92,10 @@ def _make_start_query(self): if self._state is TransactionState.STARTED: raise errors.InterfaceError( 'cannot start; the transaction is already started') - return self._make_start_query_inner() + qry = self._make_start_query_inner() + if self._connection._top_xact is None: + self._connection._top_xact = self + return qry @abc.abstractmethod def _make_start_query_inner(self): @@ -101,6 +104,9 @@ def _make_start_query_inner(self): def _make_commit_query(self): self.__check_state('commit') + if self._connection._top_xact is self: + self._connection._top_xact = None + return 'COMMIT;' def _make_rollback_query(self): @@ -163,9 +169,7 @@ class RawTransaction(BaseTransaction): def _make_start_query_inner(self): con = self._connection - if con._top_xact is None: - con._top_xact = self - else: + if con._top_xact is not None: # Nested transaction block self._nested = True @@ -179,9 +183,6 @@ def _make_start_query_inner(self): def _make_commit_query(self): query = super()._make_commit_query() - if self._connection._top_xact is self: - self._connection._top_xact = None - if self._nested: query = f'RELEASE SAVEPOINT {self._id};' @@ -190,9 +191,6 @@ def _make_commit_query(self): def _make_rollback_query(self): query = super()._make_rollback_query() - if self._connection._top_xact is self: - self._connection._top_xact = None - if self._nested: query = f'ROLLBACK TO SAVEPOINT {self._id};' @@ -222,7 +220,7 @@ def __init__(self, retry, connection, iteration): self._options = retry._options.transaction_options self.__retry = retry self.__iteration = iteration - self._started = False + self.__started = False async def __aenter__(self): if self._managed: @@ -233,7 +231,7 @@ async def __aenter__(self): async def __aexit__(self, extype, ex, tb): self._managed = False - if not self._started: + if not self.__started: return False try: @@ -283,8 +281,8 @@ async def _ensure_transaction(self): "Only managed retriable transactions are supported. " "Use `async with transaction:`" ) - if not self._started: - self._started = True + if not self.__started: + self.__started = True if self._connection.is_closed(): await self._connection.connect( single_attempt=self.__iteration != 0 @@ -292,12 +290,6 @@ async def _ensure_transaction(self): await self.start() -class RawIteration(RawTransaction, Iteration): - async def _ensure_transaction(self): - self._started = True - await super()._ensure_transaction() - - class Retry: def __init__(self, connection, raw=False): self._connection = connection @@ -305,7 +297,6 @@ def __init__(self, connection, raw=False): self._done = False self._next_backoff = 0 self._options = connection._options - self._raw = raw def _retry(self, exc): self._last_exception = exc @@ -327,8 +318,7 @@ async def __anext__(self): if self._next_backoff: await asyncio.sleep(self._next_backoff) self._done = True - cls = RawIteration if self._raw else Iteration - iteration = cls(self, self._connection, self._iteration) + iteration = Iteration(self, self._connection, self._iteration) self._iteration += 1 return iteration @@ -606,9 +596,6 @@ async def connect_addr(self): def retrying_transaction(self) -> Retry: return Retry(self) - def raw_retrying_transaction(self) -> Retry: - return Retry(self, raw=True) - def transaction(self) -> RawTransaction: return RawTransaction(self) diff --git a/edb/testbase/server.py b/edb/testbase/server.py index b45e585a73e..c9c858da608 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -959,12 +959,13 @@ async def _run_and_rollback_retrying(self): async def cm(tx): try: async with tx: + await tx._ensure_transaction() yield tx raise RollbackException except RollbackException: pass - async for tx in self.con.raw_retrying_transaction(): + async for tx in self.con.retrying_transaction(): yield cm(tx) def assert_data_shape(self, data, shape, message=None):