Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MONAI quickstart example #2773

Merged
merged 29 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3b112af
First draft
charlesbvll Oct 12, 2023
364678d
Working version
charlesbvll Oct 13, 2023
cb9453c
Merge branch 'main' of https://github.com/adap/flower into add-monai-…
charlesbvll Dec 20, 2023
7ebbd8b
Fix imports
charlesbvll Dec 20, 2023
6245f64
Formatting
charlesbvll Dec 20, 2023
423e51e
Update README
charlesbvll Dec 20, 2023
0b127de
Update requirements
charlesbvll Dec 20, 2023
e6248b0
Add README
charlesbvll Dec 22, 2023
1d94f25
Add partitioning
charlesbvll Dec 22, 2023
182a9f8
Merge branch 'main' of https://github.com/adap/flower into add-monai-…
charlesbvll Jan 4, 2024
550e817
Update examples/quickstart-monai/requirements.txt
charlesbvll Jan 5, 2024
2bcddda
Update examples/quickstart-monai/pyproject.toml
charlesbvll Jan 5, 2024
4434cf4
Update examples/quickstart-monai/README.md
charlesbvll Jan 5, 2024
ab3f57e
Merge branch 'main' into add-monai-quickstart
charlesbvll Jan 5, 2024
a235e3b
Use num_workers=0
charlesbvll Jan 5, 2024
ef3052a
Merge branch 'main' into add-monai-quickstart
charlesbvll Jan 5, 2024
932be8b
PIL wasn't really used, removing also from env setups
jafermarq Jan 5, 2024
82823fd
Update examples/quickstart-monai/data.py
charlesbvll Jan 5, 2024
f70ec64
Merge branch 'main' into add-monai-quickstart
charlesbvll Jan 17, 2024
43d7531
Merge branch 'main' into add-monai-quickstart
charlesbvll Jan 17, 2024
026471a
consistency with other examples
jafermarq Jan 17, 2024
5b4e20f
Merge branch 'main' into add-monai-quickstart
jafermarq Jan 17, 2024
cd04f6c
formatting
jafermarq Jan 17, 2024
af97e03
in top-level `README.md`
jafermarq Jan 17, 2024
ab482a9
Merge branch 'main' into add-monai-quickstart
jafermarq Jan 17, 2024
3809f1f
Merge branch 'main' into add-monai-quickstart
charlesbvll Jan 18, 2024
5dabf85
merge w/ main and updates
jafermarq Feb 28, 2024
68c80a3
format
jafermarq Feb 28, 2024
868bad5
Merge branch 'main' into add-monai-quickstart
danieljanes Feb 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ design of Flower is based on a few guiding principles:

- **Framework-agnostic**: Different machine learning frameworks have different
strengths. Flower can be used with any machine learning framework, for
example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/)
example, [PyTorch](https://pytorch.org), [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [MONAI](https://docs.monai.io/en/latest/index.html), [fastai](https://www.fast.ai/), [MLX](https://ml-explore.github.io/mlx/build/html/index.html), [XGBoost](https://xgboost.readthedocs.io/en/stable/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/)
for users who enjoy computing gradients by hand.

- **Understandable**: Flower is written with maintainability in mind. The
Expand Down Expand Up @@ -130,6 +130,7 @@ Quickstart examples:
- [Quickstart (fastai)](https://github.com/adap/flower/tree/main/examples/quickstart-fastai)
- [Quickstart (Pandas)](https://github.com/adap/flower/tree/main/examples/quickstart-pandas)
- [Quickstart (JAX)](https://github.com/adap/flower/tree/main/examples/quickstart-jax)
- [Quickstart (MONAI)](https://github.com/adap/flower/tree/main/examples/quickstart-monai)
- [Quickstart (scikit-learn)](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)
- [Quickstart (XGBoost)](https://github.com/adap/flower/tree/main/examples/xgboost-quickstart)
- [Quickstart (Android [TFLite])](https://github.com/adap/flower/tree/main/examples/android)
Expand Down
1 change: 1 addition & 0 deletions examples/quickstart-monai/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
MedNIST*
85 changes: 85 additions & 0 deletions examples/quickstart-monai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Flower Example using MONAI

This introductory example to Flower uses MONAI, but deep knowledge of MONAI is not necessarily required to run the example. However, it will help you understand how to adapt Flower to your use case.
Running this example in itself is quite easy.

[MONAI](https://docs.monai.io/en/latest/index.html)(Medical Open Network for AI) is a PyTorch-based, open-source framework for deep learning in healthcare imaging, part of the PyTorch Ecosystem.

Its ambitions are:

- developing a community of academic, industrial and clinical researchers collaborating on a common foundation;

- creating state-of-the-art, end-to-end training workflows for healthcare imaging;

- providing researchers with an optimized and standardized way to create and evaluate deep learning models.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/quickstart-monai . && rm -rf _tmp && cd quickstart-monai
```

This will create a new directory called `quickstart-monai` containing the following files:

```shell
-- pyproject.toml
-- requirements.txt
-- client.py
-- data.py
-- model.py
-- server.py
-- README.md
```

### Installing Dependencies

Project dependencies (such as `monai` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
pip install -r requirements.txt
```

## Run Federated Learning with MONAI and Flower

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:

```shell
python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminal windows and run the following commands. Clients will train a [DenseNet121](https://docs.monai.io/en/stable/networks.html#densenet121) from MONAI. If a GPU is present in your system, clients will use it.

Start client 1 in the first terminal:

```shell
python3 client.py --partition-id 0
```

Start client 2 in the second terminal:

```shell
python3 client.py --partition-id 1
```

You will see that the federated training is starting. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-monai) for a detailed explanation.
61 changes: 61 additions & 0 deletions examples/quickstart-monai/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import argparse
import warnings
from collections import OrderedDict

import torch
from data import load_data
from model import test, train
from monai.networks.nets.densenet import DenseNet121

import flwr as fl

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def __init__(self, net, trainloader, testloader, device):
self.net = net
self.trainloader = trainloader
self.testloader = testloader
self.device = device

def get_parameters(self, config):
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
self.net.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
self.set_parameters(parameters)
train(self.net, self.trainloader, epoch_num=1, device=self.device)
return self.get_parameters(config={}), len(self.trainloader), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(self.net, self.testloader, self.device)
return loss, len(self.testloader), {"accuracy": accuracy}


if __name__ == "__main__":
total_partitions = 10
parser = argparse.ArgumentParser()
parser.add_argument(
"--partition-id", type=int, choices=range(total_partitions), required=True
)
args = parser.parse_args()

# Load model and data (simple CNN, CIFAR-10)
trainloader, _, testloader, num_class = load_data(
total_partitions, args.partition_id
)
net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=num_class).to(DEVICE)

# Start Flower client
fl.client.start_numpy_client(
server_address="127.0.0.1:8080",
client=FlowerClient(net, trainloader, testloader, DEVICE),
)
158 changes: 158 additions & 0 deletions examples/quickstart-monai/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
import tarfile
from urllib import request

import numpy as np
from monai.data import DataLoader, Dataset
from monai.transforms import (
Compose,
EnsureChannelFirst,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
ToTensor,
)


def _partition(files_list, labels_list, num_shards, index):
total_size = len(files_list)
assert total_size == len(
labels_list
), f"List of datapoints and labels must be of the same length"
shard_size = total_size // num_shards

# Calculate start and end indices for the shard
start_idx = index * shard_size
if index == num_shards - 1:
# Last shard takes the remainder
end_idx = total_size
else:
end_idx = start_idx + shard_size

# Create a subset for the shard
files = files_list[start_idx:end_idx]
labels = labels_list[start_idx:end_idx]
return files, labels


def load_data(num_shards, index):
image_file_list, image_label_list, _, num_class = _download_data()

# Get partition given index
files_list, labels_list = _partition(
image_file_list, image_label_list, num_shards, index
)

trainX, trainY, valX, valY, testX, testY = _split_data(
files_list, labels_list, len(files_list)
)
train_transforms, val_transforms = _get_transforms()

train_ds = MedNISTDataset(trainX, trainY, train_transforms)
train_loader = DataLoader(train_ds, batch_size=300, shuffle=True)

val_ds = MedNISTDataset(valX, valY, val_transforms)
val_loader = DataLoader(val_ds, batch_size=300)

test_ds = MedNISTDataset(testX, testY, val_transforms)
test_loader = DataLoader(test_ds, batch_size=300)

return train_loader, val_loader, test_loader, num_class


class MedNISTDataset(Dataset):
def __init__(self, image_files, labels, transforms):
self.image_files = image_files
self.labels = labels
self.transforms = transforms

def __len__(self):
return len(self.image_files)

def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]


def _download_data():
data_dir = "./MedNIST/"
_download_and_extract(
"https://dl.dropboxusercontent.com/s/5wwskxctvcxiuea/MedNIST.tar.gz",
os.path.join(data_dir),
)

class_names = sorted(
[x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x))]
)
num_class = len(class_names)
image_files = [
[
os.path.join(data_dir, class_name, x)
for x in os.listdir(os.path.join(data_dir, class_name))
]
for class_name in class_names
]
image_file_list = []
image_label_list = []
for i, class_name in enumerate(class_names):
image_file_list.extend(image_files[i])
image_label_list.extend([i] * len(image_files[i]))
num_total = len(image_label_list)
return image_file_list, image_label_list, num_total, num_class


def _split_data(image_file_list, image_label_list, num_total):
valid_frac, test_frac = 0.1, 0.1
trainX, trainY = [], []
valX, valY = [], []
testX, testY = [], []

for i in range(num_total):
rann = np.random.random()
if rann < valid_frac:
valX.append(image_file_list[i])
valY.append(image_label_list[i])
elif rann < test_frac + valid_frac:
testX.append(image_file_list[i])
testY.append(image_label_list[i])
else:
trainX.append(image_file_list[i])
trainY.append(image_label_list[i])

return trainX, trainY, valX, valY, testX, testY


def _get_transforms():
train_transforms = Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
ScaleIntensity(),
RandRotate(range_x=15, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5, keep_size=True),
ToTensor(),
]
)

val_transforms = Compose(
[LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity(), ToTensor()]
)

return train_transforms, val_transforms


def _download_and_extract(url, dest_folder):
if not os.path.isdir(dest_folder):
# Download the tar.gz file
tar_gz_filename = url.split("/")[-1]
if not os.path.isfile(tar_gz_filename):
with request.urlopen(url) as response, open(
tar_gz_filename, "wb"
) as out_file:
out_file.write(response.read())

# Extract the tar.gz file
with tarfile.open(tar_gz_filename, "r:gz") as tar_ref:
tar_ref.extractall()
33 changes: 33 additions & 0 deletions examples/quickstart-monai/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch


def train(model, train_loader, epoch_num, device):
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
for _ in range(epoch_num):
model.train()
for inputs, labels in train_loader:
optimizer.zero_grad()
loss_function(model(inputs.to(device)), labels.to(device)).backward()
optimizer.step()


def test(model, test_loader, device):
model.eval()
loss = 0.0
y_true = list()
y_pred = list()
loss_function = torch.nn.CrossEntropyLoss()
with torch.no_grad():
for test_images, test_labels in test_loader:
out = model(test_images.to(device))
test_labels = test_labels.to(device)
loss += loss_function(out, test_labels).item()
pred = out.argmax(dim=1)
for i in range(len(pred)):
y_true.append(test_labels[i].item())
y_pred.append(pred[i].item())
accuracy = sum([1 if t == p else 0 for t, p in zip(y_true, y_pred)]) / len(
test_loader.dataset
)
return loss, accuracy
19 changes: 19 additions & 0 deletions examples/quickstart-monai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "quickstart-monai"
version = "0.1.0"
description = "MONAI Federated Learning Quickstart with Flower"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
torch = "1.13.1"
tqdm = "4.65.0"
scikit-learn = "1.3.1"
monai = { version = "1.3.0", extras=["gdown", "nibabel", "tqdm", "itk"] }
numpy = "1.24.4"
pillow = "10.2.0"
7 changes: 7 additions & 0 deletions examples/quickstart-monai/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
flwr>=1.0, <2.0
torch==1.13.1
tqdm==4.65.0
scikit-learn==1.3.1
monai[gdown,nibabel,tqdm,itk]==1.3.0
numpy==1.24.4
pillow==10.2.0
Loading