From bfcb4aff788db8df44f36972f78a71d384b7fc19 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Wed, 1 May 2024 18:55:50 +0200 Subject: [PATCH] fix(framework:cli): Fix tensorflow template (#3380) --- .../new/templates/app/code/client.tensorflow.py.tpl | 11 ++++------- .../new/templates/app/code/server.tensorflow.py.tpl | 7 +++++-- .../cli/new/templates/app/pyproject.numpy.toml.tpl | 2 +- .../cli/new/templates/app/pyproject.pytorch.toml.tpl | 2 +- .../cli/new/templates/app/pyproject.sklearn.toml.tpl | 2 +- .../new/templates/app/pyproject.tensorflow.toml.tpl | 2 +- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl index b8774b639fae..dc55d4ca6569 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl @@ -1,14 +1,9 @@ """$project_name: A Flower / TensorFlow app.""" -import os - from flwr.client import NumPyClient, ClientApp -from $project_name.task import load_data, load_model - +from $import_name.task import load_data, load_model -# Make TensorFlow log less verbose -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Define Flower Client and client_fn class FlowerClient(NumPyClient): @@ -43,4 +38,6 @@ def client_fn(cid): # Flower ClientApp -app = ClientApp(client_fn) +app = ClientApp( + client_fn=client_fn, +) diff --git a/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl index 48a7a223a79d..8d092164a468 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl @@ -4,7 +4,10 @@ from flwr.common import ndarrays_to_parameters from flwr.server import ServerApp, ServerConfig from flwr.server.strategy import FedAvg -from $project_name.task import load_model +from $import_name.task import load_model + +# Define config +config = ServerConfig(num_rounds=3) parameters = ndarrays_to_parameters(load_model().get_weights()) @@ -19,6 +22,6 @@ strategy = FedAvg( # Create ServerApp app = ServerApp( - config=ServerConfig(num_rounds=3), + config=config, strategy=strategy, ) diff --git a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl index ac81c02bf6ea..bbf8463054f4 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.numpy.toml.tpl @@ -9,7 +9,7 @@ description = "" authors = [ { name = "The Flower Authors", email = "hello@flower.ai" }, ] -license = {text = "Apache License (2.0)"} +license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.8.0,<2.0", "numpy>=1.21.0", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl index cbc34fea6304..a41ce1a6a4c6 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl @@ -9,7 +9,7 @@ description = "" authors = [ { name = "The Flower Authors", email = "hello@flower.ai" }, ] -license = {text = "Apache License (2.0)"} +license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.8.0,<2.0", "flwr-datasets[vision]>=0.0.2,<1.0.0", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl index 89d5b66d2382..25645f0cde1a 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl @@ -9,7 +9,7 @@ description = "" authors = [ { name = "The Flower Authors", email = "hello@flower.ai" }, ] -license = {text = "Apache License (2.0)"} +license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.8.0,<2.0", "flwr-datasets[vision]>=0.0.2,<1.0.0", diff --git a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl index dea76d951382..3968e3aa327b 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl @@ -9,7 +9,7 @@ description = "" authors = [ { name = "The Flower Authors", email = "hello@flower.ai" }, ] -license = {text = "Apache License (2.0)"} +license = { text = "Apache License (2.0)" } dependencies = [ "flwr[simulation]>=1.8.0,<2.0", "flwr-datasets[vision]>=0.0.2,<1.0.0",