Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam committed Oct 10, 2024
1 parent 11f4d72 commit 96222d6
Show file tree
Hide file tree
Showing 7 changed files with 421 additions and 0 deletions.
160 changes: 160 additions & 0 deletions baselines/fedbabu/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
20 changes: 20 additions & 0 deletions baselines/fedbabu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# fedbabu: A Flower / PyTorch app

## Install dependencies and project

```bash
pip install -e .
```

## Run with the Simulation Engine

In the `fedbabu` directory, use `flwr run` to run a local simulation:

```bash
flwr run .
```

## Run with the Deployment Engine

> \[!NOTE\]
> An update to this example will show how to run this Flower application with the Deployment Engine and TLS certificates, or with Docker.
1 change: 1 addition & 0 deletions baselines/fedbabu/fedbabu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""fedbabu: A Flower / PyTorch app."""
58 changes: 58 additions & 0 deletions baselines/fedbabu/fedbabu/client_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""fedbabu: A Flower / PyTorch app."""

import torch
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context

from fedbabu.task import (
Net,
load_data,
get_weights,
set_weights,
train,
test,
)


# Define Flower Client and client_fn
class FlowerClient(NumPyClient):
def __init__(self, net, trainloader, valloader, local_epochs):
self.net = net
self.trainloader = trainloader
self.valloader = valloader
self.local_epochs = local_epochs
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.net.to(self.device)

def fit(self, parameters, config):
set_weights(self.net, parameters)
train_loss = train(
self.net,
self.trainloader,
self.local_epochs,
self.device,
)
return get_weights(self.net), len(self.trainloader.dataset), {"train_loss": train_loss}

def evaluate(self, parameters, config):
set_weights(self.net, parameters)
loss, accuracy = test(self.net, self.valloader, self.device)
return loss, len(self.valloader.dataset), {"accuracy": accuracy}


def client_fn(context: Context):
# Load model and data
net = Net()
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
trainloader, valloader = load_data(partition_id, num_partitions)
local_epochs = context.run_config["local-epochs"]

# Return Client instance
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()


# Flower ClientApp
app = ClientApp(
client_fn,
)
31 changes: 31 additions & 0 deletions baselines/fedbabu/fedbabu/server_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""fedbabu: A Flower / PyTorch app."""

from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from fedbabu.task import Net, get_weights


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]
fraction_fit = context.run_config["fraction-fit"]

# Initialize model parameters
ndarrays = get_weights(Net())
parameters = ndarrays_to_parameters(ndarrays)

# Define strategy
strategy = FedAvg(
fraction_fit=fraction_fit,
fraction_evaluate=1.0,
min_available_clients=2,
initial_parameters=parameters,
)
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)

# Create ServerApp
app = ServerApp(server_fn=server_fn)
115 changes: 115 additions & 0 deletions baselines/fedbabu/fedbabu/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""fedbabu: A Flower / PyTorch app."""

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize, ToTensor
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner
from flwr_datasets.preprocessor import Merger


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)


fds = None # Cache FederatedDataset


def load_data(partition_id: int, num_partitions: int):
"""Load partition CIFAR10 data."""
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = DirichletPartitioner(
num_partitions=num_partitions, partition_by="label", alpha=0.5
)
fds = FederatedDataset(
dataset="uoft-cs/cifar10",
partitioners={"train": partitioner},
preprocessor=Merger({"train": ("train", "test")}),
)
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
pytorch_transforms = Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch

partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=32)
return trainloader, testloader


def train(net, trainloader, epochs, device):
"""Train the model on the training set."""
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
net.train()
running_loss = 0.0
for _ in range(epochs):
for batch in trainloader:
images = batch["img"]
labels = batch["label"]
optimizer.zero_grad()
loss = criterion(net(images.to(device)), labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()

avg_trainloss = running_loss / len(trainloader)
return avg_trainloss


def test(net, testloader, device):
"""Validate the model on the test set."""
net.to(device)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for batch in testloader:
images = batch["img"].to(device)
labels = batch["label"].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
loss = loss / len(testloader)
return loss, accuracy


def get_weights(net):
return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_weights(net, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
Loading

0 comments on commit 96222d6

Please sign in to comment.