Skip to content

Commit

Permalink
Add TensorFlow template to flwr new (#3243)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Apr 19, 2024
1 parent 1d2db04 commit a47fb00
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl
Original file line number Diff line number Diff line change
@@ -1 +1,57 @@
"""$project_name: A Flower / TensorFlow app."""

import os

import tensorflow as tf
from flwr.client import ClientApp, NumPyClient
from flwr_datasets import FederatedDataset


os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# Define Flower client
class FlowerClient(NumPyClient):
def __init__(self, model, train_data, test_data):
self.model = model
self.x_train, self.y_train = train_data
self.x_test, self.y_test = test_data

def get_parameters(self, config):
return self.model.get_weights()

def fit(self, parameters, config):
self.model.set_weights(parameters)
self.model.fit(self.x_train, self.y_train, epochs=1, batch_size=32, verbose=0)
return self.model.get_weights(), len(self.x_train), {}

def evaluate(self, parameters, config):
self.model.set_weights(parameters)
loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
return loss, len(self.x_test), {"accuracy": accuracy}


fds = FederatedDataset(dataset="cifar10", partitioners={"train": 2})

def client_fn(cid: str):
"""Create and return an instance of Flower `Client`."""

# Load model and data (MobileNetV2, CIFAR-10)
model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None)
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])

# Download and partition dataset
partition = fds.load_partition(int(cid), "train")
partition.set_format("numpy")

# Divide data on each node: 80% train, 20% test
partition = partition.train_test_split(test_size=0.2, seed=42)
train_data = partition["train"]["img"] / 255.0, partition["train"]["label"]
test_data = partition["test"]["img"] / 255.0, partition["test"]["label"]

return FlowerClient(model, train_data, test_data).to_client()


# Flower ClientApp
app = ClientApp(
client_fn=client_fn,
)
18 changes: 18 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/server.tensorflow.py.tpl
Original file line number Diff line number Diff line change
@@ -1 +1,19 @@
"""$project_name: A Flower / TensorFlow app."""

from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg

# Define config
config = ServerConfig(num_rounds=3)

strategy = FedAvg(
fraction_fit=1.0,
fraction_evaluate=1.0,
min_available_clients=2,
)

# Flower ServerApp
app = ServerApp(
config=config,
strategy=strategy,
)

0 comments on commit a47fb00

Please sign in to comment.