diff --git a/src/py/flwr/cli/new/templates/app/code/server.hf.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.hf.py.tpl index 039ea8619532..43fce9e481c6 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.hf.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.hf.py.tpl @@ -6,12 +6,15 @@ from flwr.server import ServerApp, ServerAppComponents, ServerConfig def server_fn(context: Context): + # Read from config + num_rounds = int(context.run_config["num-server-rounds"]) + # Define strategy strategy = FedAvg( fraction_fit=1.0, fraction_evaluate=1.0, ) - config = ServerConfig(num_rounds=3) + config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl index 122b884ab8bb..4eb7149de999 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl @@ -6,9 +6,12 @@ from flwr.server import ServerApp, ServerAppComponents, ServerConfig def server_fn(context: Context): + # Read from config + num_rounds = int(context.run_config["num-server-rounds"]) + # Define strategy strategy = FedAvg() - config = ServerConfig(num_rounds=3) + config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl index 403c68ac3405..72aed878553d 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl @@ -6,9 +6,12 @@ from flwr.server.strategy import FedAvg def server_fn(context: Context): + # Read from config + num_rounds = int(context.run_config["num-server-rounds"]) + # Define strategy strategy = FedAvg() - config = ServerConfig(num_rounds=3) + config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl index 1ed2d36339db..d324b4f24fed 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl @@ -6,9 +6,12 @@ from flwr.server.strategy import FedAvg def server_fn(context: Context): + # Read from config + num_rounds = int(context.run_config["num-server-rounds"]) + # Define strategy strategy = FedAvg() - config = ServerConfig(num_rounds=3) + config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl index 3638b9eba7b0..7ac9508f8a25 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.pytorch.py.tpl @@ -12,6 +12,9 @@ ndarrays = get_weights(Net()) parameters = ndarrays_to_parameters(ndarrays) def server_fn(context: Context): + # Read from config + num_rounds = int(context.run_config["num-server-rounds"]) + # Define strategy strategy = FedAvg( fraction_fit=1.0, @@ -19,7 +22,7 @@ def server_fn(context: Context): min_available_clients=2, initial_parameters=parameters, ) - config = ServerConfig(num_rounds=3) + config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl index 2e463e8da09e..d8837798d5a6 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl @@ -6,13 +6,16 @@ from flwr.server.strategy import FedAvg def server_fn(context: Context): + # Read from config + num_rounds = int(context.run_config["num-server-rounds"]) + # Define strategy strategy = FedAvg( fraction_fit=1.0, fraction_evaluate=1.0, min_available_clients=2, ) - config = ServerConfig(num_rounds=3) + config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl index eee727ba9025..abd2a977b503 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl @@ -12,6 +12,9 @@ config = ServerConfig(num_rounds=3) parameters = ndarrays_to_parameters(load_model().get_weights()) def server_fn(context: Context): + # Read from config + num_rounds = int(context.run_config["num-server-rounds"]) + # Define strategy strategy = strategy = FedAvg( fraction_fit=1.0, @@ -19,7 +22,7 @@ def server_fn(context: Context): min_available_clients=2, initial_parameters=parameters, ) - config = ServerConfig(num_rounds=3) + config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl index 507b5d50b843..ca0b25f172fb 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl @@ -29,6 +29,9 @@ publisher = "$username" serverapp = "$import_name.app:server" clientapp = "$import_name.app:client" +[tool.flwr.app.config] +num-server-rounds = "3" + [tool.flwr.federations] default = "localhost" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl index 7a63e1ab5368..b39facbec5a0 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.hf.toml.tpl @@ -27,6 +27,9 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" +[tool.flwr.app.config] +num-server-rounds = "3" + [tool.flwr.federations] default = "localhost" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl index 297784a4d2d8..405decf38f16 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl @@ -24,6 +24,9 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" +[tool.flwr.app.config] +num-server-rounds = "3" + [tool.flwr.federations] default = "localhost" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl index fb55f6628cea..a2b743800595 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl @@ -24,6 +24,9 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" +[tool.flwr.app.config] +num-server-rounds = "3" + [tool.flwr.federations] default = "localhost" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl index ae88472647dc..ad074b90d24a 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl @@ -22,6 +22,9 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" +[tool.flwr.app.config] +num-server-rounds = "3" + [tool.flwr.federations] default = "localhost" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl index 2dd49a25fd90..ecd1497500ab 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -24,6 +24,9 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" +[tool.flwr.app.config] +num-server-rounds = "3" + [tool.flwr.federations] default = "localhost" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl index 8458fa64ea2d..4bc407c34262 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl @@ -23,6 +23,9 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" +[tool.flwr.app.config] +num-server-rounds = "3" + [tool.flwr.federations] default = "localhost" diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl index 2bf0e7d5642c..9dab874e50ff 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -23,6 +23,9 @@ publisher = "$username" serverapp = "$import_name.server:app" clientapp = "$import_name.client:app" +[tool.flwr.app.config] +num-server-rounds = "3" + [tool.flwr.federations] default = "localhost"