diff --git a/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl index 046de57f3cf3..ffe782d274fc 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl @@ -1,9 +1,9 @@ """$project_name: A Flower / $framework_str app.""" import jax -from flwr.client import NumPyClient, ClientApp -from flwr.common import Context +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context from $import_name.task import ( evaluation, get_params, @@ -17,37 +17,31 @@ from $import_name.task import ( # Define Flower Client and client_fn class FlowerClient(NumPyClient): - def __init__(self): + def __init__(self, input_dim): self.train_x, self.train_y, self.test_x, self.test_y = load_data() self.grad_fn = jax.grad(loss_fn) - model_shape = self.train_x.shape[1:] - - self.params = load_model(model_shape) - - def get_parameters(self, config): - return get_params(self.params) - - def set_parameters(self, parameters): - set_params(self.params, parameters) + self.params = load_model((input_dim,)) def fit(self, parameters, config): - self.set_parameters(parameters) + set_params(self.params, parameters) self.params, loss, num_examples = train( self.params, self.grad_fn, self.train_x, self.train_y ) - parameters = self.get_parameters(config={}) - return parameters, num_examples, {"loss": float(loss)} + return get_params(self.params), num_examples, {"loss": float(loss)} def evaluate(self, parameters, config): - self.set_parameters(parameters) + set_params(self.params, parameters) loss, num_examples = evaluation( self.params, self.grad_fn, self.test_x, self.test_y ) return float(loss), num_examples, {"loss": float(loss)} + def client_fn(context: Context): + input_dim = context.run_config["input-dim"] + # Return Client instance - return FlowerClient().to_client() + return FlowerClient(input_dim).to_client() # Flower ClientApp diff --git a/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl index 514185fde970..60bbcaf3c175 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl @@ -1,16 +1,22 @@ """$project_name: A Flower / $framework_str app.""" -from flwr.common import Context -from flwr.server.strategy import FedAvg +from flwr.common import Context, ndarrays_to_parameters from flwr.server import ServerApp, ServerAppComponents, ServerConfig +from flwr.server.strategy import FedAvg +from $import_name.task import get_params, load_model def server_fn(context: Context): # Read from config num_rounds = context.run_config["num-server-rounds"] + input_dim = context.run_config["input-dim"] + + # Initialize global model + params = get_params(load_model((input_dim,))) + initial_parameters = ndarrays_to_parameters(params) # Define strategy - strategy = FedAvg() + strategy = FedAvg(initial_parameters=initial_parameters) config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/code/task.jax.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.jax.py.tpl index fc6ef9dee3dd..428f752845c1 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.jax.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.jax.py.tpl @@ -2,9 +2,9 @@ import jax import jax.numpy as jnp +import numpy as np from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split -import numpy as np key = jax.random.PRNGKey(0) @@ -33,7 +33,7 @@ def train(params, grad_fn, X, y): num_examples = X.shape[0] for epochs in range(50): grads = grad_fn(params, X, y) - params = jax.tree_map(lambda p, g: p - 0.05 * g, params, grads) + params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads) loss = loss_fn(params, X, y) return params, loss, num_examples diff --git a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl index 31fff1c2a4c8..28cbc5bbb527 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl @@ -9,8 +9,8 @@ description = "" license = "Apache-2.0" dependencies = [ "flwr[simulation]>=1.10.0", - "jax==0.4.13", - "jaxlib==0.4.13", + "jax==0.4.30", + "jaxlib==0.4.30", "scikit-learn==1.3.2", ] @@ -26,6 +26,7 @@ clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = 3 +input-dim = 3 [tool.flwr.federations] default = "local-simulation"