From ea8f940a465cd43de5293ffc8b719431f46679ae Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Fri, 12 Jul 2024 18:26:06 +0200 Subject: [PATCH] feat(framework) Add `node-config` arg to SuperNode (#3782) Co-authored-by: jafermarq Co-authored-by: Daniel J. Beutel --- src/py/flwr/client/app.py | 16 ++++++---------- src/py/flwr/client/supernode/app.py | 20 +++++++++++++------- src/py/flwr/common/config.py | 6 +++--- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index fa17ba9a8481..ffcc95489d62 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -160,6 +160,7 @@ class `flwr.client.Client` (default: None) event(EventType.START_CLIENT_ENTER) _start_client_internal( server_address=server_address, + node_config={}, load_client_app_fn=None, client_fn=client_fn, client=client, @@ -181,6 +182,7 @@ class `flwr.client.Client` (default: None) def _start_client_internal( *, server_address: str, + node_config: Dict[str, str], load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None, client_fn: Optional[ClientFnExt] = None, client: Optional[Client] = None, @@ -193,7 +195,6 @@ def _start_client_internal( ] = None, max_retries: Optional[int] = None, max_wait_time: Optional[float] = None, - partition_id: Optional[int] = None, flwr_dir: Optional[Path] = None, ) -> None: """Start a Flower client node which connects to a Flower server. @@ -204,6 +205,8 @@ def _start_client_internal( The IPv4 or IPv6 address of the server. If the Flower server runs on the same machine on port 8080, then `server_address` would be `"[::]:8080"`. + node_config: Dict[str, str] + The configuration of the node. load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None) A function that can be used to load a `ClientApp` instance. client_fn : Optional[ClientFnExt] @@ -238,9 +241,6 @@ class `flwr.client.Client` (default: None) The maximum duration before the client stops trying to connect to the server in case of connection error. If set to None, there is no limit to the total time. - partition_id: Optional[int] (default: None) - The data partition index associated with this node. Better suited for - prototyping purposes. flwr_dir: Optional[Path] (default: None) The fully resolved path containing installed Flower Apps. """ @@ -319,10 +319,6 @@ def _on_backoff(retry_state: RetryState) -> None: on_backoff=_on_backoff, ) - # Empty dict (for now) - # This will be removed once users can pass node_config via flower-supernode - node_config: Dict[str, str] = {} - # NodeState gets initialized when the first connection is established node_state: Optional[NodeState] = None @@ -353,7 +349,7 @@ def _on_backoff(retry_state: RetryState) -> None: node_state = NodeState( node_id=-1, node_config={}, - partition_id=partition_id, + partition_id=None, ) else: # Call create_node fn to register node @@ -365,7 +361,7 @@ def _on_backoff(retry_state: RetryState) -> None: node_state = NodeState( node_id=node_id, node_config=node_config, - partition_id=partition_id, + partition_id=None, ) app_state_tracker.register_signal_handler() diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 355a2a13a0e5..d61b986bc7af 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -29,7 +29,12 @@ from flwr.client.client_app import ClientApp, LoadClientAppError from flwr.common import EventType, event -from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir +from flwr.common.config import ( + get_flwr_dir, + get_project_config, + get_project_dir, + parse_config_args, +) from flwr.common.constant import ( TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, @@ -67,7 +72,7 @@ def run_supernode() -> None: authentication_keys=authentication_keys, max_retries=args.max_retries, max_wait_time=args.max_wait_time, - partition_id=args.partition_id, + node_config=parse_config_args(args.node_config), flwr_dir=get_flwr_dir(args.flwr_dir), ) @@ -93,6 +98,7 @@ def run_client_app() -> None: _start_client_internal( server_address=args.superlink, + node_config=parse_config_args(args.node_config), load_client_app_fn=load_fn, transport=args.transport, root_certificates=root_certificates, @@ -389,11 +395,11 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None: help="The SuperNode's public key (as a path str) to enable authentication.", ) parser.add_argument( - "--partition-id", - type=int, - help="The data partition index associated with this SuperNode. Better suited " - "for prototyping purposes where a SuperNode might only load a fraction of an " - "artificially partitioned dataset (e.g. using `flwr-datasets`)", + "--node-config", + type=str, + help="A comma separated list of key/value pairs (separated by `=`) to " + "configure the SuperNode. " + "E.g. --node-config 'key1=\"value1\",partition-id=0,num-partitions=100'", ) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 9770bdb4af2b..54d74353e4ed 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -121,16 +121,16 @@ def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, st def parse_config_args( - config_overrides: Optional[str], + config: Optional[str], separator: str = ",", ) -> Dict[str, str]: """Parse separator separated list of key-value pairs separated by '='.""" overrides: Dict[str, str] = {} - if config_overrides is None: + if config is None: return overrides - overrides_list = config_overrides.split(separator) + overrides_list = config.split(separator) if ( len(overrides_list) == 1 and "=" not in overrides_list