diff --git a/releasenotes/notes/allow-retry-callables-ba921a2b57229540.yaml b/releasenotes/notes/allow-retry-callables-ba921a2b57229540.yaml new file mode 100644 index 0000000..e7be1cd --- /dev/null +++ b/releasenotes/notes/allow-retry-callables-ba921a2b57229540.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + Allow for callables to be combined as retry values. This will only + work when used combined with their corresponding implementation + retry objects, e.g. only async functions will work when used together + with async retry strategies. diff --git a/tenacity/retry.py b/tenacity/retry.py index 9211631..69329e7 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -18,6 +18,13 @@ import re import typing +from . import _utils + +try: + import tornado +except ImportError: + tornado = None + if typing.TYPE_CHECKING: from tenacity import RetryCallState @@ -30,15 +37,29 @@ def __call__(self, retry_state: "RetryCallState") -> bool: pass def __and__(self, other: "retry_base") -> "retry_all": - return other.__rand__(self) + if isinstance(other, retry_base): + # Delegate to the other object to allow for specific + # implementations, such as asyncio + return other.__rand__(self) + return retry_all(other, self) def __rand__(self, other: "retry_base") -> "retry_all": + # This is automatically invoked for inheriting classes, + # so it helps to keep the abstraction and delegate specific + # implementations, such as asyncio return retry_all(other, self) def __or__(self, other: "retry_base") -> "retry_any": - return other.__ror__(self) + if isinstance(other, retry_base): + # Delegate to the other object to allow for specific + # implementations, such as asyncio + return other.__ror__(self) + return retry_any(other, self) def __ror__(self, other: "retry_base") -> "retry_any": + # This is automatically invoked for inheriting classes, + # so it helps to keep the abstraction and delegate specific + # implementations, such as asyncio return retry_any(other, self) @@ -269,7 +290,22 @@ def __init__(self, *retries: retry_base) -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: - return any(r(retry_state) for r in self.retries) + result = False + for r in self.retries: + if _utils.is_coroutine_callable(r) or ( + tornado + and hasattr(tornado.gen, "is_coroutine_function") + and tornado.gen.is_coroutine_function(r) + ): + raise TypeError( + "Cannot use async functions in a sync context. Make sure " + "you use the correct retrying object and the corresponding " + "async strategies" + ) + result = result or r(retry_state) + if result: + break + return result class retry_all(retry_base): @@ -279,4 +315,19 @@ def __init__(self, *retries: retry_base) -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: - return all(r(retry_state) for r in self.retries) + result = True + for r in self.retries: + if _utils.is_coroutine_callable(r) or ( + tornado + and hasattr(tornado.gen, "is_coroutine_function") + and tornado.gen.is_coroutine_function(r) + ): + raise TypeError( + "Cannot use async functions in a sync context. Make sure " + "you use the correct retrying object and the corresponding " + "async strategies" + ) + result = result and r(retry_state) + if not result: + break + return result diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 0b74476..2ed5679 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -29,7 +29,7 @@ import pytest import tenacity -from tenacity import AsyncRetrying, RetryError +from tenacity import AsyncRetrying, RetryCallState, RetryError from tenacity import asyncio as tasyncio from tenacity import retry, retry_if_exception, retry_if_result, stop_after_attempt from tenacity.wait import wait_fixed @@ -309,6 +309,103 @@ def is_exc(e: BaseException) -> bool: self.assertEqual(4, result) + @asynctest + async def test_retry_with_async_result_or_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tasyncio.retry_if_result(lt_3) | should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_or_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tasyncio.retry_if_result(lt_3) | should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_or_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tenacity.retry_if_result(lt_3) | should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + try: + await test() + except TypeError as exc: + self.assertEqual( + str(exc), + "Cannot use async functions in a sync context. Make sure you use " + "the correct retrying object and the corresponding async strategies", + ) + else: + self.fail("This is an invalid retry combination that should have failed") + @asynctest async def test_retry_with_async_result_ror(self): async def test(): @@ -340,6 +437,103 @@ async def is_exc(e: BaseException) -> bool: self.assertEqual(4, result) + @asynctest + async def test_retry_with_async_result_ror_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_ror_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_ror_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tenacity.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + try: + await test() + except TypeError as exc: + self.assertEqual( + str(exc), + "Cannot use async functions in a sync context. Make sure you use " + "the correct retrying object and the corresponding async strategies", + ) + else: + self.fail("This is an invalid retry combination that should have failed") + @asynctest async def test_retry_with_async_result_and(self): async def test(): @@ -363,6 +557,94 @@ def gt_0(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result_and_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tasyncio.retry_if_result(lt_3) & should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_and_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tasyncio.retry_if_result(lt_3) & should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_and_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tenacity.retry_if_result(lt_3) & should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + try: + await test() + except TypeError as exc: + self.assertEqual( + str(exc), + "Cannot use async functions in a sync context. Make sure you use " + "the correct retrying object and the corresponding async strategies", + ) + else: + self.fail("This is an invalid retry combination that should have failed") + @asynctest async def test_retry_with_async_result_rand(self): async def test(): @@ -386,6 +668,94 @@ def gt_0(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result_rand_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_rand_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_rand_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tenacity.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + try: + await test() + except TypeError as exc: + self.assertEqual( + str(exc), + "Cannot use async functions in a sync context. Make sure you use " + "the correct retrying object and the corresponding async strategies", + ) + else: + self.fail("This is an invalid retry combination that should have failed") + @asynctest async def test_async_retying_iterator(self): thing = NoIOErrorAfterCount(5) diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index ecc0312..9f6068a 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -634,6 +634,58 @@ def r(fut): self.assertFalse(r(tenacity.Future.construct(1, 3, False))) self.assertFalse(r(tenacity.Future.construct(1, 1, True))) + async def test_retry_and_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tenacity.retry_if_result(lt_3) & should_retry # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + + async def test_retry_rand_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tenacity.retry_if_result(lt_3) # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + def test_retry_or(self): retry = tenacity.retry_if_result( lambda x: x == "foo" @@ -648,6 +700,64 @@ def r(fut): self.assertFalse(r(tenacity.Future.construct(1, 2.2, False))) self.assertFalse(r(tenacity.Future.construct(1, 42, True))) + def test_retry_or_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tenacity.retry_if_result(lt_3) | should_retry # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + + def test_retry_ror_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tenacity.retry_if_result(lt_3) # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + def _raise_try_again(self): self._attempts += 1 if self._attempts < 3: