From da49c357b54fba9d0f288c6a9777c4867c063129 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 17 Jul 2024 15:46:19 +0200 Subject: [PATCH 01/10] refactor(framework) Update launch of simulation from executor plugin (#3829) --- src/py/flwr/superexec/simulation.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py index e71e39e38e36..63c6b3270917 100644 --- a/src/py/flwr/superexec/simulation.py +++ b/src/py/flwr/superexec/simulation.py @@ -111,11 +111,6 @@ def start_run( "Config extracted from FAB's pyproject.toml is not valid" ) - # Get ClientApp and SeverApp components - flower_components = config["tool"]["flwr"]["app"]["components"] - clientapp = flower_components["clientapp"] - serverapp = flower_components["serverapp"] - # In Simulation there is no SuperLink, still we create a run_id run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) log(INFO, "Created run %s", str(run_id)) @@ -123,16 +118,17 @@ def start_run( # Prepare commnand command = [ "flower-simulation", - "--client-app", - f"{clientapp}", - "--server-app", - f"{serverapp}", + "--app", + f"{str(fab_path)}", "--num-supernodes", f"{self.num_supernodes}", "--run-id", str(run_id), ] + if override_config: + command.extend(["--run-config", f"{override_config}"]) + # Start Simulation proc = subprocess.Popen( # pylint: disable=consider-using-with command, From e737515c051092c445ace07735307c9c46c8fffd Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Wed, 17 Jul 2024 16:48:59 +0200 Subject: [PATCH 02/10] fix(framework:skip) Pass list of strings to parsing func (#3836) --- src/py/flwr/client/supernode/app.py | 4 ++-- src/py/flwr/common/config.py | 25 ++++++++++++------------ src/py/flwr/simulation/run_simulation.py | 2 +- src/py/flwr/superexec/app.py | 2 +- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/py/flwr/client/supernode/app.py b/src/py/flwr/client/supernode/app.py index 0ef0a145b1b6..f3fb0e97805a 100644 --- a/src/py/flwr/client/supernode/app.py +++ b/src/py/flwr/client/supernode/app.py @@ -77,7 +77,7 @@ def run_supernode() -> None: authentication_keys=authentication_keys, max_retries=args.max_retries, max_wait_time=args.max_wait_time, - node_config=parse_config_args(args.node_config), + node_config=parse_config_args([args.node_config]), flwr_path=get_flwr_dir(args.flwr_dir), ) @@ -107,7 +107,7 @@ def run_client_app() -> None: _start_client_internal( server_address=args.superlink, - node_config=parse_config_args(args.node_config), + node_config=parse_config_args([args.node_config]), load_client_app_fn=load_fn, transport=args.transport, root_certificates=root_certificates, diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 3bd7a103f4c2..789433a287e7 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -140,17 +140,18 @@ def parse_config_args( return overrides for config_line in config: - overrides_list = config_line.split(separator) - if ( - len(overrides_list) == 1 - and "=" not in overrides_list - and overrides_list[0].endswith(".toml") - ): - with Path(overrides_list[0]).open("rb") as config_file: - overrides = flatten_dict(tomli.load(config_file)) - else: - for kv_pair in overrides_list: - key, value = kv_pair.split("=") - overrides[key] = value + if config_line: + overrides_list = config_line.split(separator) + if ( + len(overrides_list) == 1 + and "=" not in overrides_list + and overrides_list[0].endswith(".toml") + ): + with Path(overrides_list[0]).open("rb") as config_file: + overrides = flatten_dict(tomli.load(config_file)) + else: + for kv_pair in overrides_list: + key, value = kv_pair.split("=") + overrides[key] = value return overrides diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 8a3c7739e595..e5d82207e08b 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -132,7 +132,7 @@ def run_simulation_from_cli() -> None: client_app_attr = app_components["clientapp"] server_app_attr = app_components["serverapp"] - override_config = parse_config_args(args.run_config) + override_config = parse_config_args([args.run_config]) fused_config = get_fused_config_from_dir(app_path, override_config) app_dir = args.app is_app = True diff --git a/src/py/flwr/superexec/app.py b/src/py/flwr/superexec/app.py index b51c3e6821dc..9f1753ce041b 100644 --- a/src/py/flwr/superexec/app.py +++ b/src/py/flwr/superexec/app.py @@ -56,7 +56,7 @@ def run_superexec() -> None: address=address, executor=_load_executor(args), certificates=certificates, - config=parse_config_args(args.executor_config), + config=parse_config_args([args.executor_config]), ) grpc_servers = [superexec_server] From d10d6b8cbff4d496eaa6707b1a9cc60ad310525c Mon Sep 17 00:00:00 2001 From: Robert Steiner Date: Wed, 17 Jul 2024 16:57:49 +0200 Subject: [PATCH 03/10] ci(*:skip) Upgrade pip and setuptools (#3835) Signed-off-by: Robert Steiner --- .devcontainer/Dockerfile | 4 ++-- .github/actions/bootstrap/action.yml | 4 ++-- dev/bootstrap.sh | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index ce4f8a1a5b8d..7ab0812e7cff 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -29,8 +29,8 @@ RUN apt-get install -y curl wget gnupg python3 python-is-python3 python3-pip git build-essential tmux vim RUN python -m pip install \ - pip==24.0.0 \ - setuptools==69.5.1 \ + pip==24.1.2 \ + setuptools==70.3.0 \ poetry==1.7.1 USER $USERNAME diff --git a/.github/actions/bootstrap/action.yml b/.github/actions/bootstrap/action.yml index bee90beffa7d..4cde8dddfa3f 100644 --- a/.github/actions/bootstrap/action.yml +++ b/.github/actions/bootstrap/action.yml @@ -6,10 +6,10 @@ inputs: default: 3.8 pip-version: description: "Version of pip to be installed using pip" - default: 24.0.0 + default: 24.1.2 setuptools-version: description: "Version of setuptools to be installed using pip" - default: 69.5.1 + default: 70.3.0 poetry-version: description: "Version of poetry to be installed using pip" default: 1.7.1 diff --git a/dev/bootstrap.sh b/dev/bootstrap.sh index 154fe0f1cbaf..bfcdc8a4369e 100755 --- a/dev/bootstrap.sh +++ b/dev/bootstrap.sh @@ -9,8 +9,8 @@ cd "$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)"/../ ./dev/rm-caches.sh # Upgrade/install spcific versions of `pip`, `setuptools`, and `poetry` -python -m pip install -U pip==24.0.0 -python -m pip install -U setuptools==69.5.1 +python -m pip install -U pip==24.1.2 +python -m pip install -U setuptools==70.3.0 python -m pip install -U poetry==1.7.1 # Use `poetry` to install project dependencies From e0cb149e9a3b0997d093be5bb656ac6e1ea26260 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Wed, 17 Jul 2024 17:30:56 +0200 Subject: [PATCH 04/10] feat(framework:skip) Add `app` suffix to `client` and `server` (#3828) --- src/py/flwr/cli/new/new.py | 7 ++++--- src/py/flwr/cli/new/new_test.py | 4 ++-- .../flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl | 4 ++-- .../cli/new/templates/app/code/flwr_tune/server.py.tpl | 2 +- src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl | 4 ++-- src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl | 4 ++-- src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl | 4 ++-- src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl | 4 ++-- .../flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl | 4 ++-- .../flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl | 4 ++-- .../cli/new/templates/app/pyproject.tensorflow.toml.tpl | 4 ++-- 11 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index a0a2dc98556d..4bde009742f8 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -136,6 +136,7 @@ def new( framework_str = framework_str.lower() + llm_challenge_str = None if framework_str == "flowertune": llm_challenge_value = prompt_options( "Please select LLM challenge by typing in the number", @@ -171,7 +172,7 @@ def new( } # List of files to render - if framework_str == "flowertune": + if llm_challenge_str: files = { ".gitignore": {"template": "app/.gitignore.tpl"}, "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"}, @@ -228,10 +229,10 @@ def new( "README.md": {"template": "app/README.md.tpl"}, "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"}, f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"}, - f"{import_name}/server.py": { + f"{import_name}/server_app.py": { "template": f"app/code/server.{framework_str}.py.tpl" }, - f"{import_name}/client.py": { + f"{import_name}/client_app.py": { "template": f"app/code/client.{framework_str}.py.tpl" }, } diff --git a/src/py/flwr/cli/new/new_test.py b/src/py/flwr/cli/new/new_test.py index 33ad745efa93..7f22bd5f9825 100644 --- a/src/py/flwr/cli/new/new_test.py +++ b/src/py/flwr/cli/new/new_test.py @@ -86,8 +86,8 @@ def test_new_correct_name(tmp_path: str) -> None: } expected_files_module = { "__init__.py", - "server.py", - "client.py", + "server_app.py", + "client_app.py", "task.py", } diff --git a/src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl b/src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl index ecb87bd71e3f..a0f781df04a1 100644 --- a/src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl @@ -12,10 +12,10 @@ from flwr.client import ClientApp from flwr.common import ndarrays_to_parameters from flwr.server import ServerApp, ServerConfig -from $import_name.client import gen_client_fn, get_parameters +from $import_name.client_app import gen_client_fn, get_parameters from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting from $import_name.models import get_model -from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config +from $import_name.server_app import fit_weighted_average, get_evaluate_fn, get_on_fit_config # Avoid warnings warnings.filterwarnings("ignore", category=UserWarning) diff --git a/src/py/flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl b/src/py/flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl index 19223148bca5..5dd4d881f2f1 100644 --- a/src/py/flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl @@ -1,6 +1,6 @@ """$project_name: A Flower / FlowerTune app.""" -from $import_name.client import set_parameters +from $import_name.client_app import set_parameters from $import_name.models import get_model diff --git a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl index b39facbec5a0..92c954e754cf 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl @@ -24,8 +24,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl index 405decf38f16..e899f48f4c5c 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl @@ -21,8 +21,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl index a2b743800595..6004c076cf87 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl @@ -21,8 +21,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl index ad074b90d24a..543936ed4a89 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl @@ -19,8 +19,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl index ecd1497500ab..8a92cf0eca9a 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -21,8 +21,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl index 4bc407c34262..5c1ffa09aed2 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl @@ -20,8 +20,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl index 9dab874e50ff..de1a445e33f9 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -20,8 +20,8 @@ packages = ["."] publisher = "$username" [tool.flwr.app.components] -serverapp = "$import_name.server:app" -clientapp = "$import_name.client:app" +serverapp = "$import_name.server_app:app" +clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" From c6547efdf71ee2ef258552923b267fc8e61f253b Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 17 Jul 2024 18:35:46 +0200 Subject: [PATCH 05/10] fix(framework) Enable overriding of run configs with simulation plugin (#3839) --- src/py/flwr/superexec/simulation.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py index 63c6b3270917..732d749c5c53 100644 --- a/src/py/flwr/superexec/simulation.py +++ b/src/py/flwr/superexec/simulation.py @@ -82,11 +82,6 @@ def start_run( ) -> Optional[RunTracker]: """Start run using the Flower Simulation Engine.""" try: - if override_config: - raise ValueError( - "Overriding the run config is not yet supported with the " - "simulation executor.", - ) # Install FAB to flwr dir fab_path = install_from_fab(fab_file, None, True) From a20e4c9a13681587aba2760bfdad23e5343af1f7 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 17 Jul 2024 19:06:39 +0200 Subject: [PATCH 06/10] feat(framework) Update subprocess launch mechanism for simulation plugin (#3826) Signed-off-by: Danny Heinrich Co-authored-by: Danny Heinrich --- src/py/flwr/superexec/simulation.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py index 732d749c5c53..58cc194a16d4 100644 --- a/src/py/flwr/superexec/simulation.py +++ b/src/py/flwr/superexec/simulation.py @@ -125,10 +125,9 @@ def start_run( command.extend(["--run-config", f"{override_config}"]) # Start Simulation - proc = subprocess.Popen( # pylint: disable=consider-using-with + proc = subprocess.run( # pylint: disable=consider-using-with command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + check=True, text=True, ) @@ -136,7 +135,7 @@ def start_run( return RunTracker( run_id=run_id, - proc=proc, + proc=proc, # type:ignore ) # pylint: disable-next=broad-except From 89436dfe74bb6cfd437f730be0f5462bf63a9022 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Thu, 18 Jul 2024 12:55:09 +0200 Subject: [PATCH 07/10] feat(framework) Add run_config to templates (#3845) --- .../new/templates/app/code/client.hf.py.tpl | 10 +++++++--- .../new/templates/app/code/client.mlx.py.tpl | 18 ++++++++++-------- .../templates/app/code/client.pytorch.py.tpl | 8 +++++++- .../templates/app/code/client.sklearn.py.tpl | 3 ++- .../app/code/client.tensorflow.py.tpl | 11 +++++++++-- .../new/templates/app/pyproject.hf.toml.tpl | 1 + .../new/templates/app/pyproject.mlx.toml.tpl | 5 +++++ .../templates/app/pyproject.pytorch.toml.tpl | 1 + .../app/pyproject.tensorflow.toml.tpl | 3 +++ 9 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl index 56bac8543c50..13b071013076 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl @@ -30,7 +30,11 @@ class FlowerClient(NumPyClient): def fit(self, parameters, config): self.set_parameters(parameters) - train(self.net, self.trainloader, epochs=1) + train( + self.net, + self.trainloader, + epochs=int(self.context.run_config["local-epochs"]), + ) return self.get_parameters(config={}), len(self.trainloader), {} def evaluate(self, parameters, config): @@ -45,8 +49,8 @@ def client_fn(context: Context): CHECKPOINT, num_labels=2 ).to(DEVICE) - partition_id = int(context.node_config['partition-id']) - num_partitions = int(context.node_config['num-partitions]) + partition_id = int(context.node_config["partition-id"]) + num_partitions = int(context.node_config["num-partitions"]) trainloader, valloader = load_data(partition_id, num_partitions) # Return Client instance diff --git a/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl index 37207c940d83..fe1f4041a076 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl @@ -20,17 +20,19 @@ from $import_name.task import ( # Define Flower Client and client_fn class FlowerClient(NumPyClient): def __init__(self, data): - num_layers = 2 - hidden_dim = 32 + num_layers = int(self.context.run_config["num-layers"]) + hidden_dim = int(self.context.run_config["hidden-dim"]) num_classes = 10 - batch_size = 256 - num_epochs = 1 - learning_rate = 1e-1 + batch_size = int(self.context.run_config["batch-size"]) + learning_rate = float(self.context.run_config["lr"]) + num_epochs = int(self.context.run_config["local-epochs"]) self.train_images, self.train_labels, self.test_images, self.test_labels = data - self.model = MLP(num_layers, self.train_images.shape[-1], hidden_dim, num_classes) - self.optimizer = optim.SGD(learning_rate=learning_rate) - self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn) + self.model = MLP( + num_layers, self.train_images.shape[-1], hidden_dim, num_classes + ) + self.optimizer = optim.SGD(learning_rate=learning_rate) + self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn) self.num_epochs = num_epochs self.batch_size = batch_size diff --git a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl index addc71023a09..3635843ba0be 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl @@ -23,7 +23,13 @@ class FlowerClient(NumPyClient): def fit(self, parameters, config): set_weights(self.net, parameters) - results = train(self.net, self.trainloader, self.valloader, 1, DEVICE) + results = train( + self.net, + self.trainloader, + self.valloader, + int(self.context.run_config["local-epochs"]), + DEVICE, + ) return get_weights(self.net), len(self.trainloader.dataset), results def evaluate(self, parameters, config): diff --git a/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl index a1eefa034e7b..9642ae490155 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl @@ -67,10 +67,11 @@ class FlowerClient(NumPyClient): return loss, len(self.X_test), {"accuracy": accuracy} -fds = FederatedDataset(dataset="mnist", partitioners={"train": 2}) def client_fn(context: Context): partition_id = int(context.node_config["partition-id"]) + num_partitions = int(context.node_config["num-partitions"]) + fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions}) dataset = fds.load_partition(partition_id, "train").with_format("numpy") X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"] diff --git a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl index 0fe1c405a110..5702f5b9c0d0 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl @@ -20,7 +20,13 @@ class FlowerClient(NumPyClient): def fit(self, parameters, config): self.model.set_weights(parameters) - self.model.fit(self.x_train, self.y_train, epochs=1, batch_size=32, verbose=0) + self.model.fit( + self.x_train, + self.y_train, + epochs=int(self.context.run_config["local-epochs"]), + batch_size=int(self.context.run_config["batch-size"]), + verbose=bool(self.context.run_config.get("verbose")), + ) return self.model.get_weights(), len(self.x_train), {} def evaluate(self, parameters, config): @@ -34,7 +40,8 @@ def client_fn(context: Context): net = load_model() partition_id = int(context.node_config["partition-id"]) - x_train, y_train, x_test, y_test = load_data(partition_id, 2) + num_partitions = int(context.node_config["num-partitions"]) + x_train, y_train, x_test, y_test = load_data(partition_id, num_partitions) # Return Client instance return FlowerClient(net, x_train, y_train, x_test, y_test).to_client() diff --git a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl index 92c954e754cf..7b7ea9d01200 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl @@ -29,6 +29,7 @@ clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" +local-epochs = "1" [tool.flwr.federations] default = "localhost" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl index 6004c076cf87..fde693e6c3de 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl @@ -26,6 +26,11 @@ clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" +local-epochs = "1" +num-layers = "2" +hidden-dim = "32" +batch-size = "256" +lr = "0.1" [tool.flwr.federations] default = "localhost" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl index 8a92cf0eca9a..d7991c05daf7 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -26,6 +26,7 @@ clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" +local-epochs = "1" [tool.flwr.federations] default = "localhost" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl index de1a445e33f9..400689cc541a 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -25,6 +25,9 @@ clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = "3" +local-epochs = "1" +batch-size = "32" +verbose = "" # Empty string means False [tool.flwr.federations] default = "localhost" From 445a18e14ac2a4dcb3e544f66b2b72d0f09a699a Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Thu, 18 Jul 2024 14:35:05 +0200 Subject: [PATCH 08/10] fix(framework:skip) Send correct override_dict to simulation (#3849) --- src/py/flwr/cli/run/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index 5dedb701fea9..a8dd5a59a627 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -199,7 +199,7 @@ def _run_without_superexec( ] if config_overrides: - command.extend(["--run-config", f"{config_overrides}"]) + command.extend(["--run-config", f"{','.join(config_overrides)}"]) # Run the simulation subprocess.run( From 247012ed77c08071035e65dad3b394e5b6904a4e Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Thu, 18 Jul 2024 16:39:18 +0200 Subject: [PATCH 09/10] feat(datasets) Enable passing kwargs to load_dataset in FederatedDataset (#3827) --- datasets/flwr_datasets/federated_dataset.py | 20 +++++++++-- .../flwr_datasets/federated_dataset_test.py | 34 +++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index accfb783f368..23d98e9c7bb0 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -15,7 +15,7 @@ """FederatedDataset.""" -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import datasets from datasets import Dataset, DatasetDict @@ -65,6 +65,12 @@ class FederatedDataset: Seed used for dataset shuffling. It has no effect if `shuffle` is False. The seed cannot be set in the later stages. If `None`, then fresh, unpredictable entropy will be pulled from the OS. Defaults to 42. + load_dataset_kwargs : Any + Additional keyword arguments passed to `datasets.load_dataset` function. + Currently used paramters used are dataset => path (in load_dataset), + subset => name (in load_dataset). You can pass e.g., `num_proc=4`, + `trust_remote_code=True`. Do not pass any parameters that modify the + return type such as another type than DatasetDict is returned. Examples -------- @@ -73,7 +79,7 @@ class FederatedDataset: >>> from flwr_datasets import FederatedDataset >>> >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) - >>> # Load partition for client with ID 10. + >>> # Load partition for a client with ID 10. >>> partition = fds.load_partition(10) >>> # Use test split for centralized evaluation. >>> centralized = fds.load_split("test") @@ -107,6 +113,7 @@ def __init__( partitioners: Dict[str, Union[Partitioner, int]], shuffle: bool = True, seed: Optional[int] = 42, + **load_dataset_kwargs: Any, ) -> None: _check_if_dataset_tested(dataset) self._dataset_name: str = dataset @@ -127,6 +134,7 @@ def __init__( self._event = { "load_partition": {split: False for split in self._partitioners}, } + self._load_dataset_kwargs = load_dataset_kwargs def load_partition( self, @@ -289,8 +297,14 @@ def _prepare_dataset(self) -> None: happen before the resplitting. """ self._dataset = datasets.load_dataset( - path=self._dataset_name, name=self._subset + path=self._dataset_name, name=self._subset, **self._load_dataset_kwargs ) + if not isinstance(self._dataset, datasets.DatasetDict): + raise ValueError( + "Probably one of the specified parameter in `load_dataset_kwargs` " + "change the return type of the datasets.load_dataset function. " + "Make sure to use parameter such that the return type is DatasetDict." + ) if self._shuffle: # Note it shuffles all the splits. The self._dataset is DatasetDict # so e.g. {"train": train_data, "test": test_data}. All splits get shuffled. diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index f65aa6346f3a..100e9943c530 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -216,6 +216,23 @@ def resplit(dataset: DatasetDict) -> DatasetDict: dataset_length = sum([len(ds) for ds in dataset.values()]) self.assertEqual(len(full), dataset_length) + def test_use_load_dataset_kwargs(self) -> None: + """Test if the FederatedDataset works correctly with load_dataset_kwargs.""" + try: + fds = FederatedDataset( + dataset=self.dataset_name, + shuffle=False, + partitioners={"train": 10}, + num_proc=2, + ) + _ = fds.load_partition(0) + # Try to catch as broad as possible + except Exception as e: # pylint: disable=broad-except + self.fail( + f"Error when using load_dataset_kwargs: {e}. " + f"This code should not raise any exceptions." + ) + class ShufflingResplittingOnArtificialDatasetTest(unittest.TestCase): """Test shuffling and resplitting using small artificial dataset. @@ -416,6 +433,23 @@ def test_cannot_use_the_old_split_names(self) -> None: with self.assertRaises(ValueError): fds.load_partition(0, "train") + def test_use_load_dataset_kwargs(self) -> None: + """Test if the FederatedDataset raises with incorrect load_dataset_kwargs. + + The FederatedDataset should throw an error when the load_dataset_kwargs make the + return type different from a DatasetDict. + + Use split which makes the load_dataset return a Dataset. + """ + fds = FederatedDataset( + dataset="mnist", + shuffle=False, + partitioners={"train": 10}, + split="train", + ) + with self.assertRaises(ValueError): + _ = fds.load_partition(0) + def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool: """Check if two Datasets have the same values.""" From 02b1959aa094e742fc75c282ea45053047861fba Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Thu, 18 Jul 2024 17:50:31 +0200 Subject: [PATCH 10/10] fix(datasets:skip) Update tests for multiple partitioners (#3830) --- datasets/flwr_datasets/federated_dataset_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 100e9943c530..bb6d46f266e9 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -144,10 +144,10 @@ def test_multiple_partitioners(self) -> None: dataset_test_partition0 = dataset_fds.load_partition(0, self.test_split) dataset = datasets.load_dataset(self.dataset_name) - self.assertEqual( - len(dataset_test_partition0), - len(dataset[self.test_split]) // num_test_partitions, - ) + expected_len = len(dataset[self.test_split]) // num_test_partitions + mod = len(dataset[self.test_split]) % num_test_partitions + expected_len += 1 if 0 < mod else 0 + self.assertEqual(len(dataset_test_partition0), expected_len) def test_no_need_for_split_keyword_if_one_partitioner(self) -> None: """Test if partitions got with and without split args are the same."""