Skip to content

Commit

Permalink
Migrate tutorial-series-get-started-with-flower-pytorch to FDS (#2480)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak authored Dec 21, 2023
1 parent 18aebd7 commit 34fd0b3
Showing 1 changed file with 52 additions and 72 deletions.
124 changes: 52 additions & 72 deletions doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"\n",
"Welcome to the Flower federated learning tutorial!\n",
"\n",
"In this notebook, we'll build a federated learning system using Flower and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.\n",
"In this notebook, we'll build a federated learning system using Flower, [Flower Datasets](https://flower.dev/docs/datasets/) and PyTorch. In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.\n",
"\n",
"> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.\n",
"\n",
Expand All @@ -31,7 +31,7 @@
"source": [
"### Installing dependencies\n",
"\n",
"Next, we install the necessary packages for PyTorch (`torch` and `torchvision`) and Flower (`flwr`):"
"Next, we install the necessary packages for PyTorch (`torch` and `torchvision`), Flower Datasets (`flwr-datasets`) and Flower (`flwr`):"
]
},
{
Expand All @@ -40,7 +40,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -q flwr[simulation] torch torchvision matplotlib"
"!pip install -q flwr[simulation] flwr_datasets[vision] torch torchvision matplotlib"
]
},
{
Expand All @@ -64,18 +64,19 @@
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from torch.utils.data import DataLoader, random_split\n",
"from torchvision.datasets import CIFAR10\n",
"from datasets.utils.logging import disable_progress_bar\n",
"from torch.utils.data import DataLoader\n",
"\n",
"import flwr as fl\n",
"from flwr.common import Metrics\n",
"from flwr_datasets import FederatedDataset\n",
"\n",
"DEVICE = torch.device(\"cpu\") # Try \"cuda\" to train on GPU\n",
"print(\n",
" f\"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}\"\n",
")"
")\n",
"disable_progress_bar()"
]
},
{
Expand All @@ -92,27 +93,7 @@
"\n",
"### Loading the data\n",
"\n",
"Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"CLASSES = (\n",
" \"plane\",\n",
" \"car\",\n",
" \"bird\",\n",
" \"cat\",\n",
" \"deer\",\n",
" \"dog\",\n",
" \"frog\",\n",
" \"horse\",\n",
" \"ship\",\n",
" \"truck\",\n",
")"
"Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', and 'truck'."
]
},
{
Expand All @@ -121,24 +102,15 @@
"source": [
"We simulate having multiple datasets from multiple organizations (also called the \"cross-silo\" setting in federated learning) by splitting the original CIFAR-10 dataset into multiple partitions. Each partition will represent the data from a single organization. We're doing this purely for experimentation purposes, in the real world there's no need for data splitting because each organization already has their own data (so the data is naturally partitioned).\n",
"\n",
"Each organization will act as a client in the federated learning system. So having ten organizations participate in a federation means having ten clients connected to the federated learning server:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"NUM_CLIENTS = 10"
"Each organization will act as a client in the federated learning system. So having ten organizations participate in a federation means having ten clients connected to the federated learning server.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Let's now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap the resulting partitions by creating a PyTorch `DataLoader` for each of them:"
"Let's now create the Federated Dataset abstraction that from `flwr-datasets` that partitions the CIFAR-10. We will create small training and test set for each edge device and wrap each of them into a PyTorch `DataLoader`:"
]
},
{
Expand All @@ -147,32 +119,36 @@
"metadata": {},
"outputs": [],
"source": [
"NUM_CLIENTS = 10\n",
"BATCH_SIZE = 32\n",
"\n",
"\n",
"def load_datasets():\n",
" # Download and transform CIFAR-10 (train and test)\n",
" transform = transforms.Compose(\n",
" [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
" )\n",
" trainset = CIFAR10(\"./dataset\", train=True, download=True, transform=transform)\n",
" testset = CIFAR10(\"./dataset\", train=False, download=True, transform=transform)\n",
"\n",
" # Split training set into 10 partitions to simulate the individual dataset\n",
" partition_size = len(trainset) // NUM_CLIENTS\n",
" lengths = [partition_size] * NUM_CLIENTS\n",
" datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))\n",
"\n",
" # Split each partition into train/val and create DataLoader\n",
" fds = FederatedDataset(dataset=\"cifar10\", partitioners={\"train\": NUM_CLIENTS})\n",
"\n",
" def apply_transforms(batch):\n",
" # Instead of passing transforms to CIFAR10(..., transform=transform)\n",
" # we will use this function to dataset.with_transform(apply_transforms)\n",
" # The transforms object is exactly the same\n",
" transform = transforms.Compose(\n",
" [\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
" ]\n",
" )\n",
" batch[\"img\"] = [transform(img) for img in batch[\"img\"]]\n",
" return batch\n",
"\n",
" # Create train/val for each partition and wrap it into DataLoader\n",
" trainloaders = []\n",
" valloaders = []\n",
" for ds in datasets:\n",
" len_val = len(ds) // 10 # 10 % validation set\n",
" len_train = len(ds) - len_val\n",
" lengths = [len_train, len_val]\n",
" ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))\n",
" trainloaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True))\n",
" valloaders.append(DataLoader(ds_val, batch_size=BATCH_SIZE))\n",
" for partition_id in range(NUM_CLIENTS):\n",
" partition = fds.load_partition(partition_id, \"train\")\n",
" partition = partition.with_transform(apply_transforms)\n",
" partition = partition.train_test_split(train_size=0.8)\n",
" trainloaders.append(DataLoader(partition[\"train\"], batch_size=BATCH_SIZE))\n",
" valloaders.append(DataLoader(partition[\"test\"], batch_size=BATCH_SIZE))\n",
" testset = fds.load_full(\"test\").with_transform(apply_transforms)\n",
" testloader = DataLoader(testset, batch_size=BATCH_SIZE)\n",
" return trainloaders, valloaders, testloader\n",
"\n",
Expand All @@ -195,8 +171,8 @@
"metadata": {},
"outputs": [],
"source": [
"images, labels = next(iter(trainloaders[0]))\n",
"\n",
"batch = next(iter(trainloaders[0]))\n",
"images, labels = batch[\"img\"], batch[\"label\"]\n",
"# Reshape and convert images to a NumPy array\n",
"# matplotlib requires images with the shape (height, width, 3)\n",
"images = images.permute(0, 2, 3, 1).numpy()\n",
Expand All @@ -209,7 +185,7 @@
"# Loop over the images and plot them\n",
"for i, ax in enumerate(axs.flat):\n",
" ax.imshow(images[i])\n",
" ax.set_title(CLASSES[labels[i]])\n",
" ax.set_title(trainloaders[0].dataset.features[\"label\"].int2str([labels[i]])[0])\n",
" ax.axis(\"off\")\n",
"\n",
"# Show the plot\n",
Expand Down Expand Up @@ -294,8 +270,8 @@
" net.train()\n",
" for epoch in range(epochs):\n",
" correct, total, epoch_loss = 0, 0, 0.0\n",
" for images, labels in trainloader:\n",
" images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
" for batch in trainloader:\n",
" images, labels = batch[\"img\"].to(DEVICE), batch[\"label\"].to(DEVICE)\n",
" optimizer.zero_grad()\n",
" outputs = net(images)\n",
" loss = criterion(outputs, labels)\n",
Expand All @@ -317,8 +293,8 @@
" correct, total, loss = 0, 0, 0.0\n",
" net.eval()\n",
" with torch.no_grad():\n",
" for images, labels in testloader:\n",
" images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
" for batch in testloader:\n",
" images, labels = batch[\"img\"].to(DEVICE), batch[\"label\"].to(DEVICE)\n",
" outputs = net(images)\n",
" loss += criterion(outputs, labels).item()\n",
" _, predicted = torch.max(outputs.data, 1)\n",
Expand Down Expand Up @@ -477,7 +453,7 @@
" valloader = valloaders[int(cid)]\n",
"\n",
" # Create a single Flower client representing a single organization\n",
" return FlowerClient(net, trainloader, valloader)"
" return FlowerClient(net, trainloader, valloader).to_client()"
]
},
{
Expand Down Expand Up @@ -508,10 +484,14 @@
" min_available_clients=10, # Wait until all 10 clients are available\n",
")\n",
"\n",
"# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)\n",
"client_resources = None\n",
"# Specify the resources each of your clients need. By default, each \n",
"# client will be allocated 1x CPU and 0x CPUs\n",
"client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n",
"if DEVICE.type == \"cuda\":\n",
" client_resources = {\"num_gpus\": 1}\n",
" # here we are asigning an entire GPU for each client.\n",
" client_resources = {\"num_cpus\": 1, \"num_gpus\": 1.0}\n",
" # Refer to our documentation for more details about Flower Simulations\n",
" # and how to setup these `client_resources`.\n",
"\n",
"# Start simulation\n",
"fl.simulation.start_simulation(\n",
Expand Down Expand Up @@ -629,7 +609,7 @@
"\n",
"There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!\n",
"\n",
"The [Flower Federated Learning Tutorial - Part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them."
"The [Flower Federated Learning Tutorial - Part 2](https://flower.dev/docs/framework/tutorial-use-a-federated-learning-strategy-pytorch.html) goes into more depth about strategies and all the advanced things you can build with them.\n"
]
}
],
Expand All @@ -640,7 +620,7 @@
"toc_visible": true
},
"kernelspec": {
"display_name": "flower-3.7.12",
"display_name": "flwr",
"language": "python",
"name": "python3"
}
Expand Down

0 comments on commit 34fd0b3

Please sign in to comment.