Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Jan 27, 2024
1 parent 4bc609c commit 63a1ebf
Show file tree
Hide file tree
Showing 15 changed files with 421 additions and 252 deletions.
18 changes: 14 additions & 4 deletions baselines/fedpara/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
title: "FedPara: Low-rank Hadamard Product for Communication-Efficient Federated Learning"
url: https://openreview.net/forum?id=d71n4ftoCBy
labels: [image classification, personalization, low-rank training, tensor decomposition]
dataset: [CIFAR-10, CIFAR-100, FEMNIST]
dataset: [CIFAR-10, CIFAR-100, MNIST]
---

# FedPara: Low-rank Hadamard Product for Communication-Efficient Federated Learning
Expand All @@ -29,10 +29,10 @@ page: https://github.com/South-hw/FedPara_ICLR22

**What’s implemented:** The code in this directory replicates the experiments in FedPara paper implementing the Low-rank scheme for Convolution module.

Specifically, it replicates the results for CIFAR-10 and CIFAR-100 in Figure 3 and the results for Feminist in Figure 5(a).
Specifically, it replicates the results for CIFAR-10 and CIFAR-100 in Figure 3 and the results for MNIST in Figure 5(c).


**Datasets:** CIFAR-10, CIFAR-100, FEMNIST from PyTorch's Torchvision
**Datasets:** CIFAR-10, CIFAR-100, MNIST from PyTorch's Torchvision

**Hardware Setup:** The experiments have been conducted on our server with the following specs:

Expand Down Expand Up @@ -62,7 +62,7 @@ On a machine with RTX 3090Ti (24GB VRAM) it takes approximately 1h to run each C

**Training Hyperparameters:**

| | Cifar10 IID | Cifar10 Non-IID | Cifar100 IID | Cifar100 Non-IID | FEMNIST |
| | Cifar10 IID | Cifar10 Non-IID | Cifar100 IID | Cifar100 Non-IID | MNIST |
|---|-------|-------|------|-------|----------|
| Fraction of client (K) | 16 | 16 | 8 | 8 | 10 |
| Total rounds (T) | 200 | 200 | 400 | 400 | 100 |
Expand Down Expand Up @@ -143,6 +143,12 @@ python -m fedpara.main --config-name cifar100 --multirun model.conv_type=standar
python -m fedpara.main --multirun model.conv_type=standard,lowrank num_epochs=10 dataset_config.partition=iid
# To run fedpara for iid CIFAR-100 on vgg16 for lowrank and original schemes
python -m fedpara.main --config-name cifar100 --multirun model.conv_type=standard,lowrank num_epochs=10 dataset_config.partition=iid
# To run fedavg for non-iid MINST on FC
python -m fedpara.main --config-name mnist_fedavg
# To run fedper for non-iid MINST on FC
python -m fedpara.main --config-name mnist_fedper
# To run pfedpara for non-iid MINST on FC
python -m fedpara.main --config-name mnist_pfedpara
```

#### Communication Cost:
Expand All @@ -163,3 +169,7 @@ Communication costs as measured as described in the paper:
| IID | Non-IID |
|:----:|:----:|
|![CIFAR10 iid](_static/Cifar10_iid.jpeg) | ![CIFAR10 non-iid](_static/Cifar10_noniid.jpeg) |

### NON-IID MINST (FedAvg vs FedPer vs pFedPara)
**Important Note: The only federated averaging (FedAvg) implementation replicates the results outlined in the paper. However, challenges with convergence were encountered when applying pFedPara and FedPer methods.**
![Personalization algorithms](_static/non-iid_mnist_personalization.png)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
108 changes: 57 additions & 51 deletions baselines/fedpara/fedpara/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
"""Client for FedPara."""

import copy
import os
from collections import OrderedDict
from typing import Callable, Dict, List, Tuple, Optional
import copy,os
from typing import Callable, Dict, List, Optional, Tuple

import flwr as fl
import torch
from flwr.common import NDArrays, Scalar
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from fedpara.models import train,test
import logging

from fedpara.models import test, train
from fedpara.utils import get_keys_state_dict


class FlowerClient(fl.client.NumPyClient):
"""Standard Flower client for CNN training."""
Expand Down Expand Up @@ -59,8 +63,10 @@ def fit(
{},
)


class PFlowerClient(fl.client.NumPyClient):
"""personalized Flower Client"""
"""Personalized Flower Client."""

def __init__(
self,
cid: int,
Expand All @@ -71,8 +77,7 @@ def __init__(
num_epochs: int,
state_path: str,
algorithm: str,
):

):
self.cid = cid
self.net = net
self.train_loader = train_loader
Expand All @@ -82,49 +87,46 @@ def __init__(
self.state_path = state_path
self.algorithm = algorithm

def get_keys_state_dict(self, mode:str="local")->list[str]:
match self.algorithm:
case "fedper":
if mode == "local":
return list(filter(lambda x: 'fc2' in x,self.net.state_dict().keys()))
elif mode == "global":
return list(filter(lambda x: 'fc1' in x,self.net.state_dict().keys()))
case "pfedpara":
if mode == "local":
return list(filter(lambda x: 'w2' in x,self.net.state_dict().keys()))
elif mode == "global":
return list(filter(lambda x: 'w1' in x,self.net.state_dict().keys()))
case _:
raise NotImplementedError(f"algorithm {self.algorithm} not implemented")


def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""Return the parameters of the current net."""
model_dict = self.net.state_dict()
#TODO: overwrite the server private parameters
model_dict = self.net.state_dict().copy()
# overwrite the server private parameters
for k in self.private_server_param.keys():
model_dict[k] = self.private_server_param[k]
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
return [val.cpu().numpy() for _, val in model_dict.items()]

def set_parameters(self, parameters: NDArrays) -> None:
def set_parameters(self, parameters: NDArrays, evaluate: False) -> None:
self.private_server_param: Dict[str, torch.Tensor] = {}
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.private_server_param = {k:state_dict[k] for k in self.get_keys_state_dict(mode="local")}
self.net.load_state_dict(state_dict, strict=True)
server_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.private_server_param = {
k: server_dict[k]
for k in get_keys_state_dict(
model=self.net, algorithm=self.algorithm, mode="local"
)
}

if evaluate:
client_dict = self.net.state_dict().copy()
else:
client_dict = copy.deepcopy(server_dict)

if os.path.isfile(self.state_path):
# only overwrite global parameters
with open(self.state_path, 'rb') as f:
model_dict = self.net.state_dict()
state_dict = torch.load(f)
for k in self.get_keys_state_dict(mode="global"):
model_dict[k] = state_dict[k]
with open(self.state_path, "rb") as f:
client_dict = torch.load(f)

for k in get_keys_state_dict(
model=self.net, algorithm=self.algorithm, mode="global"
):
client_dict[k] = server_dict[k]

self.net.load_state_dict(client_dict, strict=False)

def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
"""Train the network on the training set."""
self.set_parameters(parameters)
self.set_parameters(parameters, evaluate=False)
print(f"Client {self.cid} Training...")

train(
Expand All @@ -136,55 +138,59 @@ def fit(
epoch=config["curr_round"],
)
if self.state_path is not None:
with open(self.state_path, 'wb') as f:
with open(self.state_path, "wb") as f:
torch.save(self.net.state_dict(), f)

return (
self.get_parameters({}),
len(self.train_loader),
{},
{},
)
def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[float, int, Dict]:

def evaluate(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[float, int, Dict]:
"""Evaluate the network on the test set."""
self.set_parameters(parameters)
self.set_parameters(parameters, evaluate=True)
print(f"Client {self.cid} Evaluating...")
self.net.to(self.device)
loss, accuracy = test(self.net, self.test_loader, device=self.device)
return loss, len(self.test_loader), {"accuracy": accuracy}


def gen_client_fn(
train_loaders: List[DataLoader],
model: DictConfig,
num_epochs: int,
args: Dict,
test_loader: Optional[List[DataLoader]]=None,
state_path: Optional[str]=None,
test_loader: Optional[List[DataLoader]] = None,
state_path: Optional[str] = None,
) -> Callable[[str], fl.client.NumPyClient]:
"""Return a function which creates a new FlowerClient for a given cid."""

def client_fn(cid: str) -> fl.client.NumPyClient:
"""Create a new FlowerClient for a given cid."""
cid = int(cid)
if args['algorithm'].lower() == "pfedpara" or args['algorithm'] == "fedper":
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if args["algorithm"].lower() == "pfedpara" or args["algorithm"] == "fedper":
cl_path = f"{state_path}/client_{cid}.pth"
return PFlowerClient(
cid=cid,
net=instantiate(model).to(args["device"]),
net=instantiate(model).to(device),
train_loader=train_loaders[cid],
test_loader=copy.deepcopy(test_loader),
device=args["device"],
num_epochs=num_epochs,
state_path=cl_path,
algorithm=args['algorithm'].lower(),
algorithm=args["algorithm"].lower(),
device=device,
)
else:
return FlowerClient(
cid=cid,
net=instantiate(model).to(args["device"]),
net=instantiate(model).to(device),
train_loader=train_loaders[cid],
device=args["device"],
num_epochs=num_epochs,
device=device,
)

return client_fn
8 changes: 4 additions & 4 deletions baselines/fedpara/fedpara/conf/cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ num_rounds: 200
clients_per_round: 16
num_epochs: 5
batch_size: 64
algorithm: FedPara

server_device: cuda

Expand All @@ -22,8 +23,7 @@ dataset_config:
model:
_target_: fedpara.models.VGG
num_classes: ${dataset_config.num_classes}
conv_type: lowrank # lowrank or standard
activation: relu # relu or leaky_relu
param_type: lowrank # lowrank or standard
ratio: 0.1 # lowrank ratio

hyperparams:
Expand All @@ -32,12 +32,12 @@ hyperparams:


strategy:
_target_: fedpara.strategy.FedPara
algorithm: FedPara
_target_: fedpara.strategy.FedAvg
fraction_fit: 0.00001
fraction_evaluate: 0.0
min_evaluate_clients: 0
min_fit_clients: ${clients_per_round}
min_available_clients: ${clients_per_round}
accept_failures: false

exp_id: ${algorithm}_${dataset_config.name}_${seed}_${dataset_config.partition}_${dataset_config.alpha}_${model.param_type}_${model.ratio}
8 changes: 4 additions & 4 deletions baselines/fedpara/fedpara/conf/cifar100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ num_rounds: 400
clients_per_round: 8
num_epochs: 5
batch_size: 64
algorithm: FedPara

server_device: cuda

Expand All @@ -22,21 +23,20 @@ dataset_config:
model:
_target_: fedpara.models.VGG
num_classes: ${dataset_config.num_classes}
conv_type: lowrank # lowrank or standard
activation: relu # relu or leaky_relu
param_type: lowrank # lowrank or standard
ratio: 0.4 # lowrank ratio

hyperparams:
eta_l: 0.1
learning_decay: 0.992

strategy:
_target_: fedpara.strategy.FedPara
algorithm: FedPara
_target_: fedpara.strategy.FedAvg
fraction_fit: 0.00001
fraction_evaluate: 0.0
min_evaluate_clients: 0
min_fit_clients: ${clients_per_round}
min_available_clients: ${clients_per_round}
accept_failures: false

exp_id: ${algorithm}_${dataset_config.name}_${seed}_${dataset_config.partition}_${dataset_config.alpha}_${model.param_type}_${model.ratio}
38 changes: 38 additions & 0 deletions baselines/fedpara/fedpara/conf/mnist_fedavg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
---
seed: 424

num_clients: 100
num_rounds: 100
clients_per_round: 10
num_epochs: 1
batch_size: 10
server_device: cuda
algorithm: fedavg

client_resources:
num_cpus: 2
num_gpus: 0.1

dataset_config:
name: MNIST
num_classes: 10
shard_size: 300

model:
_target_: fedpara.models.FC
num_classes: ${dataset_config.num_classes}
hidden_size: 200

hyperparams:
eta_l: 0.05
learning_decay: 1

strategy:
_target_: fedpara.strategy.FedAvg
fraction_fit: 0.00001
fraction_evaluate: 0
min_evaluate_clients: 0
min_fit_clients: ${clients_per_round}
min_available_clients: ${clients_per_round}

exp_id: ${algorithm}_${dataset_config.name}_${seed}
Loading

0 comments on commit 63a1ebf

Please sign in to comment.