Skip to content

Commit

Permalink
Support custom wait function in RetryInvoker (#3183)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Apr 1, 2024
1 parent 531e0e3 commit 7ab8df0
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _load_client_app() -> ClientApp:
)

retry_invoker = RetryInvoker(
wait_factory=exponential,
wait_gen_factory=exponential,
recoverable_exceptions=connection_error_type,
max_tries=max_retries,
max_time=max_wait_time,
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def run_client() -> int:
server_address=f"[::]:{port}",
insecure=True,
retry_invoker=RetryInvoker(
wait_factory=exponential,
wait_gen_factory=exponential,
recoverable_exceptions=grpc.RpcError,
max_tries=1,
max_time=None,
Expand Down
37 changes: 24 additions & 13 deletions src/py/flwr/common/retry_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class RetryInvoker:
Parameters
----------
wait_factory: Callable[[], Generator[float, None, None]]
wait_gen_factory: Callable[[], Generator[float, None, None]]
A generator yielding successive wait times in seconds. If the generator
is finite, the giveup event will be triggered when the generator raises
`StopIteration`.
Expand All @@ -129,19 +129,26 @@ class RetryInvoker:
data class object detailing the invocation.
on_giveup: Optional[Callable[[RetryState], None]] (default: None)
A callable to be executed in the event that `max_tries` or `max_time` is
exceeded, `should_giveup` returns True, or `wait_factory()` generator raises
exceeded, `should_giveup` returns True, or `wait_gen_factory()` generator raises
`StopInteration`. The parameter is a data class object detailing the
invocation.
jitter: Optional[Callable[[float], float]] (default: full_jitter)
A function of the value yielded by `wait_factory()` returning the actual time
to wait. This function helps distribute wait times stochastically to avoid
A function of the value yielded by `wait_gen_factory()` returning the actual
time to wait. This function helps distribute wait times stochastically to avoid
timing collisions across concurrent clients. Wait times are jittered by
default using the `full_jitter` function. To disable jittering, pass
`jitter=None`.
should_giveup: Optional[Callable[[Exception], bool]] (default: None)
A function accepting an exception instance, returning whether or not
to give up prematurely before other give-up conditions are evaluated.
If set to None, the strategy is to never give up prematurely.
wait_function: Optional[Callable[[float], None]] (default: None)
A function that defines how to wait between retry attempts. It accepts
one argument, the wait time in seconds, allowing the use of various waiting
mechanisms (e.g., asynchronous waits or event-based synchronization) suitable
for different execution environments. If set to `None`, the `wait_function`
defaults to `time.sleep`, which is ideal for synchronous operations. Custom
functions should manage execution flow to prevent blocking or interference.
Examples
--------
Expand All @@ -159,7 +166,7 @@ class RetryInvoker:
# pylint: disable-next=too-many-arguments
def __init__(
self,
wait_factory: Callable[[], Generator[float, None, None]],
wait_gen_factory: Callable[[], Generator[float, None, None]],
recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
max_tries: Optional[int],
max_time: Optional[float],
Expand All @@ -169,8 +176,9 @@ def __init__(
on_giveup: Optional[Callable[[RetryState], None]] = None,
jitter: Optional[Callable[[float], float]] = full_jitter,
should_giveup: Optional[Callable[[Exception], bool]] = None,
wait_function: Optional[Callable[[float], None]] = None,
) -> None:
self.wait_factory = wait_factory
self.wait_gen_factory = wait_gen_factory
self.recoverable_exceptions = recoverable_exceptions
self.max_tries = max_tries
self.max_time = max_time
Expand All @@ -179,6 +187,9 @@ def __init__(
self.on_giveup = on_giveup
self.jitter = jitter
self.should_giveup = should_giveup
if wait_function is None:
wait_function = time.sleep
self.wait_function = wait_function

# pylint: disable-next=too-many-locals
def invoke(
Expand Down Expand Up @@ -212,13 +223,13 @@ def invoke(
Raises
------
Exception
If the number of tries exceeds `max_tries`, if the total time
exceeds `max_time`, if `wait_factory()` generator raises `StopInteration`,
If the number of tries exceeds `max_tries`, if the total time exceeds
`max_time`, if `wait_gen_factory()` generator raises `StopInteration`,
or if the `should_giveup` returns True for a raised exception.
Notes
-----
The time between retries is determined by the provided `wait_factory()`
The time between retries is determined by the provided `wait_gen_factory()`
generator and can optionally be jittered using the `jitter` function.
The recoverable exceptions that trigger a retry, as well as conditions to
stop retries, are also determined by the class's initialization parameters.
Expand All @@ -231,13 +242,13 @@ def try_call_event_handler(
handler(cast(RetryState, ref_state[0]))

try_cnt = 0
wait_generator = self.wait_factory()
start = time.time()
wait_generator = self.wait_gen_factory()
start = time.monotonic()
ref_state: List[Optional[RetryState]] = [None]

while True:
try_cnt += 1
elapsed_time = time.time() - start
elapsed_time = time.monotonic() - start
state = RetryState(
target=target,
args=args,
Expand Down Expand Up @@ -282,7 +293,7 @@ def giveup_check(_exception: Exception) -> bool:
try_call_event_handler(self.on_backoff)

# Sleep
time.sleep(wait_time)
self.wait_function(state.actual_wait)
else:
# Trigger success event
try_call_event_handler(self.on_success)
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/common/retry_invoker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def failing_function() -> None:

@pytest.fixture(name="mock_time")
def fixture_mock_time() -> Generator[MagicMock, None, None]:
"""Mock time.time for controlled testing."""
with patch("time.time") as mock_time:
"""Mock time.monotonic for controlled testing."""
with patch("time.monotonic") as mock_time:
yield mock_time


Expand Down

0 comments on commit 7ab8df0

Please sign in to comment.