Skip to content

Commit

Permalink
refactor(framework) Update MLX template (#4291)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Oct 8, 2024
1 parent ef3646c commit 5c88e49
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 39 deletions.
52 changes: 16 additions & 36 deletions src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context

from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from flwr.common.config import UserConfig
from $import_name.task import (
MLP,
batch_iterate,
eval_fn,
get_params,
load_data,
loss_fn,
set_params,
MLP,
)


Expand All @@ -22,48 +23,35 @@ class FlowerClient(NumPyClient):
def __init__(
self,
data,
num_layers,
hidden_dim,
run_config: UserConfig,
num_classes,
batch_size,
learning_rate,
num_epochs,
):
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.batch_size = batch_size
self.learning_rate = learning_rate
self.num_epochs = num_epochs
num_layers = run_config["num-layers"]
hidden_dim = run_config["hidden-dim"]
input_dim = run_config["input-dim"]
batch_size = run_config["batch-size"]
learning_rate = run_config["lr"]
self.num_epochs = run_config["local-epochs"]

self.train_images, self.train_labels, self.test_images, self.test_labels = data
self.model = MLP(
num_layers, self.train_images.shape[-1], hidden_dim, num_classes
)
self.model = MLP(num_layers, input_dim, hidden_dim, num_classes)
self.optimizer = optim.SGD(learning_rate=learning_rate)
self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn)
self.num_epochs = num_epochs
self.batch_size = batch_size

def get_parameters(self, config):
return get_params(self.model)

def set_parameters(self, parameters):
set_params(self.model, parameters)

def fit(self, parameters, config):
self.set_parameters(parameters)
set_params(self.model, parameters)
for _ in range(self.num_epochs):
for X, y in batch_iterate(
self.batch_size, self.train_images, self.train_labels
):
_, grads = self.loss_and_grad_fn(self.model, X, y)
self.optimizer.update(self.model, grads)
mx.eval(self.model.parameters(), self.optimizer.state)
return self.get_parameters(config={}), len(self.train_images), {}
return get_params(self.model), len(self.train_images), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
set_params(self.model, parameters)
accuracy = eval_fn(self.model, self.test_images, self.test_labels)
loss = loss_fn(self.model, self.test_images, self.test_labels)
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}
Expand All @@ -73,18 +61,10 @@ def client_fn(context: Context):
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
data = load_data(partition_id, num_partitions)

num_layers = context.run_config["num-layers"]
hidden_dim = context.run_config["hidden-dim"]
num_classes = 10
batch_size = context.run_config["batch-size"]
learning_rate = context.run_config["lr"]
num_epochs = context.run_config["local-epochs"]

# Return Client instance
return FlowerClient(
data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs
).to_client()
return FlowerClient(data, context.run_config, num_classes).to_client()


# Flower ClientApp
Expand Down
15 changes: 13 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
"""$project_name: A Flower / $framework_str app."""

from flwr.common import Context
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from $import_name.task import MLP, get_params


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]

num_classes = 10
num_layers = context.run_config["num-layers"]
input_dim = context.run_config["input-dim"]
hidden_dim = context.run_config["hidden-dim"]

# Initialize global model
model = MLP(num_layers, input_dim, hidden_dim, num_classes)
params = get_params(model)
initial_parameters = ndarrays_to_parameters(params)

# Define strategy
strategy = FedAvg()
strategy = FedAvg(initial_parameters=initial_parameters)
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from datasets.utils.logging import disable_progress_bar
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

from datasets.utils.logging import disable_progress_bar

disable_progress_bar()

Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ clientapp = "$import_name.client_app:app"
num-server-rounds = 3
local-epochs = 1
num-layers = 2
input-dim = 784 # 28*28
hidden-dim = 32
batch-size = 256
lr = 0.1
Expand Down

0 comments on commit 5c88e49

Please sign in to comment.