Skip to content

Commit

Permalink
Retry calls to execute as well as to query (#577)
Browse files Browse the repository at this point in the history
I think originally `execute` didn't support retries because we only retried
read-only queries. But now we also retry transaction errors, so we should
support it on both.

(I modeled the implementation after how I did this in the edgedb test
suite's hacked up client: edgedb/edgedb#8249)
  • Loading branch information
msullivan authored Feb 13, 2025
1 parent d0eec6c commit e263f2b
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 26 deletions.
8 changes: 7 additions & 1 deletion gel/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def lower(
class ExecuteContext(typing.NamedTuple):
query: QueryWithArgs
cache: QueryCache
retry_options: typing.Optional[options.RetryOptions]
state: typing.Optional[options.State]
warning_handler: options.WarningHandler
annotations: typing.Dict[str, str]
Expand Down Expand Up @@ -187,8 +188,9 @@ class BaseReadOnlyExecutor(abc.ABC):
def _get_query_cache(self) -> QueryCache:
...

@abc.abstractmethod
def _get_retry_options(self) -> typing.Optional[options.RetryOptions]:
return None
...

@abc.abstractmethod
def _get_state(self) -> options.State:
Expand Down Expand Up @@ -303,6 +305,7 @@ def execute(self, commands: str, *args, **kwargs) -> None:
self._execute(ExecuteContext(
query=QueryWithArgs(commands, args, kwargs),
cache=self._get_query_cache(),
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
Expand All @@ -317,6 +320,7 @@ def execute_sql(self, commands: str, *args, **kwargs) -> None:
input_language=protocol.InputLanguage.SQL,
),
cache=self._get_query_cache(),
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
Expand Down Expand Up @@ -438,6 +442,7 @@ async def execute(self, commands: str, *args, **kwargs) -> None:
await self._execute(ExecuteContext(
query=QueryWithArgs(commands, args, kwargs),
cache=self._get_query_cache(),
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
Expand All @@ -452,6 +457,7 @@ async def execute_sql(self, commands: str, *args, **kwargs) -> None:
input_language=protocol.InputLanguage.SQL,
),
cache=self._get_query_cache(),
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
Expand Down
58 changes: 37 additions & 21 deletions gel/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,32 +197,18 @@ def is_in_transaction(self) -> bool:
def get_settings(self) -> typing.Dict[str, typing.Any]:
return self._protocol.get_settings()

async def raw_query(self, query_context: abstract.QueryContext):
if self.is_closed():
await self.connect()

async def _retry_operation(self, func, retry_options, ctx):
reconnect = False
i = 0
if self._protocol.is_legacy:
allow_capabilities = enums.Capability.LEGACY_EXECUTE
else:
allow_capabilities = enums.Capability.EXECUTE
ctx = query_context.lower(allow_capabilities=allow_capabilities)
while True:
i += 1
try:
if reconnect:
await self.connect(single_attempt=True)
if self._protocol.is_legacy:
return await self._protocol.legacy_execute_anonymous(ctx)
else:
res = await self._protocol.query(ctx)
if ctx.warnings:
res = query_context.warning_handler(ctx.warnings, res)
return res
return await func()

except errors.EdgeDBError as e:
if query_context.retry_options is None:
if retry_options is None:
raise
if not e.has_tag(errors.SHOULD_RETRY):
raise e
Expand All @@ -234,12 +220,37 @@ async def raw_query(self, query_context: abstract.QueryContext):
and not isinstance(e, errors.TransactionConflictError)
):
raise e
rule = query_context.retry_options.get_rule_for_exception(e)
rule = retry_options.get_rule_for_exception(e)
if i >= rule.attempts:
raise e
await self.sleep(rule.backoff(i))
reconnect = self.is_closed()

async def raw_query(self, query_context: abstract.QueryContext):
if self.is_closed():
await self.connect()

reconnect = False
i = 0
if self._protocol.is_legacy:
allow_capabilities = enums.Capability.LEGACY_EXECUTE
else:
allow_capabilities = enums.Capability.EXECUTE
ctx = query_context.lower(allow_capabilities=allow_capabilities)

async def _inner():
if self._protocol.is_legacy:
return await self._protocol.legacy_execute_anonymous(ctx)
else:
res = await self._protocol.query(ctx)
if ctx.warnings:
res = query_context.warning_handler(ctx.warnings, res)
return res

return await self._retry_operation(
_inner, query_context.retry_options, ctx
)

async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
if self._protocol.is_legacy:
if execute_context.query.args or execute_context.query.kwargs:
Expand All @@ -253,9 +264,14 @@ async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
ctx = execute_context.lower(
allow_capabilities=enums.Capability.EXECUTE
)
res = await self._protocol.execute(ctx)
if ctx.warnings:
res = execute_context.warning_handler(ctx.warnings, res)
async def _inner():
res = await self._protocol.execute(ctx)
if ctx.warnings:
res = execute_context.warning_handler(ctx.warnings, res)

return await self._retry_operation(
_inner, execute_context.retry_options, ctx
)

async def describe(
self, describe_context: abstract.DescribeContext
Expand Down
4 changes: 4 additions & 0 deletions gel/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ async def _exit(self, extype, ex):
def _get_query_cache(self) -> abstract.QueryCache:
return self._client._get_query_cache()

def _get_retry_options(self) -> typing.Optional[options.RetryOptions]:
return None

def _get_state(self) -> options.State:
return self._client._get_state()

Expand All @@ -206,6 +209,7 @@ async def _privileged_execute(self, query: str) -> None:
query=abstract.QueryWithArgs(query, (), {}),
cache=self._get_query_cache(),
state=self._get_state(),
retry_options=self._get_retry_options(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))
Expand Down
71 changes: 67 additions & 4 deletions tests/test_async_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ class TestAsyncRetry(tb.AsyncQueryTestCase):
};
'''

TEARDOWN = '''
DROP TYPE test::Counter;
'''

async def test_async_retry_01(self):
async for tx in self.client.transaction():
async with tx:
Expand Down Expand Up @@ -206,6 +202,73 @@ async def transaction1(client):
self.assertEqual(set(results), {1, 2})
self.assertEqual(iterations, 3)

async def test_async_retry_conflict_nontx_01(self):
await self.execute_nontx_conflict(
'counter_nontx_01',
lambda client, *args, **kwargs: client.query(*args, **kwargs)
)

async def test_async_retry_conflict_nontx_02(self):
await self.execute_nontx_conflict(
'counter_nontx_02',
lambda client, *args, **kwargs: client.execute(*args, **kwargs)
)

async def execute_nontx_conflict(self, name, func):
# Test retries on conflicts in a non-tx setting. We do this
# by having conflicting upserts that are made long-running by
# adding a sys::_sleep call.
#
# Unlike for the tx ones, we don't assert that a retry
# actually was necessary, since that feels fragile in a
# timing-based test like this.

client1 = self.client
client2 = self.make_test_client(database=self.get_database_name())
self.addCleanup(client2.aclose)

await client1.query("SELECT 1")
await client2.query("SELECT 1")

query = '''
SELECT (
INSERT test::Counter {
name := <str>$name,
value := 1,
} UNLESS CONFLICT ON .name
ELSE (
UPDATE test::Counter
SET { value := .value + 1 }
)
).value
ORDER BY sys::_sleep(<int64>$sleep)
THEN <int64>$nonce
'''

await func(client1, query, name=name, sleep=0, nonce=0)

task1 = asyncio.create_task(
func(client1, query, name=name, sleep=5, nonce=1)
)
task2 = asyncio.create_task(
func(client2, query, name=name, sleep=5, nonce=2)
)

results = await asyncio.wait_for(asyncio.gather(
task1,
task2,
return_exceptions=True,
), 20)

excs = [e for e in results if isinstance(e, BaseException)]
if excs:
raise excs[0]
val = await client1.query_single('''
select (select test::Counter filter .name = <str>$name).value
''', name=name)

self.assertEqual(val, 3)

async def test_async_transaction_interface_errors(self):
with self.assertRaisesRegex(
AttributeError,
Expand Down

0 comments on commit e263f2b

Please sign in to comment.