Skip to content

Commit

Permalink
formatting and typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Jan 23, 2024
1 parent 4bc609c commit f9ed22e
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 105 deletions.
88 changes: 53 additions & 35 deletions baselines/fedpara/fedpara/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""Client for FedPara."""

import copy
import os
from collections import OrderedDict
from typing import Callable, Dict, List, Tuple, Optional
import copy,os
from typing import Callable, Dict, List, Optional, Tuple

import flwr as fl
import torch
from flwr.common import NDArrays, Scalar
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from fedpara.models import train,test
import logging

from fedpara.models import test, train


class FlowerClient(fl.client.NumPyClient):
"""Standard Flower client for CNN training."""
Expand All @@ -34,6 +37,7 @@ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

def set_parameters(self, parameters: NDArrays) -> None:
"""Return parameters of client model."""
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.net.load_state_dict(state_dict, strict=True)
Expand All @@ -59,8 +63,10 @@ def fit(
{},
)


class PFlowerClient(fl.client.NumPyClient):
"""personalized Flower Client"""
"""Personalized Flower Client."""

def __init__(
self,
cid: int,
Expand All @@ -71,8 +77,7 @@ def __init__(
num_epochs: int,
state_path: str,
algorithm: str,
):

):
self.cid = cid
self.net = net
self.train_loader = train_loader
Expand All @@ -82,42 +87,52 @@ def __init__(
self.state_path = state_path
self.algorithm = algorithm

def get_keys_state_dict(self, mode:str="local")->list[str]:
def _get_keys_state_dict(self, mode: str = "local") -> list[str]:
match self.algorithm:
case "fedper":
if mode == "local":
return list(filter(lambda x: 'fc2' in x,self.net.state_dict().keys()))
return list(
filter(lambda x: "fc2" in x, self.net.state_dict().keys())
)
elif mode == "global":
return list(filter(lambda x: 'fc1' in x,self.net.state_dict().keys()))
return list(
filter(lambda x: "fc1" in x, self.net.state_dict().keys())
)
case "pfedpara":
if mode == "local":
return list(filter(lambda x: 'w2' in x,self.net.state_dict().keys()))
return list(
filter(lambda x: "w2" in x, self.net.state_dict().keys())
)
elif mode == "global":
return list(filter(lambda x: 'w1' in x,self.net.state_dict().keys()))
return list(
filter(lambda x: "w1" in x, self.net.state_dict().keys())
)
case _:
raise NotImplementedError(f"algorithm {self.algorithm} not implemented")



def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""Return the parameters of the current net."""
model_dict = self.net.state_dict()
#TODO: overwrite the server private parameters
# TODO: overwrite the server private parameters
for k in self.private_server_param.keys():
model_dict[k] = self.private_server_param[k]
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

def set_parameters(self, parameters: NDArrays) -> None:
"""Apply parameters to model."""
self.private_server_param: Dict[str, torch.Tensor] = {}
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.private_server_param = {k:state_dict[k] for k in self.get_keys_state_dict(mode="local")}
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
self.private_server_param = {
k: state_dict[k] for k in self._get_keys_state_dict(mode="local")
}
self.net.load_state_dict(state_dict, strict=True)
if os.path.isfile(self.state_path):
# only overwrite global parameters
with open(self.state_path, 'rb') as f:
with open(self.state_path, "rb") as f:
model_dict = self.net.state_dict()
state_dict = torch.load(f)
for k in self.get_keys_state_dict(mode="global"):
for k in self._get_keys_state_dict(mode="global"):
model_dict[k] = state_dict[k]

def fit(
Expand All @@ -133,58 +148,61 @@ def fit(
self.device,
epochs=self.num_epochs,
hyperparams=config,
epoch=config["curr_round"],
epoch=int(config["curr_round"]),
)
if self.state_path is not None:
with open(self.state_path, 'wb') as f:
with open(self.state_path, "wb") as f:
torch.save(self.net.state_dict(), f)

return (
self.get_parameters({}),
len(self.train_loader),
{},
{},
)
def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[float, int, Dict]:

def evaluate(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[float, int, Dict]:
"""Evaluate the network on the test set."""
self.set_parameters(parameters)
print(f"Client {self.cid} Evaluating...")
self.net.to(self.device)
loss, accuracy = test(self.net, self.test_loader, device=self.device)
return loss, len(self.test_loader), {"accuracy": accuracy}


def gen_client_fn(
train_loaders: List[DataLoader],
model: DictConfig,
num_epochs: int,
args: Dict,
test_loader: Optional[List[DataLoader]]=None,
state_path: Optional[str]=None,
test_loader: Optional[List[DataLoader]] = None,
state_path: Optional[str] = None,
) -> Callable[[str], fl.client.NumPyClient]:
"""Return a function which creates a new FlowerClient for a given cid."""

def client_fn(cid: str) -> fl.client.NumPyClient:
"""Create a new FlowerClient for a given cid."""
cid = int(cid)
if args['algorithm'].lower() == "pfedpara" or args['algorithm'] == "fedper":
cl_path = f"{state_path}/client_{cid}.pth"
cid_ = int(cid)
if args["algorithm"].lower() == "pfedpara" or args["algorithm"] == "fedper":
cl_path = f"{state_path}/client_{cid_}.pth"
return PFlowerClient(
cid=cid,
cid=cid_,
net=instantiate(model).to(args["device"]),
train_loader=train_loaders[cid],
train_loader=train_loaders[cid_],
test_loader=copy.deepcopy(test_loader),
device=args["device"],
num_epochs=num_epochs,
state_path=cl_path,
algorithm=args['algorithm'].lower(),
algorithm=args["algorithm"].lower(),
)
else:
return FlowerClient(
cid=cid,
cid=cid_,
net=instantiate(model).to(args["device"]),
train_loader=train_loaders[cid],
train_loader=train_loaders[cid_],
device=args["device"],
num_epochs=num_epochs,
)

return client_fn
42 changes: 29 additions & 13 deletions baselines/fedpara/fedpara/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from fedpara.dataset_preparation import DatasetSplit, iid, noniid, mnist_niid
from fedpara.dataset_preparation import DatasetSplit, iid, mnist_niid, noniid


def load_datasets(
config, num_clients, batch_size
) -> Tuple[List[DataLoader], DataLoader]:
"""Load the dataset and return the dataloaders for the clients and the server."""
print("Loading data...")
match config['name']:
match config["name"]:
case "CIFAR10":
Dataset = datasets.CIFAR10
case "CIFAR100":
Expand All @@ -24,31 +24,41 @@ def load_datasets(
case _:
raise NotImplementedError
data_directory = f"./data/{config['name'].lower()}/"
match config['name']:
match config["name"]:
case "CIFAR10" | "CIFAR100":
ds_path = f"{data_directory}train_{num_clients}_{config.alpha:.2f}.pkl"
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)
transform_test = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)
try:
with open(ds_path, "rb") as file:
train_datasets = pickle.load(file)
dataset_train = Dataset(
data_directory, train=True, download=False, transform=transform_train
data_directory,
train=True,
download=False,
transform=transform_train,
)
dataset_test = Dataset(
data_directory, train=False, download=False, transform=transform_test
data_directory,
train=False,
download=False,
transform=transform_test,
)
except FileNotFoundError:
dataset_train = Dataset(
Expand All @@ -63,7 +73,7 @@ def load_datasets(
dataset_test = Dataset(
data_directory, train=False, download=True, transform=transform_test
)

case "MNIST":
ds_path = f"{data_directory}train_{num_clients}.pkl"
transform_train = transforms.Compose(
Expand All @@ -78,22 +88,29 @@ def load_datasets(
try:
train_datasets = pickle.load(open(ds_path, "rb"))
dataset_train = Dataset(
data_directory, train=True, download=False, transform=transform_train
data_directory,
train=True,
download=False,
transform=transform_train,
)
dataset_test = Dataset(
data_directory, train=False, download=False, transform=transform_test
data_directory,
train=False,
download=False,
transform=transform_test,
)
except FileNotFoundError:
dataset_train = Dataset(
data_directory, train=True, download=True, transform=transform_train
)
train_datasets = mnist_niid(dataset_train, num_clients, config.shard_size, config.data_seed)
train_datasets = mnist_niid(
dataset_train, num_clients, config.shard_size, config.data_seed
)
pickle.dump(train_datasets, open(ds_path, "wb"))
dataset_test = Dataset(
data_directory, train=False, download=True, transform=transform_test
)


test_loader = DataLoader(dataset_test, batch_size=batch_size, num_workers=2)
train_loaders = [
DataLoader(
Expand All @@ -106,4 +123,3 @@ def load_datasets(
]

return train_loaders, test_loader

23 changes: 15 additions & 8 deletions baselines/fedpara/fedpara/dataset_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
uncomment the lines below and tell us in the README.md (see the "Running the Experiment"
block) that this file should be executed first.
"""
import logging
import random
from collections import defaultdict
from collections import Counter, defaultdict

import numpy as np
from torch.utils.data import Dataset
import logging
from collections import Counter


class DatasetSplit(Dataset):
"""An abstract Dataset class wrapped around Pytorch Dataset class."""
Expand Down Expand Up @@ -99,14 +99,21 @@ def noniid(dataset, no_participants, alpha=0.5):
clas_weight[i, j] = float(datasize[i, j]) / float((train_img_size[i]))
return per_participant_list, clas_weight

def mnist_niid(dataset: Dataset, num_clients: int, shard_size: int, seed: int) -> np.ndarray:
""" Partitioning technique as mentioned in https://arxiv.org/pdf/1602.05629.pdf"""

def mnist_niid(
dataset: Dataset, num_clients: int, shard_size: int, seed: int
) -> np.ndarray:
"""Partitioning technique as mentioned in https://arxiv.org/pdf/1602.05629.pdf."""
indices = dataset.targets[np.argsort(dataset.targets)].numpy()
logging.debug(Counter(dataset.targets[indices].numpy()))
silos = np.array_split(indices, len(dataset) // shard_size)# randomly assign silos to clients
np.random.seed(seed+17)
silos = np.array_split(
indices, len(dataset) // shard_size
) # randomly assign silos to clients
np.random.seed(seed + 17)
np.random.shuffle(silos)
clients = np.array(np.array_split(silos, num_clients)).reshape(num_clients, -1)
logging.debug(clients.shape)
logging.debug(Counter([len(Counter(dataset.targets[client].numpy())) for client in clients]))
logging.debug(
Counter([len(Counter(dataset.targets[client].numpy())) for client in clients])
)
return clients
Loading

0 comments on commit f9ed22e

Please sign in to comment.