Skip to content

Commit

Permalink
feat(framework) Allow clients to exit gracefully (#3090)
Browse files Browse the repository at this point in the history
Co-authored-by: Heng Pan <[email protected]>
Co-authored-by: Javier <[email protected]>
  • Loading branch information
3 people authored Jun 10, 2024
1 parent 8d08e12 commit e182983
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
28 changes: 20 additions & 8 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ def _load_client_app() -> ClientApp:
transport, server_address
)

run_tracker = _RunTracker()
_ = run_tracker
app_state_tracker = _AppStateTracker()

def _on_sucess(retry_state: RetryState) -> None:
app_state_tracker.is_connected = True
if retry_state.tries > 1:
log(
INFO,
Expand All @@ -278,6 +278,7 @@ def _on_sucess(retry_state: RetryState) -> None:
)

def _on_backoff(retry_state: RetryState) -> None:
app_state_tracker.is_connected = False
if retry_state.tries == 1:
log(WARN, "Connection attempt failed, retrying...")
else:
Expand Down Expand Up @@ -308,7 +309,7 @@ def _on_backoff(retry_state: RetryState) -> None:

node_state = NodeState()

while True:
while not app_state_tracker.interrupt:
sleep_duration: int = 0
with connection(
address,
Expand All @@ -325,8 +326,9 @@ def _on_backoff(retry_state: RetryState) -> None:
if create_node is not None:
create_node() # pylint: disable=not-callable

while True:
if True: # pylint: disable=using-constant-test
app_state_tracker.register_signal_handler()
while not app_state_tracker.interrupt:
try:
# Receive
message = receive()
if message is None:
Expand Down Expand Up @@ -397,7 +399,10 @@ def _on_backoff(retry_state: RetryState) -> None:
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
exc_entity = "SuperNode"

log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
if not app_state_tracker.interrupt:
log(
ERROR, "%s raised an exception", exc_entity, exc_info=ex
)

# Create error message
reply_message = message.create_error_reply(
Expand All @@ -414,13 +419,19 @@ def _on_backoff(retry_state: RetryState) -> None:
send(reply_message)
log(INFO, "Sent reply")

except StopIteration:
sleep_duration = 0
break

# Unregister node
if delete_node is not None:
if delete_node is not None and app_state_tracker.is_connected:
delete_node() # pylint: disable=not-callable

if sleep_duration == 0:
log(INFO, "Disconnect and shut down")
del app_state_tracker
break

# Sleep and reconnect afterwards
log(
INFO,
Expand Down Expand Up @@ -592,8 +603,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[


@dataclass
class _RunTracker:
class _AppStateTracker:
interrupt: bool = False
is_connected: bool = False

def register_signal_handler(self) -> None:
"""Register handlers for exit signals."""
Expand Down
2 changes: 0 additions & 2 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,6 @@ def delete_node() -> None:

# Stop the ping-loop thread
ping_stop_event.set()
if ping_thread is not None:
ping_thread.join()

# Call FleetAPI
delete_node_request = DeleteNodeRequest(node=node)
Expand Down

0 comments on commit e182983

Please sign in to comment.