Skip to content

Commit

Permalink
refactor(framework) Update JAX template (#4294)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Oct 8, 2024
1 parent 3e45bfd commit fcd20f6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 24 deletions.
28 changes: 11 additions & 17 deletions src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""$project_name: A Flower / $framework_str app."""

import jax
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context

from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from $import_name.task import (
evaluation,
get_params,
Expand All @@ -17,37 +17,31 @@ from $import_name.task import (

# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(self):
def __init__(self, input_dim):
self.train_x, self.train_y, self.test_x, self.test_y = load_data()
self.grad_fn = jax.grad(loss_fn)
model_shape = self.train_x.shape[1:]

self.params = load_model(model_shape)

def get_parameters(self, config):
return get_params(self.params)

def set_parameters(self, parameters):
set_params(self.params, parameters)
self.params = load_model((input_dim,))

def fit(self, parameters, config):
self.set_parameters(parameters)
set_params(self.params, parameters)
self.params, loss, num_examples = train(
self.params, self.grad_fn, self.train_x, self.train_y
)
parameters = self.get_parameters(config={})
return parameters, num_examples, {"loss": float(loss)}
return get_params(self.params), num_examples, {"loss": float(loss)}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
set_params(self.params, parameters)
loss, num_examples = evaluation(
self.params, self.grad_fn, self.test_x, self.test_y
)
return float(loss), num_examples, {"loss": float(loss)}


def client_fn(context: Context):
input_dim = context.run_config["input-dim"]

# Return Client instance
return FlowerClient().to_client()
return FlowerClient(input_dim).to_client()


# Flower ClientApp
Expand Down
12 changes: 9 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
"""$project_name: A Flower / $framework_str app."""

from flwr.common import Context
from flwr.server.strategy import FedAvg
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from $import_name.task import get_params, load_model


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]
input_dim = context.run_config["input-dim"]

# Initialize global model
params = get_params(load_model((input_dim,)))
initial_parameters = ndarrays_to_parameters(params)

# Define strategy
strategy = FedAvg()
strategy = FedAvg(initial_parameters=initial_parameters)
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/task.jax.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import jax
import jax.numpy as jnp
import numpy as np
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
import numpy as np

key = jax.random.PRNGKey(0)

Expand Down Expand Up @@ -33,7 +33,7 @@ def train(params, grad_fn, X, y):
num_examples = X.shape[0]
for epochs in range(50):
grads = grad_fn(params, X, y)
params = jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
loss = loss_fn(params, X, y)
return params, loss, num_examples

Expand Down
5 changes: 3 additions & 2 deletions src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.10.0",
"jax==0.4.13",
"jaxlib==0.4.13",
"jax==0.4.30",
"jaxlib==0.4.30",
"scikit-learn==1.3.2",
]

Expand All @@ -26,6 +26,7 @@ clientapp = "$import_name.client_app:app"

[tool.flwr.app.config]
num-server-rounds = 3
input-dim = 3

[tool.flwr.federations]
default = "local-simulation"
Expand Down

0 comments on commit fcd20f6

Please sign in to comment.