Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam committed Oct 11, 2024
1 parent 96222d6 commit 157938c
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 35 deletions.
30 changes: 17 additions & 13 deletions baselines/fedbabu/fedbabu/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flwr.common import Context

from fedbabu.task import (
Net,
MobileNetCifar,
load_data,
get_weights,
set_weights,
Expand All @@ -26,33 +26,37 @@ def __init__(self, net, trainloader, valloader, local_epochs):

def fit(self, parameters, config):
set_weights(self.net, parameters)
train_loss = train(
self.net,
self.trainloader,
self.local_epochs,
self.device,
train_loss = train(self.net, self.trainloader, self.local_epochs, self.device)
return (
get_weights(self.net),
len(self.trainloader.dataset),
{"train_loss": train_loss},
)
return get_weights(self.net), len(self.trainloader.dataset), {"train_loss": train_loss}

def evaluate(self, parameters, config):
set_weights(self.net, parameters)
loss, accuracy = test(self.net, self.valloader, self.device)
loss, accuracy = test(
self.net,
self.valloader,
self.trainloader,
self.device,
config["finetune-epochs"],
)
return loss, len(self.valloader.dataset), {"accuracy": accuracy}


def client_fn(context: Context):
# Load model and data
net = Net()
net = MobileNetCifar()
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
trainloader, valloader = load_data(partition_id, num_partitions)
alpha = context.client_config["alpha"]
trainloader, valloader = load_data(partition_id, num_partitions, alpha)
local_epochs = context.run_config["local-epochs"]

# Return Client instance
return FlowerClient(net, trainloader, valloader, local_epochs).to_client()


# Flower ClientApp
app = ClientApp(
client_fn,
)
app = ClientApp(client_fn)
128 changes: 106 additions & 22 deletions baselines/fedbabu/fedbabu/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,95 @@
from flwr_datasets.partitioner import DirichletPartitioner
from flwr_datasets.preprocessor import Merger

'''
MobileNet in PyTorch.
See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
for more details.
'''


class Block(nn.Module):
'''Depthwise conv + Pointwise conv'''

def __init__(self, in_planes, out_planes, stride=1):
super(Block, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
in_planes,
kernel_size=3,
stride=stride,
padding=1,
groups=in_planes,
bias=False,
)
self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=False)
self.conv2 = nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False
)
self.bn2 = nn.BatchNorm2d(out_planes, track_running_stats=False)

class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
return out


class MobileNetCifar(nn.Module):
# (128,2) means conv planes=128, conv stride=2, by default conv stride=1
cfg = [
64,
(128, 2),
128,
(256, 2),
256,
(512, 2),
512,
512,
512,
512,
512,
(1024, 2),
1024,
]

def __init__(self, num_classes=10):
super(MobileNetCifar, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(32, track_running_stats=False),
nn.ReLU(inplace=True),
*self._make_layers(in_planes=32),
nn.AvgPool2d(2),
nn.Flatten(),
)
self.classifier = nn.Linear(1024, num_classes)

def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def _make_layers(self, in_planes):
layers = []
for x in self.cfg:
out_planes = x if isinstance(x, int) else x[0]
stride = 1 if isinstance(x, int) else x[1]
layers.append(Block(in_planes, out_planes, stride))
in_planes = out_planes
return layers

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
return self.classifier(self.feature_extractor(x))

def extract_features(self, x):
return self.feature_extractor(x)


fds = None # Cache FederatedDataset


def load_data(partition_id: int, num_partitions: int):
def load_data(partition_id: int, num_partitions: int, alpha: float):
"""Load partition CIFAR10 data."""
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = DirichletPartitioner(
num_partitions=num_partitions, partition_by="label", alpha=0.5
num_partitions=num_partitions, partition_by="label", alpha=alpha
)
fds = FederatedDataset(
dataset="uoft-cs/cifar10",
Expand All @@ -67,12 +124,21 @@ def apply_transforms(batch):
return trainloader, testloader


def train(net, trainloader, epochs, device):
def train(net: MobileNetCifar, trainloader, epochs, device):
"""Train the model on the training set."""
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
optimizer = torch.optim.SGD(
[
{"params": net.feature_extractor.parameters()},
{"params": net.classifier.parameters(), "lr": 0},
],
lr=0.1,
momentum=0.9,
)
net.train()
# FedBABU does not update the classifier weights while training.
net.classifier.requires_grad_(False)
running_loss = 0.0
for _ in range(epochs):
for batch in trainloader:
Expand All @@ -88,23 +154,41 @@ def train(net, trainloader, epochs, device):
return avg_trainloss


def test(net, testloader, device):
def test(net: MobileNetCifar, testloader, trainloader, device, finetune_epochs: int):
"""Validate the model on the test set."""
finetune(net, trainloader, finetune_epochs, device, finetune_epochs)
net.to(device)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
net.eval()
with torch.no_grad():
for batch in testloader:
images = batch["img"].to(device)
labels = batch["label"].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
correct += (torch.argmax(outputs, 1) == labels).sum().item()
accuracy = correct / len(testloader.dataset)
loss = loss / len(testloader)
return loss, accuracy


def finetune(net: MobileNetCifar, trainloader, epochs, device, finetune_epochs: int):
"""Finetune the model on the training set."""
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
net.train()
for _ in range(epochs):
for batch in trainloader:
images = batch["img"]
labels = batch["label"]
optimizer.zero_grad()
loss = criterion(net(images.to(device)), labels.to(device))
loss.backward()
optimizer.step()


def get_weights(net):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

Expand Down
2 changes: 2 additions & 0 deletions baselines/fedbabu/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ clientapp = "fedbabu.client_app:app"
num-server-rounds = 3
fraction-fit = 0.5
local-epochs = 1
finetune-epochs = 1
alpha = 0.1

[tool.flwr.federations]
default = "local-simulation"
Expand Down

0 comments on commit 157938c

Please sign in to comment.