Skip to content

Commit

Permalink
Migrate advanced tensorflow to use FDS
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Jan 17, 2024
1 parent 0daa3d7 commit 05755d1
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion examples/advanced-tensorflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. It differs from the quickstart example in the following ways:

- 10 clients (instead of just 2)
- Each client holds a local dataset of 5000 training examples and 1000 test examples (note that by default only a small subset of this data is used when running the `run.sh` script)
- Each client holds a local dataset of 1/10 of the train datasets and 80% is training examples and 20% as test examples (note that by default only a small subset of this data is used when running the `run.sh` script)
- Server-side model evaluation after parameter aggregation
- Hyperparameter schedule using config functions
- Custom return values
Expand Down
23 changes: 13 additions & 10 deletions examples/advanced-tensorflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import flwr as fl

from flwr_datasets import FederatedDataset

# Make TensorFlow logs less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

Expand Down Expand Up @@ -99,7 +101,7 @@ def main() -> None:
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])

# Load a subset of CIFAR-10 to simulate the local data partition
(x_train, y_train), (x_test, y_test) = load_partition(args.partition)
x_train, y_train, x_test, y_test = load_partition(args.partition)

if args.toy:
x_train, y_train = x_train[:10], y_train[:10]
Expand All @@ -117,15 +119,16 @@ def main() -> None:

def load_partition(idx: int):
"""Load 1/10th of the training and test data to simulate a partition."""
assert idx in range(10)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
return (
x_train[idx * 5000 : (idx + 1) * 5000],
y_train[idx * 5000 : (idx + 1) * 5000],
), (
x_test[idx * 1000 : (idx + 1) * 1000],
y_test[idx * 1000 : (idx + 1) * 1000],
)
# Download and partition dataset
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
partition = fds.load_partition(idx)
partition.set_format("numpy")

# Divide data on each node: 80% train, 20% test
partition = partition.train_test_split(test_size=0.2)
x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"]
x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"]
return x_train, y_train, x_test, y_test


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/advanced-tensorflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ authors = ["The Flower Authors <[email protected]>"]
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" }
tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""}
tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""}
5 changes: 1 addition & 4 deletions examples/advanced-tensorflow/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ echo "Starting server"
python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start

# Ensure that the Keras dataset used in client.py is already cached.
python -c "import tensorflow as tf; tf.keras.datasets.cifar10.load_data()"

for i in `seq 0 9`; do
for i in $(seq 0 9); do
echo "Starting client $i"
python client.py --partition=${i} --toy True &
done
Expand Down
14 changes: 8 additions & 6 deletions examples/advanced-tensorflow/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import flwr as fl
import tensorflow as tf

from flwr_datasets import FederatedDataset


def main() -> None:
# Load and compile model for
Expand Down Expand Up @@ -43,11 +45,11 @@ def main() -> None:
def get_evaluate_fn(model):
"""Return an evaluation function for server-side evaluation."""

# Load data and model here to avoid the overhead of doing it in `evaluate` itself
(x_train, y_train), _ = tf.keras.datasets.cifar10.load_data()

# Use the last 5k training examples as a validation set
x_val, y_val = x_train[45000:50000], y_train[45000:50000]
# Load data here to avoid the overhead of doing it in `evaluate` itself
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
test = fds.load_full("test")
test.set_format("numpy")
x_test, y_test = test["img"] / 255.0, test["label"]

# The `evaluate` function will be called after every round
def evaluate(
Expand All @@ -56,7 +58,7 @@ def evaluate(
config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
model.set_weights(parameters) # Update model with the latest parameters
loss, accuracy = model.evaluate(x_val, y_val)
loss, accuracy = model.evaluate(x_test, y_test)
return loss, {"accuracy": accuracy}

return evaluate
Expand Down

0 comments on commit 05755d1

Please sign in to comment.