Skip to content

Commit

Permalink
Merge branch 'main' into dead-node-err-reply
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Apr 2, 2024
2 parents e7ea7cb + f95d641 commit 71f74d9
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 51 deletions.
2 changes: 1 addition & 1 deletion datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ exclude = [
[tool.poetry.dependencies]
python = "^3.8"
numpy = "^1.21.0"
datasets = "^2.14.3"
datasets = "^2.14.6"
pillow = { version = ">=6.2.1", optional = true }
soundfile = { version = ">=0.12.1", optional = true }
librosa = { version = ">=0.10.0.post2", optional = true }
Expand Down
33 changes: 24 additions & 9 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_REST,
TRANSPORT_TYPES,
ErrorCode,
)
from flwr.common.exit_handlers import register_exit_handlers
from flwr.common.logger import log, warn_deprecated_feature
Expand Down Expand Up @@ -483,35 +484,49 @@ def _load_client_app() -> ClientApp:
# Create an error reply message that will never be used to prevent
# the used-before-assignment linting error
reply_message = message.create_error_reply(
error=Error(code=0, reason="Unknown")
error=Error(code=ErrorCode.UNKNOWN, reason="Unknown")
)

# Handle app loading and task message
try:
# Load ClientApp instance
client_app: ClientApp = load_client_app_fn()

# Execute ClientApp
reply_message = client_app(message=message, context=context)
# Update node state
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)
except Exception as ex: # pylint: disable=broad-exception-caught
log(ERROR, "ClientApp raised an exception", exc_info=ex)

# Legacy grpc-bidi
if transport in ["grpc-bidi", None]:
log(ERROR, "Client raised an exception.", exc_info=ex)
# Raise exception, crash process
raise ex

# Don't update/change NodeState

# Create error message
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
# Reason example: "<class 'ZeroDivisionError'>:<'division by zero'>"
reason = str(type(ex)) + ":<'" + str(ex) + "'>"
exc_entity = "ClientApp"
if isinstance(ex, LoadClientAppError):
reason = (
"An exception was raised when attempting to load "
"`ClientApp`"
)
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
exc_entity = "SuperNode"

log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)

# Create error message
reply_message = message.create_error_reply(
error=Error(code=0, reason=reason)
error=Error(code=e_code, reason=reason)
)
else:
# No exception, update node state
node_state.update_context(
run_id=message.metadata.run_id,
context=context,
)

# Send
Expand Down
9 changes: 9 additions & 0 deletions src/py/flwr/client/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
from .typing import ClientAppCallable


class ClientAppException(Exception):
"""Exception raised when an exception is raised while executing a ClientApp."""

def __init__(self, message: str):
ex_name = self.__class__.__name__
self.message = f"\nException {ex_name} occurred. Message: " + message
super().__init__(self.message)


class ClientApp:
"""Flower ClientApp.
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
PING_CALL_TIMEOUT = 5
PING_BASE_MULTIPLIER = 0.8
PING_RANDOM_RANGE = (-0.1, 0.1)
PING_MAX_INTERVAL = 1e300


class MessageType:
Expand Down Expand Up @@ -80,7 +81,8 @@ class ErrorCode:
"""Error codes for Message's Error."""

UNKNOWN = 0
CLIENT_APP_RAISED_EXCEPTION = 1
LOAD_CLIENT_APP_EXCEPTION = 1
CLIENT_APP_RAISED_EXCEPTION = 2

def __new__(cls) -> ErrorCode:
"""Prevent instantiation."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def create_node(
) -> CreateNodeResponse:
"""."""
# Create node
node_id = state.create_node()
node_id = state.create_node(ping_interval=request.ping_interval)
return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False))


Expand Down
18 changes: 14 additions & 4 deletions src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from logging import DEBUG, ERROR, INFO, WARN
from typing import Callable, Dict, List, Optional

from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
from flwr.client.node_state import NodeState
from flwr.common.constant import PING_MAX_INTERVAL, ErrorCode
from flwr.common.logger import log
from flwr.common.message import Error
from flwr.common.object_ref import load_app
Expand All @@ -43,7 +44,7 @@ def _register_nodes(
nodes_mapping: NodeToPartitionMapping = {}
state = state_factory.state()
for i in range(num_nodes):
node_id = state.create_node()
node_id = state.create_node(ping_interval=PING_MAX_INTERVAL)
nodes_mapping[node_id] = i
log(INFO, "Registered %i nodes", len(nodes_mapping))
return nodes_mapping
Expand Down Expand Up @@ -93,9 +94,18 @@ async def worker(
except Exception as ex: # pylint: disable=broad-exception-caught
log(ERROR, ex)
log(ERROR, traceback.format_exc())

if isinstance(ex, ClientAppException):
e_code = ErrorCode.CLIENT_APP_RAISED_EXCEPTION
elif isinstance(ex, LoadClientAppError):
e_code = ErrorCode.LOAD_CLIENT_APP_EXCEPTION
else:
e_code = ErrorCode.UNKNOWN

reason = str(type(ex)) + ":<'" + str(ex) + "'>"
error = Error(code=0, reason=reason)
out_mssg = message.create_error_reply(error=error)
out_mssg = message.create_error_reply(
error=Error(code=e_code, reason=reason)
)

finally:
if out_mssg:
Expand Down
6 changes: 2 additions & 4 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,14 @@ def num_task_res(self) -> int:
"""
return len(self.task_res_store)

def create_node(self) -> int:
def create_node(self, ping_interval: float) -> int:
"""Create, store in state, and return `node_id`."""
# Sample a random int64 as node_id
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)

with self.lock:
if node_id not in self.node_ids:
# Default ping interval is 30s
# TODO: change 1e9 to 30s # pylint: disable=W0511
self.node_ids[node_id] = (time.time() + 1e9, 1e9)
self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
return node_id
log(ERROR, "Unexpected node registration failure.")
return 0
Expand Down
6 changes: 2 additions & 4 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None:

return None

def create_node(self) -> int:
def create_node(self, ping_interval: float) -> int:
"""Create, store in state, and return `node_id`."""
# Sample a random int64 as node_id
node_id: int = int.from_bytes(os.urandom(8), "little", signed=True)
Expand All @@ -528,9 +528,7 @@ def create_node(self) -> int:
)

try:
# Default ping interval is 30s
# TODO: change 1e9 to 30s # pylint: disable=W0511
self.query(query, (node_id, time.time() + 1e9, 1e9))
self.query(query, (node_id, time.time() + ping_interval, ping_interval))
except sqlite3.IntegrityError:
log(ERROR, "Unexpected node registration failure.")
return 0
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None:
"""Delete all delivered TaskIns/TaskRes pairs."""

@abc.abstractmethod
def create_node(self) -> int:
def create_node(self, ping_interval: float) -> int:
"""Create, store in state, and return `node_id`."""

@abc.abstractmethod
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_create_node_and_get_nodes(self) -> None:

# Execute
for _ in range(10):
node_ids.append(state.create_node())
node_ids.append(state.create_node(ping_interval=10))
retrieved_node_ids = state.get_nodes(run_id)

# Assert
Expand All @@ -331,7 +331,7 @@ def test_delete_node(self) -> None:
# Prepare
state: State = self.state_factory()
run_id = state.create_run()
node_id = state.create_node()
node_id = state.create_node(ping_interval=10)

# Execute
state.delete_node(node_id)
Expand All @@ -346,7 +346,7 @@ def test_get_nodes_invalid_run_id(self) -> None:
state: State = self.state_factory()
state.create_run()
invalid_run_id = 61016
state.create_node()
state.create_node(ping_interval=10)

# Execute
retrieved_node_ids = state.get_nodes(invalid_run_id)
Expand Down Expand Up @@ -399,7 +399,7 @@ def test_acknowledge_ping(self) -> None:
# Prepare
state: State = self.state_factory()
run_id = state.create_run()
node_ids = [state.create_node() for _ in range(100)]
node_ids = [state.create_node(ping_interval=10) for _ in range(100)]
for node_id in node_ids[:70]:
state.acknowledge_ping(node_id, ping_interval=30)
for node_id in node_ids[70:]:
Expand Down
24 changes: 2 additions & 22 deletions src/py/flwr/simulation/ray_transport/ray_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import asyncio
import threading
import traceback
from abc import ABC
from logging import DEBUG, ERROR, WARNING
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
Expand All @@ -25,22 +24,13 @@
from ray import ObjectRef
from ray.util.actor_pool import ActorPool

from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
from flwr.common import Context, Message
from flwr.common.logger import log

ClientAppFn = Callable[[], ClientApp]


class ClientException(Exception):
"""Raised when client side logic crashes with an exception."""

def __init__(self, message: str):
div = ">" * 7
self.message = "\n" + div + "A ClientException occurred." + message
super().__init__(self.message)


class VirtualClientEngineActor(ABC):
"""Abstract base class for VirtualClientEngine Actors."""

Expand Down Expand Up @@ -71,17 +61,7 @@ def run(
raise load_ex

except Exception as ex:
client_trace = traceback.format_exc()
mssg = (
"\n\tSomething went wrong when running your client run."
"\n\tClient "
+ cid
+ " crashed when the "
+ self.__class__.__name__
+ " was running its run."
"\n\tException triggered on the client side: " + client_trace,
)
raise ClientException(str(mssg)) from ex
raise ClientAppException(str(ex)) from ex

return cid, out_message, context

Expand Down

0 comments on commit 71f74d9

Please sign in to comment.