Skip to content

Commit

Permalink
Make Iteration able to contain subtransactions but not be one
Browse files Browse the repository at this point in the history
  • Loading branch information
msullivan committed Mar 1, 2024
1 parent c59baa4 commit 9ee1aea
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 27 deletions.
39 changes: 13 additions & 26 deletions edb/testbase/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def _make_start_query(self):
if self._state is TransactionState.STARTED:
raise errors.InterfaceError(
'cannot start; the transaction is already started')
return self._make_start_query_inner()
qry = self._make_start_query_inner()
if self._connection._top_xact is None:
self._connection._top_xact = self
return qry

@abc.abstractmethod
def _make_start_query_inner(self):
Expand All @@ -101,6 +104,9 @@ def _make_start_query_inner(self):
def _make_commit_query(self):
self.__check_state('commit')

if self._connection._top_xact is self:
self._connection._top_xact = None

return 'COMMIT;'

def _make_rollback_query(self):
Expand Down Expand Up @@ -163,9 +169,7 @@ class RawTransaction(BaseTransaction):
def _make_start_query_inner(self):
con = self._connection

if con._top_xact is None:
con._top_xact = self
else:
if con._top_xact is not None:
# Nested transaction block
self._nested = True

Expand All @@ -179,9 +183,6 @@ def _make_start_query_inner(self):
def _make_commit_query(self):
query = super()._make_commit_query()

if self._connection._top_xact is self:
self._connection._top_xact = None

if self._nested:
query = f'RELEASE SAVEPOINT {self._id};'

Expand All @@ -190,9 +191,6 @@ def _make_commit_query(self):
def _make_rollback_query(self):
query = super()._make_rollback_query()

if self._connection._top_xact is self:
self._connection._top_xact = None

if self._nested:
query = f'ROLLBACK TO SAVEPOINT {self._id};'

Expand Down Expand Up @@ -222,7 +220,7 @@ def __init__(self, retry, connection, iteration):
self._options = retry._options.transaction_options
self.__retry = retry
self.__iteration = iteration
self._started = False
self.__started = False

async def __aenter__(self):
if self._managed:
Expand All @@ -233,7 +231,7 @@ async def __aenter__(self):

async def __aexit__(self, extype, ex, tb):
self._managed = False
if not self._started:
if not self.__started:
return False

try:
Expand Down Expand Up @@ -283,29 +281,22 @@ async def _ensure_transaction(self):
"Only managed retriable transactions are supported. "
"Use `async with transaction:`"
)
if not self._started:
self._started = True
if not self.__started:
self.__started = True
if self._connection.is_closed():
await self._connection.connect(
single_attempt=self.__iteration != 0
)
await self.start()


class RawIteration(RawTransaction, Iteration):
async def _ensure_transaction(self):
self._started = True
await super()._ensure_transaction()


class Retry:
def __init__(self, connection, raw=False):
self._connection = connection
self._iteration = 0
self._done = False
self._next_backoff = 0
self._options = connection._options
self._raw = raw

def _retry(self, exc):
self._last_exception = exc
Expand All @@ -327,8 +318,7 @@ async def __anext__(self):
if self._next_backoff:
await asyncio.sleep(self._next_backoff)
self._done = True
cls = RawIteration if self._raw else Iteration
iteration = cls(self, self._connection, self._iteration)
iteration = Iteration(self, self._connection, self._iteration)
self._iteration += 1
return iteration

Expand Down Expand Up @@ -606,9 +596,6 @@ async def connect_addr(self):
def retrying_transaction(self) -> Retry:
return Retry(self)

def raw_retrying_transaction(self) -> Retry:
return Retry(self, raw=True)

def transaction(self) -> RawTransaction:
return RawTransaction(self)

Expand Down
3 changes: 2 additions & 1 deletion edb/testbase/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,12 +959,13 @@ async def _run_and_rollback_retrying(self):
async def cm(tx):
try:
async with tx:
await tx._ensure_transaction()
yield tx
raise RollbackException
except RollbackException:
pass

async for tx in self.con.raw_retrying_transaction():
async for tx in self.con.retrying_transaction():
yield cm(tx)

def assert_data_shape(self, data, shape, message=None):
Expand Down

0 comments on commit 9ee1aea

Please sign in to comment.