Skip to content

Commit

Permalink
feat(framework) Add node-config arg to SuperNode (#3782)
Browse files Browse the repository at this point in the history
Co-authored-by: jafermarq <[email protected]>
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
3 people authored Jul 12, 2024
1 parent 01ca846 commit ea8f940
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
16 changes: 6 additions & 10 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class `flwr.client.Client` (default: None)
event(EventType.START_CLIENT_ENTER)
_start_client_internal(
server_address=server_address,
node_config={},
load_client_app_fn=None,
client_fn=client_fn,
client=client,
Expand All @@ -181,6 +182,7 @@ class `flwr.client.Client` (default: None)
def _start_client_internal(
*,
server_address: str,
node_config: Dict[str, str],
load_client_app_fn: Optional[Callable[[str, str], ClientApp]] = None,
client_fn: Optional[ClientFnExt] = None,
client: Optional[Client] = None,
Expand All @@ -193,7 +195,6 @@ def _start_client_internal(
] = None,
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
partition_id: Optional[int] = None,
flwr_dir: Optional[Path] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.
Expand All @@ -204,6 +205,8 @@ def _start_client_internal(
The IPv4 or IPv6 address of the server. If the Flower
server runs on the same machine on port 8080, then `server_address`
would be `"[::]:8080"`.
node_config: Dict[str, str]
The configuration of the node.
load_client_app_fn : Optional[Callable[[], ClientApp]] (default: None)
A function that can be used to load a `ClientApp` instance.
client_fn : Optional[ClientFnExt]
Expand Down Expand Up @@ -238,9 +241,6 @@ class `flwr.client.Client` (default: None)
The maximum duration before the client stops trying to
connect to the server in case of connection error.
If set to None, there is no limit to the total time.
partition_id: Optional[int] (default: None)
The data partition index associated with this node. Better suited for
prototyping purposes.
flwr_dir: Optional[Path] (default: None)
The fully resolved path containing installed Flower Apps.
"""
Expand Down Expand Up @@ -319,10 +319,6 @@ def _on_backoff(retry_state: RetryState) -> None:
on_backoff=_on_backoff,
)

# Empty dict (for now)
# This will be removed once users can pass node_config via flower-supernode
node_config: Dict[str, str] = {}

# NodeState gets initialized when the first connection is established
node_state: Optional[NodeState] = None

Expand Down Expand Up @@ -353,7 +349,7 @@ def _on_backoff(retry_state: RetryState) -> None:
node_state = NodeState(
node_id=-1,
node_config={},
partition_id=partition_id,
partition_id=None,
)
else:
# Call create_node fn to register node
Expand All @@ -365,7 +361,7 @@ def _on_backoff(retry_state: RetryState) -> None:
node_state = NodeState(
node_id=node_id,
node_config=node_config,
partition_id=partition_id,
partition_id=None,
)

app_state_tracker.register_signal_handler()
Expand Down
20 changes: 13 additions & 7 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@

from flwr.client.client_app import ClientApp, LoadClientAppError
from flwr.common import EventType, event
from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
from flwr.common.config import (
get_flwr_dir,
get_project_config,
get_project_dir,
parse_config_args,
)
from flwr.common.constant import (
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_RERE,
Expand Down Expand Up @@ -67,7 +72,7 @@ def run_supernode() -> None:
authentication_keys=authentication_keys,
max_retries=args.max_retries,
max_wait_time=args.max_wait_time,
partition_id=args.partition_id,
node_config=parse_config_args(args.node_config),
flwr_dir=get_flwr_dir(args.flwr_dir),
)

Expand All @@ -93,6 +98,7 @@ def run_client_app() -> None:

_start_client_internal(
server_address=args.superlink,
node_config=parse_config_args(args.node_config),
load_client_app_fn=load_fn,
transport=args.transport,
root_certificates=root_certificates,
Expand Down Expand Up @@ -389,11 +395,11 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
help="The SuperNode's public key (as a path str) to enable authentication.",
)
parser.add_argument(
"--partition-id",
type=int,
help="The data partition index associated with this SuperNode. Better suited "
"for prototyping purposes where a SuperNode might only load a fraction of an "
"artificially partitioned dataset (e.g. using `flwr-datasets`)",
"--node-config",
type=str,
help="A comma separated list of key/value pairs (separated by `=`) to "
"configure the SuperNode. "
"E.g. --node-config 'key1=\"value1\",partition-id=0,num-partitions=100'",
)


Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,16 @@ def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, st


def parse_config_args(
config_overrides: Optional[str],
config: Optional[str],
separator: str = ",",
) -> Dict[str, str]:
"""Parse separator separated list of key-value pairs separated by '='."""
overrides: Dict[str, str] = {}

if config_overrides is None:
if config is None:
return overrides

overrides_list = config_overrides.split(separator)
overrides_list = config.split(separator)
if (
len(overrides_list) == 1
and "=" not in overrides_list
Expand Down

0 comments on commit ea8f940

Please sign in to comment.