Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(framework) Update templates #4083

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def new(
MlFramework.HUGGINGFACE.value,
MlFramework.MLX.value,
MlFramework.TENSORFLOW.value,
MlFramework.NUMPY.value,
MlFramework.SKLEARN.value,
]
if framework_str in frameworks_with_tasks:
files[f"{import_name}/task.py"] = {
Expand Down
24 changes: 9 additions & 15 deletions src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,31 @@ from $import_name.task import (

# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(self):
def __init__(self, input_dim):
self.train_x, self.train_y, self.test_x, self.test_y = load_data()
self.grad_fn = jax.grad(loss_fn)
model_shape = self.train_x.shape[1:]

self.params = load_model(model_shape)

def get_parameters(self, config):
return get_params(self.params)

def set_parameters(self, parameters):
set_params(self.params, parameters)
self.params = load_model((input_dim,))

def fit(self, parameters, config):
self.set_parameters(parameters)
set_params(self.params, parameters)
self.params, loss, num_examples = train(
self.params, self.grad_fn, self.train_x, self.train_y
)
parameters = self.get_parameters(config={})
return parameters, num_examples, {"loss": float(loss)}
return get_params(self.params), num_examples, {"loss": float(loss)}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
set_params(self.params, parameters)
loss, num_examples = evaluation(
self.params, self.grad_fn, self.test_x, self.test_y
)
return float(loss), num_examples, {"loss": float(loss)}

def client_fn(context: Context):

input_dim = context.run_config["input-dim"]

# Return Client instance
return FlowerClient().to_client()
return FlowerClient(input_dim).to_client()


# Flower ClientApp
Expand Down
48 changes: 13 additions & 35 deletions src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context
from flwr.common.config import UserConfig

from $import_name.task import (
batch_iterate,
Expand All @@ -22,48 +22,35 @@ class FlowerClient(NumPyClient):
def __init__(
self,
data,
num_layers,
hidden_dim,
run_config: UserConfig,
num_classes,
batch_size,
learning_rate,
num_epochs,
):
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.batch_size = batch_size
self.learning_rate = learning_rate
self.num_epochs = num_epochs
num_layers = run_config["num-layers"]
hidden_dim = run_config["hidden-dim"]
input_dim = run_config["input-dim"]
batch_size = run_config["batch-size"]
learning_rate = run_config["lr"]
self.num_epochs = run_config["local-epochs"]

self.train_images, self.train_labels, self.test_images, self.test_labels = data
self.model = MLP(
num_layers, self.train_images.shape[-1], hidden_dim, num_classes
)
self.model = MLP(num_layers, input_dim, hidden_dim, num_classes)
self.optimizer = optim.SGD(learning_rate=learning_rate)
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
self.num_epochs = num_epochs
self.batch_size = batch_size

def get_parameters(self, config):
return get_params(self.model)

def set_parameters(self, parameters):
set_params(self.model, parameters)

def fit(self, parameters, config):
self.set_parameters(parameters)
set_params(self.model, parameters)
for _ in range(self.num_epochs):
for X, y in batch_iterate(
self.batch_size, self.train_images, self.train_labels
):
_, grads = self.loss_and_grad_fn(self.model, X, y)
self.optimizer.update(self.model, grads)
mx.eval(self.model.parameters(), self.optimizer.state)
return self.get_parameters(config={}), len(self.train_images), {}
return get_params(self.model), len(self.train_images), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
set_params(self.model, parameters)
accuracy = eval_fn(self.model, self.test_images, self.test_labels)
loss = loss_fn(self.model, self.test_images, self.test_labels)
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
Expand All @@ -73,19 +60,10 @@ def client_fn(context: Context):
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
data = load_data(partition_id, num_partitions)

num_layers = context.run_config["num-layers"]
hidden_dim = context.run_config["hidden-dim"]
num_classes = 10
batch_size = context.run_config["batch-size"]
learning_rate = context.run_config["lr"]
num_epochs = context.run_config["local-epochs"]

# Return Client instance
return FlowerClient(
data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
).to_client()

return FlowerClient(data, context.run_config, num_classes).to_client()

# Flower ClientApp
app = ClientApp(
Expand Down
7 changes: 4 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ from flwr.client import NumPyClient, ClientApp
from flwr.common import Context
import numpy as np

from $import_name.task import get_dummy_model


class FlowerClient(NumPyClient):
def get_parameters(self, config):
return [np.ones((1, 1))]

def fit(self, parameters, config):
return ([np.ones((1, 1))], 1, {})
model = get_dummy_model()
return ([model], 1, {})

def evaluate(self, parameters, config):
return float(0.0), 1, {"accuracy": float(1.0)}
Expand Down
59 changes: 12 additions & 47 deletions src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,17 @@

import warnings

import numpy as np
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context
from flwr_datasets import FederatedDataset
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss


def get_model_parameters(model):
if model.fit_intercept:
params = [
model.coef_,
model.intercept_,
]
else:
params = [model.coef_]
return params


def set_model_params(model, params):
model.coef_ = params[0]
if model.fit_intercept:
model.intercept_ = params[1]
return model


def set_initial_params(model):
n_classes = 10 # MNIST has 10 classes
n_features = 784 # Number of features in dataset
model.classes_ = np.array([i for i in range(10)])

model.coef_ = np.zeros((n_classes, n_features))
if model.fit_intercept:
model.intercept_ = np.zeros((n_classes,))

from $import_name.task import (
get_model,
get_model_params,
load_data,
set_initial_params,
set_model_params,
)

class FlowerClient(NumPyClient):
def __init__(self, model, X_train, X_test, y_train, y_test):
Expand All @@ -46,9 +22,6 @@ class FlowerClient(NumPyClient):
self.y_train = y_train
self.y_test = y_test

def get_parameters(self, config):
return get_model_parameters(self.model)

def fit(self, parameters, config):
set_model_params(self.model, parameters)

Expand All @@ -57,7 +30,7 @@ class FlowerClient(NumPyClient):
warnings.simplefilter("ignore")
self.model.fit(self.X_train, self.y_train)

return get_model_parameters(self.model), len(self.X_train), {}
return get_model_params(self.model), len(self.X_train), {}

def evaluate(self, parameters, config):
set_model_params(self.model, parameters)
Expand All @@ -71,21 +44,13 @@ class FlowerClient(NumPyClient):
def client_fn(context: Context):
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
dataset = fds.load_partition(partition_id, "train").with_format("numpy")

X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]

# Split the on edge data: 80% train, 20% test
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
X_train, X_test, y_train, y_test = load_data(partition_id, num_partitions)

# Create LogisticRegression Model
model = LogisticRegression(
penalty="l2",
max_iter=1, # local epoch
warm_start=True, # prevent refreshing weights when fitting
)
penalty = context.run_config["penalty"]
local_epochs = context.run_config["local-epochs"]
model = get_model(penalty, local_epochs)

# Setting initial parameters, akin to model.compile for keras models
set_initial_params(model)
Expand Down
10 changes: 8 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
"""$project_name: A Flower / $framework_str app."""

from flwr.common import Context
from flwr.common import Context, ndarrays_to_parameters
from flwr.server.strategy import FedAvg
from flwr.server import ServerApp, ServerAppComponents, ServerConfig

from $import_name.task import get_params, load_model

def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]
input_dim = context.run_config["input-dim"]

# Initialize global model
params = get_params(load_model((input_dim,)))
initial_parameters = ndarrays_to_parameters(params)

# Define strategy
strategy = FedAvg()
strategy = FedAvg(initial_parameters=initial_parameters)
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)
Expand Down
16 changes: 14 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
"""$project_name: A Flower / $framework_str app."""

from flwr.common import Context
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from $import_name.task import get_params, MLP


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]

num_classes = 10
num_layers = context.run_config["num-layers"]
input_dim = context.run_config["input-dim"]
hidden_dim = context.run_config["hidden-dim"]

# Initialize global model
model = MLP(num_layers, input_dim, hidden_dim, num_classes)
params = get_params(model)
initial_parameters = ndarrays_to_parameters(params)

# Define strategy
strategy = FedAvg()
strategy = FedAvg(initial_parameters=initial_parameters)
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)
Expand Down
10 changes: 8 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
"""$project_name: A Flower / $framework_str app."""

from flwr.common import Context
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from $import_name.task import get_dummy_model


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]

# Initial model
model = get_dummy_model()
dummy_parameters = ndarrays_to_parameters([model])

# Define strategy
strategy = FedAvg()
strategy = FedAvg(initial_parameters=dummy_parameters)
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)
Expand Down
15 changes: 14 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
"""$project_name: A Flower / $framework_str app."""

from flwr.common import Context
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from $import_name.task import get_model, get_model_params, set_initial_params


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]

# Create LogisticRegression Model
penalty = context.run_config["penalty"]
local_epochs = context.run_config["local-epochs"]
model = get_model(penalty, local_epochs)

# Setting initial parameters, akin to model.compile for keras models
set_initial_params(model)

initial_parameters = ndarrays_to_parameters(get_model_params(model))

# Define strategy
strategy = FedAvg(
fraction_fit=1.0,
fraction_evaluate=1.0,
min_available_clients=2,
initial_parameters=initial_parameters,
)
config = ServerConfig(num_rounds=num_rounds)

Expand Down
7 changes: 7 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/task.numpy.py.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""$project_name: A Flower / $framework_str app."""

import numpy as np

def get_dummy_model():
return np.ones((1,1))

Loading