Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Jan 27, 2024
1 parent 4bc609c commit 0db9cad
Show file tree
Hide file tree
Showing 14 changed files with 408 additions and 251 deletions.
8 changes: 4 additions & 4 deletions baselines/fedpara/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
title: "FedPara: Low-rank Hadamard Product for Communication-Efficient Federated Learning"
url: https://openreview.net/forum?id=d71n4ftoCBy
labels: [image classification, personalization, low-rank training, tensor decomposition]
dataset: [CIFAR-10, CIFAR-100, FEMNIST]
dataset: [CIFAR-10, CIFAR-100, MNIST]
---

# FedPara: Low-rank Hadamard Product for Communication-Efficient Federated Learning
Expand All @@ -29,10 +29,10 @@ page: https://github.com/South-hw/FedPara_ICLR22

**What’s implemented:** The code in this directory replicates the experiments in FedPara paper implementing the Low-rank scheme for Convolution module.

Specifically, it replicates the results for CIFAR-10 and CIFAR-100 in Figure 3 and the results for Feminist in Figure 5(a).
Specifically, it replicates the results for CIFAR-10 and CIFAR-100 in Figure 3 and the results for MNIST in Figure 5(c).


**Datasets:** CIFAR-10, CIFAR-100, FEMNIST from PyTorch's Torchvision
**Datasets:** CIFAR-10, CIFAR-100, MNIST from PyTorch's Torchvision

**Hardware Setup:** The experiments have been conducted on our server with the following specs:

Expand Down Expand Up @@ -62,7 +62,7 @@ On a machine with RTX 3090Ti (24GB VRAM) it takes approximately 1h to run each C

**Training Hyperparameters:**

| | Cifar10 IID | Cifar10 Non-IID | Cifar100 IID | Cifar100 Non-IID | FEMNIST |
| | Cifar10 IID | Cifar10 Non-IID | Cifar100 IID | Cifar100 Non-IID | MNIST |
|---|-------|-------|------|-------|----------|
| Fraction of client (K) | 16 | 16 | 8 | 8 | 10 |
| Total rounds (T) | 200 | 200 | 400 | 400 | 100 |
Expand Down
109 changes: 58 additions & 51 deletions baselines/fedpara/fedpara/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
"""Client for FedPara."""

import copy
import os
from collections import OrderedDict
from typing import Callable, Dict, List, Tuple, Optional
import copy,os
from typing import Callable, Dict, List, Optional, Tuple

import flwr as fl
import torch
from flwr.common import NDArrays, Scalar
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from fedpara.models import train,test
import logging

from fedpara.models import test, train
from fedpara.utils import get_keys_state_dict


class FlowerClient(fl.client.NumPyClient):
"""Standard Flower client for CNN training."""
Expand Down Expand Up @@ -59,8 +63,10 @@ def fit(
{},
)


class PFlowerClient(fl.client.NumPyClient):
"""personalized Flower Client"""
"""Personalized Flower Client."""

def __init__(
self,
cid: int,
Expand All @@ -71,8 +77,8 @@ def __init__(
num_epochs: int,
state_path: str,
algorithm: str,
):
):

self.cid = cid
self.net = net
self.train_loader = train_loader
Expand All @@ -82,49 +88,46 @@ def __init__(
self.state_path = state_path
self.algorithm = algorithm

def get_keys_state_dict(self, mode:str="local")->list[str]:
match self.algorithm:
case "fedper":
if mode == "local":
return list(filter(lambda x: 'fc2' in x,self.net.state_dict().keys()))
elif mode == "global":
return list(filter(lambda x: 'fc1' in x,self.net.state_dict().keys()))
case "pfedpara":
if mode == "local":
return list(filter(lambda x: 'w2' in x,self.net.state_dict().keys()))
elif mode == "global":
return list(filter(lambda x: 'w1' in x,self.net.state_dict().keys()))
case _:
raise NotImplementedError(f"algorithm {self.algorithm} not implemented")


def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""Return the parameters of the current net."""
model_dict = self.net.state_dict()
#TODO: overwrite the server private parameters
model_dict = self.net.state_dict().copy()
# overwrite the server private parameters
for k in self.private_server_param.keys():
model_dict[k] = self.private_server_param[k]
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
return [val.cpu().numpy() for _, val in model_dict.items()]

def set_parameters(self, parameters: NDArrays) -> None:
def set_parameters(self, parameters: NDArrays, evaluate: False) -> None:
self.private_server_param: Dict[str, torch.Tensor] = {}
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.private_server_param = {k:state_dict[k] for k in self.get_keys_state_dict(mode="local")}
self.net.load_state_dict(state_dict, strict=True)
server_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.private_server_param = {
k: server_dict[k]
for k in get_keys_state_dict(
model=self.net, algorithm=self.algorithm, mode="local"
)
}

if evaluate:
client_dict = self.net.state_dict().copy()
else:
client_dict = copy.deepcopy(server_dict)

if os.path.isfile(self.state_path):
# only overwrite global parameters
with open(self.state_path, 'rb') as f:
model_dict = self.net.state_dict()
state_dict = torch.load(f)
for k in self.get_keys_state_dict(mode="global"):
model_dict[k] = state_dict[k]
with open(self.state_path, "rb") as f:
client_dict = torch.load(f)

for k in get_keys_state_dict(
model=self.net, algorithm=self.algorithm, mode="global"
):
client_dict[k] = server_dict[k]

self.net.load_state_dict(client_dict, strict=False)

def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
"""Train the network on the training set."""
self.set_parameters(parameters)
self.set_parameters(parameters, evaluate=False)
print(f"Client {self.cid} Training...")

train(
Expand All @@ -136,55 +139,59 @@ def fit(
epoch=config["curr_round"],
)
if self.state_path is not None:
with open(self.state_path, 'wb') as f:
with open(self.state_path, "wb") as f:
torch.save(self.net.state_dict(), f)

return (
self.get_parameters({}),
len(self.train_loader),
{},
{},
)
def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[float, int, Dict]:

def evaluate(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[float, int, Dict]:
"""Evaluate the network on the test set."""
self.set_parameters(parameters)
self.set_parameters(parameters, evaluate=True)
print(f"Client {self.cid} Evaluating...")
self.net.to(self.device)
loss, accuracy = test(self.net, self.test_loader, device=self.device)
return loss, len(self.test_loader), {"accuracy": accuracy}


def gen_client_fn(
train_loaders: List[DataLoader],
model: DictConfig,
num_epochs: int,
args: Dict,
test_loader: Optional[List[DataLoader]]=None,
state_path: Optional[str]=None,
test_loader: Optional[List[DataLoader]] = None,
state_path: Optional[str] = None,
) -> Callable[[str], fl.client.NumPyClient]:
"""Return a function which creates a new FlowerClient for a given cid."""

def client_fn(cid: str) -> fl.client.NumPyClient:
"""Create a new FlowerClient for a given cid."""
cid = int(cid)
if args['algorithm'].lower() == "pfedpara" or args['algorithm'] == "fedper":
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if args["algorithm"].lower() == "pfedpara" or args["algorithm"] == "fedper":
cl_path = f"{state_path}/client_{cid}.pth"
return PFlowerClient(
cid=cid,
net=instantiate(model).to(args["device"]),
net=instantiate(model).to(device),
train_loader=train_loaders[cid],
test_loader=copy.deepcopy(test_loader),
device=args["device"],
num_epochs=num_epochs,
state_path=cl_path,
algorithm=args['algorithm'].lower(),
algorithm=args["algorithm"].lower(),
device=device,
)
else:
return FlowerClient(
cid=cid,
net=instantiate(model).to(args["device"]),
net=instantiate(model).to(device),
train_loader=train_loaders[cid],
device=args["device"],
num_epochs=num_epochs,
device=device,
)

return client_fn
8 changes: 4 additions & 4 deletions baselines/fedpara/fedpara/conf/cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ num_rounds: 200
clients_per_round: 16
num_epochs: 5
batch_size: 64
algorithm: FedPara

server_device: cuda

Expand All @@ -22,8 +23,7 @@ dataset_config:
model:
_target_: fedpara.models.VGG
num_classes: ${dataset_config.num_classes}
conv_type: lowrank # lowrank or standard
activation: relu # relu or leaky_relu
param_type: lowrank # lowrank or standard
ratio: 0.1 # lowrank ratio

hyperparams:
Expand All @@ -32,12 +32,12 @@ hyperparams:


strategy:
_target_: fedpara.strategy.FedPara
algorithm: FedPara
_target_: fedpara.strategy.FedAvg
fraction_fit: 0.00001
fraction_evaluate: 0.0
min_evaluate_clients: 0
min_fit_clients: ${clients_per_round}
min_available_clients: ${clients_per_round}
accept_failures: false

exp_id: ${algorithm}_${dataset_config.name}_${seed}_${dataset_config.partition}_${dataset_config.alpha}_${model.param_type}_${model.ratio}
8 changes: 4 additions & 4 deletions baselines/fedpara/fedpara/conf/cifar100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ num_rounds: 400
clients_per_round: 8
num_epochs: 5
batch_size: 64
algorithm: FedPara

server_device: cuda

Expand All @@ -22,21 +23,20 @@ dataset_config:
model:
_target_: fedpara.models.VGG
num_classes: ${dataset_config.num_classes}
conv_type: lowrank # lowrank or standard
activation: relu # relu or leaky_relu
param_type: lowrank # lowrank or standard
ratio: 0.4 # lowrank ratio

hyperparams:
eta_l: 0.1
learning_decay: 0.992

strategy:
_target_: fedpara.strategy.FedPara
algorithm: FedPara
_target_: fedpara.strategy.FedAvg
fraction_fit: 0.00001
fraction_evaluate: 0.0
min_evaluate_clients: 0
min_fit_clients: ${clients_per_round}
min_available_clients: ${clients_per_round}
accept_failures: false

exp_id: ${algorithm}_${dataset_config.name}_${seed}_${dataset_config.partition}_${dataset_config.alpha}_${model.param_type}_${model.ratio}
38 changes: 38 additions & 0 deletions baselines/fedpara/fedpara/conf/mnist_fedavg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
---
seed: 424

num_clients: 100
num_rounds: 100
clients_per_round: 10
num_epochs: 1
batch_size: 10
server_device: cuda
algorithm: fedavg

client_resources:
num_cpus: 2
num_gpus: 0.1

dataset_config:
name: MNIST
num_classes: 10
shard_size: 300

model:
_target_: fedpara.models.FC
num_classes: ${dataset_config.num_classes}
hidden_size: 200

hyperparams:
eta_l: 0.05
learning_decay: 1

strategy:
_target_: fedpara.strategy.FedAvg
fraction_fit: 0.00001
fraction_evaluate: 0
min_evaluate_clients: 0
min_fit_clients: ${clients_per_round}
min_available_clients: ${clients_per_round}

exp_id: ${algorithm}_${dataset_config.name}_${seed}
42 changes: 42 additions & 0 deletions baselines/fedpara/fedpara/conf/mnist_fedper.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
---
seed: 17

num_clients: 100
num_rounds: 100
clients_per_round: 10
num_epochs: 1
batch_size: 10
state_path: ./client_states/
server_device: cuda
client_device: cuda

algorithm: fedper
# fedavg in future

client_resources:
num_cpus: 2
num_gpus: 0.1

dataset_config:
name: MNIST
num_classes: 10
shard_size: 300

data_seed: ${seed}

model:
_target_: fedpara.models.FC
num_classes: ${dataset_config.num_classes}
hidden_size: 200

hyperparams:
eta_l: 0.05
learning_decay: 0

strategy:
_target_: fedpara.strategy.FedAvg
fraction_fit: 0.00001
fraction_evaluate: 0.00001
min_evaluate_clients: ${clients_per_round}
min_fit_clients: ${clients_per_round}
min_available_clients: ${clients_per_round}
Loading

0 comments on commit 0db9cad

Please sign in to comment.