From 0db9cad4c9a05405c93b978a27686644ac0872aa Mon Sep 17 00:00:00 2001 From: = <=> Date: Sat, 27 Jan 2024 12:24:18 +0000 Subject: [PATCH] Minor changes --- baselines/fedpara/README.md | 8 +- baselines/fedpara/fedpara/client.py | 109 ++++++----- baselines/fedpara/fedpara/conf/cifar10.yaml | 8 +- baselines/fedpara/fedpara/conf/cifar100.yaml | 8 +- .../fedpara/fedpara/conf/mnist_fedavg.yaml | 38 ++++ .../fedpara/fedpara/conf/mnist_fedper.yaml | 42 +++++ .../conf/{mnist.yaml => mnist_pfedpara.yaml} | 17 +- baselines/fedpara/fedpara/dataset.py | 56 ++++-- .../fedpara/fedpara/dataset_preparation.py | 73 ++++++-- baselines/fedpara/fedpara/main.py | 51 ++--- baselines/fedpara/fedpara/models.py | 174 ++++++++---------- baselines/fedpara/fedpara/server.py | 16 +- baselines/fedpara/fedpara/strategy.py | 21 +-- baselines/fedpara/fedpara/utils.py | 38 +++- 14 files changed, 408 insertions(+), 251 deletions(-) create mode 100644 baselines/fedpara/fedpara/conf/mnist_fedavg.yaml create mode 100644 baselines/fedpara/fedpara/conf/mnist_fedper.yaml rename baselines/fedpara/fedpara/conf/{mnist.yaml => mnist_pfedpara.yaml} (75%) diff --git a/baselines/fedpara/README.md b/baselines/fedpara/README.md index 576dabc7671d..8386de1ca5d7 100644 --- a/baselines/fedpara/README.md +++ b/baselines/fedpara/README.md @@ -2,7 +2,7 @@ title: "FedPara: Low-rank Hadamard Product for Communication-Efficient Federated Learning" url: https://openreview.net/forum?id=d71n4ftoCBy labels: [image classification, personalization, low-rank training, tensor decomposition] -dataset: [CIFAR-10, CIFAR-100, FEMNIST] +dataset: [CIFAR-10, CIFAR-100, MNIST] --- # FedPara: Low-rank Hadamard Product for Communication-Efficient Federated Learning @@ -29,10 +29,10 @@ page: https://github.com/South-hw/FedPara_ICLR22 **What’s implemented:** The code in this directory replicates the experiments in FedPara paper implementing the Low-rank scheme for Convolution module. -Specifically, it replicates the results for CIFAR-10 and CIFAR-100 in Figure 3 and the results for Feminist in Figure 5(a). +Specifically, it replicates the results for CIFAR-10 and CIFAR-100 in Figure 3 and the results for MNIST in Figure 5(c). -**Datasets:** CIFAR-10, CIFAR-100, FEMNIST from PyTorch's Torchvision +**Datasets:** CIFAR-10, CIFAR-100, MNIST from PyTorch's Torchvision **Hardware Setup:** The experiments have been conducted on our server with the following specs: @@ -62,7 +62,7 @@ On a machine with RTX 3090Ti (24GB VRAM) it takes approximately 1h to run each C **Training Hyperparameters:** -| | Cifar10 IID | Cifar10 Non-IID | Cifar100 IID | Cifar100 Non-IID | FEMNIST | +| | Cifar10 IID | Cifar10 Non-IID | Cifar100 IID | Cifar100 Non-IID | MNIST | |---|-------|-------|------|-------|----------| | Fraction of client (K) | 16 | 16 | 8 | 8 | 10 | | Total rounds (T) | 200 | 200 | 400 | 400 | 100 | diff --git a/baselines/fedpara/fedpara/client.py b/baselines/fedpara/fedpara/client.py index f366c1e7b3b6..96a28b7b856c 100644 --- a/baselines/fedpara/fedpara/client.py +++ b/baselines/fedpara/fedpara/client.py @@ -1,16 +1,20 @@ """Client for FedPara.""" +import copy +import os from collections import OrderedDict -from typing import Callable, Dict, List, Tuple, Optional -import copy,os +from typing import Callable, Dict, List, Optional, Tuple + import flwr as fl import torch from flwr.common import NDArrays, Scalar from hydra.utils import instantiate from omegaconf import DictConfig from torch.utils.data import DataLoader -from fedpara.models import train,test -import logging + +from fedpara.models import test, train +from fedpara.utils import get_keys_state_dict + class FlowerClient(fl.client.NumPyClient): """Standard Flower client for CNN training.""" @@ -59,8 +63,10 @@ def fit( {}, ) + class PFlowerClient(fl.client.NumPyClient): - """personalized Flower Client""" + """Personalized Flower Client.""" + def __init__( self, cid: int, @@ -71,8 +77,8 @@ def __init__( num_epochs: int, state_path: str, algorithm: str, - ): - + ): + self.cid = cid self.net = net self.train_loader = train_loader @@ -82,49 +88,46 @@ def __init__( self.state_path = state_path self.algorithm = algorithm - def get_keys_state_dict(self, mode:str="local")->list[str]: - match self.algorithm: - case "fedper": - if mode == "local": - return list(filter(lambda x: 'fc2' in x,self.net.state_dict().keys())) - elif mode == "global": - return list(filter(lambda x: 'fc1' in x,self.net.state_dict().keys())) - case "pfedpara": - if mode == "local": - return list(filter(lambda x: 'w2' in x,self.net.state_dict().keys())) - elif mode == "global": - return list(filter(lambda x: 'w1' in x,self.net.state_dict().keys())) - case _: - raise NotImplementedError(f"algorithm {self.algorithm} not implemented") - - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: """Return the parameters of the current net.""" - model_dict = self.net.state_dict() - #TODO: overwrite the server private parameters + model_dict = self.net.state_dict().copy() + # overwrite the server private parameters for k in self.private_server_param.keys(): model_dict[k] = self.private_server_param[k] - return [val.cpu().numpy() for _, val in self.net.state_dict().items()] + return [val.cpu().numpy() for _, val in model_dict.items()] - def set_parameters(self, parameters: NDArrays) -> None: + def set_parameters(self, parameters: NDArrays, evaluate: False) -> None: self.private_server_param: Dict[str, torch.Tensor] = {} params_dict = zip(self.net.state_dict().keys(), parameters) - state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) - self.private_server_param = {k:state_dict[k] for k in self.get_keys_state_dict(mode="local")} - self.net.load_state_dict(state_dict, strict=True) + server_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + self.private_server_param = { + k: server_dict[k] + for k in get_keys_state_dict( + model=self.net, algorithm=self.algorithm, mode="local" + ) + } + + if evaluate: + client_dict = self.net.state_dict().copy() + else: + client_dict = copy.deepcopy(server_dict) + if os.path.isfile(self.state_path): - # only overwrite global parameters - with open(self.state_path, 'rb') as f: - model_dict = self.net.state_dict() - state_dict = torch.load(f) - for k in self.get_keys_state_dict(mode="global"): - model_dict[k] = state_dict[k] + with open(self.state_path, "rb") as f: + client_dict = torch.load(f) + + for k in get_keys_state_dict( + model=self.net, algorithm=self.algorithm, mode="global" + ): + client_dict[k] = server_dict[k] + + self.net.load_state_dict(client_dict, strict=False) def fit( self, parameters: NDArrays, config: Dict[str, Scalar] ) -> Tuple[NDArrays, int, Dict]: """Train the network on the training set.""" - self.set_parameters(parameters) + self.set_parameters(parameters, evaluate=False) print(f"Client {self.cid} Training...") train( @@ -136,55 +139,59 @@ def fit( epoch=config["curr_round"], ) if self.state_path is not None: - with open(self.state_path, 'wb') as f: + with open(self.state_path, "wb") as f: torch.save(self.net.state_dict(), f) return ( self.get_parameters({}), len(self.train_loader), - {}, + {}, ) - def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[float, int, Dict]: + + def evaluate( + self, parameters: NDArrays, config: Dict[str, Scalar] + ) -> Tuple[float, int, Dict]: """Evaluate the network on the test set.""" - self.set_parameters(parameters) + self.set_parameters(parameters, evaluate=True) print(f"Client {self.cid} Evaluating...") self.net.to(self.device) loss, accuracy = test(self.net, self.test_loader, device=self.device) return loss, len(self.test_loader), {"accuracy": accuracy} - + def gen_client_fn( train_loaders: List[DataLoader], model: DictConfig, num_epochs: int, args: Dict, - test_loader: Optional[List[DataLoader]]=None, - state_path: Optional[str]=None, + test_loader: Optional[List[DataLoader]] = None, + state_path: Optional[str] = None, ) -> Callable[[str], fl.client.NumPyClient]: """Return a function which creates a new FlowerClient for a given cid.""" def client_fn(cid: str) -> fl.client.NumPyClient: """Create a new FlowerClient for a given cid.""" cid = int(cid) - if args['algorithm'].lower() == "pfedpara" or args['algorithm'] == "fedper": + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if args["algorithm"].lower() == "pfedpara" or args["algorithm"] == "fedper": cl_path = f"{state_path}/client_{cid}.pth" return PFlowerClient( cid=cid, - net=instantiate(model).to(args["device"]), + net=instantiate(model).to(device), train_loader=train_loaders[cid], test_loader=copy.deepcopy(test_loader), - device=args["device"], num_epochs=num_epochs, state_path=cl_path, - algorithm=args['algorithm'].lower(), + algorithm=args["algorithm"].lower(), + device=device, ) else: return FlowerClient( cid=cid, - net=instantiate(model).to(args["device"]), + net=instantiate(model).to(device), train_loader=train_loaders[cid], - device=args["device"], num_epochs=num_epochs, + device=device, ) - + return client_fn diff --git a/baselines/fedpara/fedpara/conf/cifar10.yaml b/baselines/fedpara/fedpara/conf/cifar10.yaml index 1c05e0ec1ca9..13f5a07b5f93 100644 --- a/baselines/fedpara/fedpara/conf/cifar10.yaml +++ b/baselines/fedpara/fedpara/conf/cifar10.yaml @@ -6,6 +6,7 @@ num_rounds: 200 clients_per_round: 16 num_epochs: 5 batch_size: 64 +algorithm: FedPara server_device: cuda @@ -22,8 +23,7 @@ dataset_config: model: _target_: fedpara.models.VGG num_classes: ${dataset_config.num_classes} - conv_type: lowrank # lowrank or standard - activation: relu # relu or leaky_relu + param_type: lowrank # lowrank or standard ratio: 0.1 # lowrank ratio hyperparams: @@ -32,8 +32,7 @@ hyperparams: strategy: - _target_: fedpara.strategy.FedPara - algorithm: FedPara + _target_: fedpara.strategy.FedAvg fraction_fit: 0.00001 fraction_evaluate: 0.0 min_evaluate_clients: 0 @@ -41,3 +40,4 @@ strategy: min_available_clients: ${clients_per_round} accept_failures: false +exp_id: ${algorithm}_${dataset_config.name}_${seed}_${dataset_config.partition}_${dataset_config.alpha}_${model.param_type}_${model.ratio} diff --git a/baselines/fedpara/fedpara/conf/cifar100.yaml b/baselines/fedpara/fedpara/conf/cifar100.yaml index ee8402324519..4563ee949c31 100644 --- a/baselines/fedpara/fedpara/conf/cifar100.yaml +++ b/baselines/fedpara/fedpara/conf/cifar100.yaml @@ -6,6 +6,7 @@ num_rounds: 400 clients_per_round: 8 num_epochs: 5 batch_size: 64 +algorithm: FedPara server_device: cuda @@ -22,8 +23,7 @@ dataset_config: model: _target_: fedpara.models.VGG num_classes: ${dataset_config.num_classes} - conv_type: lowrank # lowrank or standard - activation: relu # relu or leaky_relu + param_type: lowrank # lowrank or standard ratio: 0.4 # lowrank ratio hyperparams: @@ -31,8 +31,7 @@ hyperparams: learning_decay: 0.992 strategy: - _target_: fedpara.strategy.FedPara - algorithm: FedPara + _target_: fedpara.strategy.FedAvg fraction_fit: 0.00001 fraction_evaluate: 0.0 min_evaluate_clients: 0 @@ -40,3 +39,4 @@ strategy: min_available_clients: ${clients_per_round} accept_failures: false +exp_id: ${algorithm}_${dataset_config.name}_${seed}_${dataset_config.partition}_${dataset_config.alpha}_${model.param_type}_${model.ratio} \ No newline at end of file diff --git a/baselines/fedpara/fedpara/conf/mnist_fedavg.yaml b/baselines/fedpara/fedpara/conf/mnist_fedavg.yaml new file mode 100644 index 000000000000..2429f9b20d6e --- /dev/null +++ b/baselines/fedpara/fedpara/conf/mnist_fedavg.yaml @@ -0,0 +1,38 @@ +--- +seed: 424 + +num_clients: 100 +num_rounds: 100 +clients_per_round: 10 +num_epochs: 1 +batch_size: 10 +server_device: cuda +algorithm: fedavg + +client_resources: + num_cpus: 2 + num_gpus: 0.1 + +dataset_config: + name: MNIST + num_classes: 10 + shard_size: 300 + +model: + _target_: fedpara.models.FC + num_classes: ${dataset_config.num_classes} + hidden_size: 200 + +hyperparams: + eta_l: 0.05 + learning_decay: 1 + +strategy: + _target_: fedpara.strategy.FedAvg + fraction_fit: 0.00001 + fraction_evaluate: 0 + min_evaluate_clients: 0 + min_fit_clients: ${clients_per_round} + min_available_clients: ${clients_per_round} + +exp_id: ${algorithm}_${dataset_config.name}_${seed} diff --git a/baselines/fedpara/fedpara/conf/mnist_fedper.yaml b/baselines/fedpara/fedpara/conf/mnist_fedper.yaml new file mode 100644 index 000000000000..13d2dedfc0ad --- /dev/null +++ b/baselines/fedpara/fedpara/conf/mnist_fedper.yaml @@ -0,0 +1,42 @@ +--- +seed: 17 + +num_clients: 100 +num_rounds: 100 +clients_per_round: 10 +num_epochs: 1 +batch_size: 10 +state_path: ./client_states/ +server_device: cuda +client_device: cuda + +algorithm: fedper +# fedavg in future + +client_resources: + num_cpus: 2 + num_gpus: 0.1 + +dataset_config: + name: MNIST + num_classes: 10 + shard_size: 300 + + data_seed: ${seed} + +model: + _target_: fedpara.models.FC + num_classes: ${dataset_config.num_classes} + hidden_size: 200 + +hyperparams: + eta_l: 0.05 + learning_decay: 0 + +strategy: + _target_: fedpara.strategy.FedAvg + fraction_fit: 0.00001 + fraction_evaluate: 0.00001 + min_evaluate_clients: ${clients_per_round} + min_fit_clients: ${clients_per_round} + min_available_clients: ${clients_per_round} \ No newline at end of file diff --git a/baselines/fedpara/fedpara/conf/mnist.yaml b/baselines/fedpara/fedpara/conf/mnist_pfedpara.yaml similarity index 75% rename from baselines/fedpara/fedpara/conf/mnist.yaml rename to baselines/fedpara/fedpara/conf/mnist_pfedpara.yaml index 05155542fb6c..7bf8b3979e06 100644 --- a/baselines/fedpara/fedpara/conf/mnist.yaml +++ b/baselines/fedpara/fedpara/conf/mnist_pfedpara.yaml @@ -7,28 +7,25 @@ clients_per_round: 10 num_epochs: 5 batch_size: 10 state_path: ./client_states/ -client_device: cuda -algorithm: pFedPara - +server_device: cuda +algorithm: pFedpara client_resources: num_cpus: 2 - num_gpus: 0.0625 + num_gpus: 0.1 dataset_config: name: MNIST num_classes: 10 shard_size: 300 - - data_seed: ${seed} model: _target_: fedpara.models.FC num_classes: ${dataset_config.num_classes} param_type: lowrank # lowrank or standard - activation: relu # relu or leaky_relu ratio: 0.5 # lowrank ratio - algorithm: ${algorithm} + hidden_size: 200 + hyperparams: eta_l: 0.01 learning_decay: 0.999 @@ -39,4 +36,6 @@ strategy: fraction_evaluate: 0.00001 min_evaluate_clients: ${clients_per_round} min_fit_clients: ${clients_per_round} - min_available_clients: ${clients_per_round} \ No newline at end of file + min_available_clients: ${clients_per_round} + +exp_id: ${algorithm}_${dataset_config.name}_${seed}_${model.param_type}_${model.ratio} diff --git a/baselines/fedpara/fedpara/dataset.py b/baselines/fedpara/fedpara/dataset.py index b4737d597ac2..f1e13b84e47a 100644 --- a/baselines/fedpara/fedpara/dataset.py +++ b/baselines/fedpara/fedpara/dataset.py @@ -6,7 +6,12 @@ from torch.utils.data import DataLoader from torchvision import datasets, transforms -from fedpara.dataset_preparation import DatasetSplit, iid, noniid, mnist_niid +from fedpara.dataset_preparation import ( + DatasetSplit, + iid, + noniid, + noniid_partition_loader, +) def load_datasets( @@ -14,7 +19,7 @@ def load_datasets( ) -> Tuple[List[DataLoader], DataLoader]: """Load the dataset and return the dataloaders for the clients and the server.""" print("Loading data...") - match config['name']: + match config["name"]: case "CIFAR10": Dataset = datasets.CIFAR10 case "CIFAR100": @@ -24,7 +29,7 @@ def load_datasets( case _: raise NotImplementedError data_directory = f"./data/{config['name'].lower()}/" - match config['name']: + match config["name"]: case "CIFAR10" | "CIFAR100": ds_path = f"{data_directory}train_{num_clients}_{config.alpha:.2f}.pkl" transform_train = transforms.Compose( @@ -32,23 +37,33 @@ def load_datasets( transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ), ] ) transform_test = transforms.Compose( [ transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ), ] ) try: with open(ds_path, "rb") as file: train_datasets = pickle.load(file) dataset_train = Dataset( - data_directory, train=True, download=False, transform=transform_train + data_directory, + train=True, + download=False, + transform=transform_train, ) dataset_test = Dataset( - data_directory, train=False, download=False, transform=transform_test + data_directory, + train=False, + download=False, + transform=transform_test, ) except FileNotFoundError: dataset_train = Dataset( @@ -63,7 +78,7 @@ def load_datasets( dataset_test = Dataset( data_directory, train=False, download=True, transform=transform_test ) - + case "MNIST": ds_path = f"{data_directory}train_{num_clients}.pkl" transform_train = transforms.Compose( @@ -78,21 +93,37 @@ def load_datasets( try: train_datasets = pickle.load(open(ds_path, "rb")) dataset_train = Dataset( - data_directory, train=True, download=False, transform=transform_train + data_directory, + train=True, + download=False, + transform=transform_train, ) dataset_test = Dataset( - data_directory, train=False, download=False, transform=transform_test + data_directory, + train=False, + download=False, + transform=transform_test, ) + except FileNotFoundError: dataset_train = Dataset( data_directory, train=True, download=True, transform=transform_train ) - train_datasets = mnist_niid(dataset_train, num_clients, config.shard_size, config.data_seed) + train_datasets = noniid_partition_loader( + dataset_train, + m_per_shard=config.shard_size, + n_shards_per_client=len(dataset_train) // (config.shard_size * 100), + ) pickle.dump(train_datasets, open(ds_path, "wb")) dataset_test = Dataset( data_directory, train=False, download=True, transform=transform_test ) - + train_loaders = [ + DataLoader(x, batch_size=batch_size, shuffle=True) + for x in train_datasets + ] + test_loader = DataLoader(dataset_test, batch_size=batch_size, num_workers=2) + return train_loaders, test_loader test_loader = DataLoader(dataset_test, batch_size=batch_size, num_workers=2) train_loaders = [ @@ -106,4 +137,3 @@ def load_datasets( ] return train_loaders, test_loader - diff --git a/baselines/fedpara/fedpara/dataset_preparation.py b/baselines/fedpara/fedpara/dataset_preparation.py index b76dcf576f22..2de43be318ce 100644 --- a/baselines/fedpara/fedpara/dataset_preparation.py +++ b/baselines/fedpara/fedpara/dataset_preparation.py @@ -6,13 +6,14 @@ uncomment the lines below and tell us in the README.md (see the "Running the Experiment" block) that this file should be executed first. """ + import random from collections import defaultdict import numpy as np +import torch from torch.utils.data import Dataset -import logging -from collections import Counter + class DatasetSplit(Dataset): """An abstract Dataset class wrapped around Pytorch Dataset class.""" @@ -99,14 +100,60 @@ def noniid(dataset, no_participants, alpha=0.5): clas_weight[i, j] = float(datasize[i, j]) / float((train_img_size[i])) return per_participant_list, clas_weight -def mnist_niid(dataset: Dataset, num_clients: int, shard_size: int, seed: int) -> np.ndarray: - """ Partitioning technique as mentioned in https://arxiv.org/pdf/1602.05629.pdf""" - indices = dataset.targets[np.argsort(dataset.targets)].numpy() - logging.debug(Counter(dataset.targets[indices].numpy())) - silos = np.array_split(indices, len(dataset) // shard_size)# randomly assign silos to clients - np.random.seed(seed+17) - np.random.shuffle(silos) - clients = np.array(np.array_split(silos, num_clients)).reshape(num_clients, -1) - logging.debug(clients.shape) - logging.debug(Counter([len(Counter(dataset.targets[client].numpy())) for client in clients])) - return clients + +def data_to_tensor(data): + """Loads dataset to memory, applies transform.""" + loader = torch.utils.data.DataLoader(data, batch_size=len(data)) + img, label = next(iter(loader)) + return img, label + + +def noniid_partition_loader(data, m_per_shard=300, n_shards_per_client=2): + """semi-pathological client sample partition + 1. sort examples by label, form shards of size 300 by grouping points + successively + 2. each client is 2 random shards + most clients will have 2 digits, at most 4 + """ + # load data into memory + img, label = data_to_tensor(data) + + # sort + idx = torch.argsort(label) + img = img[idx] + label = label[idx] + + # split into n_shards of size m_per_shard + m = len(data) + assert m % m_per_shard == 0 + n_shards = m // m_per_shard + shards_idx = [ + torch.arange(m_per_shard * i, m_per_shard * (i + 1)) for i in range(n_shards) + ] + random.shuffle(shards_idx) # shuffle shards + + # pick shards to create a dataset for each client + assert n_shards % n_shards_per_client == 0 + n_clients = n_shards // n_shards_per_client + client_data = [ + torch.utils.data.TensorDataset( + torch.cat( + [ + img[shards_idx[j]] + for j in range( + i * n_shards_per_client, (i + 1) * n_shards_per_client + ) + ] + ), + torch.cat( + [ + label[shards_idx[j]] + for j in range( + i * n_shards_per_client, (i + 1) * n_shards_per_client + ) + ] + ), + ) + for i in range(n_clients) + ] + return client_data diff --git a/baselines/fedpara/fedpara/main.py b/baselines/fedpara/fedpara/main.py index 4da028738f5e..47509a3d7811 100644 --- a/baselines/fedpara/fedpara/main.py +++ b/baselines/fedpara/fedpara/main.py @@ -5,12 +5,19 @@ from hydra.core.hydra_config import HydraConfig from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf + from fedpara import client, server, utils from fedpara.dataset import load_datasets -from fedpara.utils import get_parameters, save_results_as_pickle, seed_everything, set_client_state_save_path +from fedpara.server import weighted_average +from fedpara.utils import ( + get_parameters, + save_results_as_pickle, + seed_everything, + set_client_state_save_path, +) -@hydra.main(config_path="conf", config_name="mnist", version_base=None) +@hydra.main(config_path="conf", config_name="cifar10", version_base=None) def main(cfg: DictConfig) -> None: """Run the baseline. @@ -23,8 +30,10 @@ def main(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) seed_everything(cfg.seed) OmegaConf.to_container(cfg, resolve=True) - if 'state_path' in cfg: state_path=set_client_state_save_path(cfg.state_path) - else: state_path = None + if "state_path" in cfg: + state_path = set_client_state_save_path(cfg.state_path) + else: + state_path = None # 2. Prepare dataset train_loaders, test_loader = load_datasets( @@ -41,7 +50,7 @@ def main(cfg: DictConfig) -> None: test_loader=test_loader, model=cfg.model, num_epochs=cfg.num_epochs, - args={"device": cfg.client_device, "algorithm": cfg.algorithm}, + args={"algorithm": cfg.algorithm}, state_path=state_path, ) @@ -60,25 +69,28 @@ def fit_config_fn(server_round: int): strategy = instantiate( cfg.strategy, on_fit_config_fn=get_on_fit_config(), - initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(net_glob)), + initial_parameters=fl.common.ndarrays_to_parameters( + get_parameters(net_glob) + ), + evaluate_metrics_aggregation_fn=weighted_average, ) - else : + else: evaluate_fn = server.gen_evaluate_fn( num_clients=cfg.num_clients, test_loader=test_loader, model=cfg.model, device=cfg.server_device, - state_path=cfg.state_path, ) strategy = instantiate( cfg.strategy, evaluate_fn=evaluate_fn, on_fit_config_fn=get_on_fit_config(), - initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(net_glob)), + initial_parameters=fl.common.ndarrays_to_parameters( + get_parameters(net_glob) + ), ) - # 5. Start Simulation history = fl.simulation.start_simulation( client_fn=client_fn, @@ -92,21 +104,12 @@ def fit_config_fn(server_round: int): "_memory": 30 * 1024 * 1024 * 1024, }, ) - save_results_as_pickle(history) - - # 6. Save results save_path = HydraConfig.get().runtime.output_dir - file_suffix = "_".join( - [ - repr(strategy), - cfg.dataset_config.name, - f"{cfg.seed}", - f"{cfg.dataset_config.alpha}", - f"{cfg.num_clients}", - f"{cfg.num_rounds}", - f"{cfg.clients_per_round}", - ] - ) + + save_results_as_pickle(history, file_path=save_path) + + # 6. Save results + file_suffix = "_".join([repr(net_glob), f"{cfg.exp_id}"]) utils.plot_metric_from_history( hist=history, diff --git a/baselines/fedpara/fedpara/models.py b/baselines/fedpara/fedpara/models.py index bb785e5b893c..be3bbe915866 100644 --- a/baselines/fedpara/fedpara/models.py +++ b/baselines/fedpara/fedpara/models.py @@ -2,6 +2,7 @@ import math from typing import Dict, Tuple + import numpy as np import torch import torch.nn.functional as F @@ -10,100 +11,94 @@ from torch.nn import init from torch.utils.data import DataLoader + class LowRankNN(nn.Module): - def __init__(self,input, output, rank,activation: str = "relu",) -> None: + """Fedpara Low-rank weight systhesis for fully connected layer.""" + + def __init__(self, input, output, rank) -> None: super(LowRankNN, self).__init__() - + self.X = nn.Parameter( torch.empty(size=(input, rank)), requires_grad=True, ) - self.Y = nn.Parameter( - torch.empty(size=(output,rank)), requires_grad=True - ) + self.Y = nn.Parameter(torch.empty(size=(output, rank)), requires_grad=True) + + init.kaiming_normal_(self.X, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(self.Y, mode="fan_out", nonlinearity="relu") - if activation == "leakyrelu": - activation = "leaky_relu" - init.kaiming_normal_(self.X, mode="fan_out", nonlinearity=activation) - init.kaiming_normal_(self.Y, mode="fan_out", nonlinearity=activation) - def forward(self): - out = torch.einsum("yr,xr->yx",self.Y, self.X) + out = torch.einsum("yr,xr->yx", self.Y, self.X) return out - + + class Linear(nn.Module): - def __init__(self, input, output, ratio, activation: str = "relu",bias= True, pfedpara=True) -> None: + """Low-rank fully connected layer module for personalized scheme.""" + + def __init__(self, input, output, ratio, bias=True) -> None: super(Linear, self).__init__() - rank = self._calc_from_ratio(ratio,input, output) - self.w1 = LowRankNN(input, output, rank, activation) - self.w2 = LowRankNN(input, output, rank, activation) + rank = self._calc_from_ratio(ratio, input, output) + self.w1 = LowRankNN(input, output, rank) + self.w2 = LowRankNN(input, output, rank) # make the bias for each layer if bias: self.bias = nn.Parameter(torch.zeros(output)) - self.pfedpara = pfedpara - def _calc_from_ratio(self, ratio,input, output): + def _calc_from_ratio(self, ratio, input, output): # Return the low-rank of sub-matrices given the compression ratio # minimum possible parameter r1 = int(np.ceil(np.sqrt(output))) r2 = int(np.ceil(np.sqrt(input))) r = np.min((r1, r2)) - # maximum possible rank, + # maximum possible rank, """ - To solve it we need to know the roots of quadratic equation: ax^2+bx+c=0 - a = kernel**2 - b = out channel+ in channel - c = - num_target_params/2 - r3 is floored because we cannot take the ceil as it results a bigger number of parameters than the original problem + To solve it we need to know the roots of quadratic equation: 2*r*(m+n)=m*n """ - num_target_params = ( - output * input - ) - a, b, c = input, output,- num_target_params/2 - discriminant = b**2 - 4 * a * c - r3 = math.floor((-b+math.sqrt(discriminant))/(2*a)) - rank=math.ceil((1-ratio)*r+ ratio*r3) + r3 = math.floor((output * input) / (2 * (output + input))) + rank = math.ceil((1 - ratio) * r + ratio * r3) return rank - - def forward(self,x): + + def forward(self, x): # personalized - if self.pfedpara: - w = self.w1() * self.w2() + self.w1() - else: - w = self.w1() * self.w2() - out = F.linear(x, w,self.bias) + w = self.w1() * self.w2() + self.w1() + out = F.linear(x, w, self.bias) return out - + class FC(nn.Module): - def __init__(self, input_size=28**2, hidden_size=256, num_classes=10, ratio=0.5, param_type="lowrank",activation: str = "relu",algorithm="pfedpara"): + """2NN Fully connected layer as in the paper: https://arxiv.org/abs/1602.05629""" + + def __init__( + self, + input_size=28**2, + hidden_size=200, + num_classes=10, + ratio=0.5, + param_type="standard", + ): super(FC, self).__init__() self.input_size = input_size - self.method = algorithm.lower() if param_type == "standard": - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - self.softmax = nn.Softmax(dim=1) + + self.fc1 = nn.Linear(input_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, 256) + self.out = nn.Linear(256, num_classes) + elif param_type == "lowrank": - pfedpara = False - if self.method == "pfedpara": - pfedpara = True - - self.fc1 = Linear(input_size, hidden_size, ratio, activation, pfedpara=pfedpara) - self.relu = nn.ReLU() - self.fc2 = Linear(hidden_size, num_classes, ratio, activation, pfedpara=pfedpara) - self.softmax = nn.Softmax(dim=1) + self.fc1 = Linear(input_size, hidden_size, ratio) + self.fc2 = Linear(hidden_size, 256, ratio) + self.out = Linear(256, num_classes, ratio) + else: raise ValueError("param_type must be either standard or lowrank") @property def model_size(self): - """ - Return the total number of trainable parameters (in million paramaters) and the size of the model in MB. - """ - total_trainable_params = sum( - p.numel() for p in self.parameters() if p.requires_grad)/1e6 + """Return the total number of trainable parameters (in million paramaters) and + the size of the model in MB.""" + total_trainable_params = ( + sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 + ) param_size = 0 for param in self.parameters(): param_size += param.nelement() * param.element_size() @@ -112,17 +107,17 @@ def model_size(self): buffer_size += buffer.nelement() * buffer.element_size() size_all_mb = (param_size + buffer_size) / 1024**2 return total_trainable_params, size_all_mb - - def forward(self,x): + + def forward(self, x): x = x.view(-1, self.input_size) - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - out = self.softmax(out) - return out + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.out(x) + return x + class LowRank(nn.Module): - """Low-rank convolutional layer.""" + """Fedpara Low-rank weight systhesis for Convolution layer.""" def __init__( # pylint: disable=too-many-arguments self, @@ -130,7 +125,6 @@ def __init__( # pylint: disable=too-many-arguments out_channels: int, low_rank: int, kernel_size: int, - activation: str = "relu", ): super().__init__() self.T = nn.Parameter( @@ -143,11 +137,9 @@ def __init__( # pylint: disable=too-many-arguments self.Y = nn.Parameter( torch.empty(size=(low_rank, in_channels)), requires_grad=True ) - if activation == "leakyrelu": - activation = "leaky_relu" - init.kaiming_normal_(self.T, mode="fan_out", nonlinearity=activation) - init.kaiming_normal_(self.X, mode="fan_out", nonlinearity=activation) - init.kaiming_normal_(self.Y, mode="fan_out", nonlinearity=activation) + init.kaiming_normal_(self.T, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(self.X, mode="fan_out", nonlinearity="relu") + init.kaiming_normal_(self.Y, mode="fan_out", nonlinearity="relu") def forward(self): """Forward pass.""" @@ -169,7 +161,6 @@ def __init__( # pylint: disable=too-many-arguments bias: bool = False, ratio: float = 0.1, add_nonlinear: bool = False, - activation: str = "relu", ): super().__init__() self.in_channels = in_channels @@ -181,13 +172,8 @@ def __init__( # pylint: disable=too-many-arguments self.ratio = ratio self.low_rank = self._calc_from_ratio() self.add_nonlinear = add_nonlinear - self.activation = activation - self.W1 = LowRank( - in_channels, out_channels, self.low_rank, kernel_size, activation - ) - self.W2 = LowRank( - in_channels, out_channels, self.low_rank, kernel_size, activation - ) + self.W1 = LowRank(in_channels, out_channels, self.low_rank, kernel_size) + self.W2 = LowRank(in_channels, out_channels, self.low_rank, kernel_size) self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None self.tanh = nn.Tanh() @@ -207,9 +193,7 @@ def _calc_from_ratio(self): # r3 is floored because we cannot take the ceil as it results a bigger number # of parameters than the original problem - num_target_params = ( - self.out_channels * self.in_channels * (self.kernel_size**2) - ) + num_target_params = self.out_channels * self.in_channels * (self.kernel_size**2) a, b, c = ( self.kernel_size**2, self.out_channels + self.in_channels, @@ -242,16 +226,11 @@ def __init__( # pylint: disable=too-many-arguments num_classes, num_groups=2, ratio=0.1, - activation="relu", - conv_type="lowrank", + param_type="lowrank", add_nonlinear=False, ): super().__init__() - if activation == "relu": - self.activation = nn.ReLU(inplace=True) - elif activation == "leaky_relu": - self.activation = nn.LeakyReLU(inplace=True) - self.conv_type = conv_type + self.param_type = param_type self.num_groups = num_groups self.num_classes = num_classes self.ratio = ratio @@ -281,10 +260,10 @@ def __init__( # pylint: disable=too-many-arguments self.classifier = nn.Sequential( nn.Dropout(), nn.Linear(512, 512), - self.activation, + nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(512, 512), - self.activation, + nn.ReLU(inplace=True), nn.Linear(512, num_classes), ) self._init_weights() @@ -294,7 +273,7 @@ def _init_weights(self): for name, module in self.features.named_children(): module = getattr(self.features, name) if isinstance(module, nn.Conv2d): - if self.conv_type == "lowrank": + if self.param_type == "lowrank": num_channels = module.in_channels setattr( self.features, @@ -309,10 +288,9 @@ def _init_weights(self): ratio=self.ratio, add_nonlinear=self.add_nonlinear, # send the activation function to the Conv2d class - activation=self.activation.__class__.__name__.lower(), ), ) - elif self.conv_type == "standard": + elif self.param_type == "standard": n = ( module.kernel_size[0] * module.kernel_size[1] @@ -333,10 +311,10 @@ def _make_layers(self, cfg, group_norm=True): layers += [ conv2d, nn.GroupNorm(self.num_groups, v), - self.activation, + nn.ReLU(inplace=True), ] else: - layers += [conv2d, self.activation] + layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return nn.Sequential(*layers) @@ -490,6 +468,6 @@ def _train_one_epoch( # pylint: disable=too-many-arguments if __name__ == "__main__": - model = VGG(num_classes=10, num_groups=2, conv_type="standard", ratio=0.4) + model = VGG(num_classes=10, num_groups=2, param_type="standard", ratio=0.4) # Print the modified VGG16GN model architecture print(model.model_size) diff --git a/baselines/fedpara/fedpara/server.py b/baselines/fedpara/fedpara/server.py index f73618ab8569..cd370822e1f7 100644 --- a/baselines/fedpara/fedpara/server.py +++ b/baselines/fedpara/fedpara/server.py @@ -1,12 +1,14 @@ """Global evaluation function.""" from collections import OrderedDict -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple + import torch -from flwr.common import NDArrays, Scalar +from flwr.common import Metrics, NDArrays, Scalar from hydra.utils import instantiate from omegaconf import DictConfig from torch.utils.data import DataLoader + from fedpara.models import test @@ -16,6 +18,7 @@ def get_on_fit_config(hypearparams: Dict): def fit_config_fn(server_round: int): hypearparams["curr_round"] = server_round return hypearparams + return fit_config_fn @@ -62,3 +65,12 @@ def evaluate( return loss, {"accuracy": accuracy} return evaluate + + +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + print(f"accuracies: {sum(accuracies) / sum(examples)}") + # Aggregate and return custom metric (weighted average) + return {"accuracy": sum(accuracies) / sum(examples)} diff --git a/baselines/fedpara/fedpara/strategy.py b/baselines/fedpara/fedpara/strategy.py index 0a7cd788d189..7fc32a19b90d 100644 --- a/baselines/fedpara/fedpara/strategy.py +++ b/baselines/fedpara/fedpara/strategy.py @@ -1,20 +1 @@ -"""FedPara strategy.""" - - -from flwr.server.strategy import FedAvg - - -class FedPara(FedAvg): - """FedPara strategy.""" - - def __init__( - self, - algorithm: str, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.algorithm = algorithm - - def __repr__(self) -> str: - """Return the name of the strategy.""" - return self.algorithm +"""FedPara uses FedAvg as the default strategy.""" diff --git a/baselines/fedpara/fedpara/utils.py b/baselines/fedpara/fedpara/utils.py index 77d70cb6ce38..af68c5f4765b 100644 --- a/baselines/fedpara/fedpara/utils.py +++ b/baselines/fedpara/fedpara/utils.py @@ -1,8 +1,13 @@ """Utility functions for FedPara.""" + +import os +import pickle import random +import time from pathlib import Path from secrets import token_hex -from typing import Optional, Union +from typing import Optional + import matplotlib.pyplot as plt import numpy as np import torch @@ -10,7 +15,7 @@ from flwr.server import History from omegaconf import DictConfig from torch.nn import Module -import time, os, pickle + def plot_metric_from_history( hist: History, @@ -38,14 +43,14 @@ def plot_metric_from_history( metric_dict = ( hist.metrics_centralized if metric_type == "centralized" - else hist.metrics_distributed + else hist["history"].metrics_distributed ) _, axs = plt.subplots() rounds, values_accuracy = zip(*metric_dict["accuracy"]) r_cc = (i * 2 * model_size * int(cfg.clients_per_round) / 1024 for i in rounds) # Set the title - title = f"{cfg.strategy.algorithm} | parameters: {cfg.model.conv_type} | " + title = f"{cfg.algorithm} | parameters: {cfg.model.conv_type} | " title += ( f"{cfg.dataset_config.name} {cfg.dataset_config.partition} | Seed {cfg.seed}" ) @@ -71,9 +76,11 @@ def get_parameters(net: Module) -> NDArrays: """Get the parameters of the network.""" return [val.cpu().numpy() for _, val in net.state_dict().items()] + def save_results_as_pickle( history: History, - default_filename: Optional[str] = "results.pkl", + file_path: str, + default_filename: Optional[str] = "history.pkl", ) -> None: """Save results from simulation to pickle. @@ -94,7 +101,6 @@ def save_results_as_pickle( File used by default if file_path points to a directory instead to a file. Default: "results.pkl" """ - file_path = set_client_state_save_path("./outputs/") path = Path(file_path) # ensure path exists @@ -132,9 +138,23 @@ def set_client_state_save_path(path: str) -> str: """Set the client state save path.""" client_state_save_path = time.strftime("%Y-%m-%d") client_state_sub_path = time.strftime("%H-%M-%S") - client_state_save_path = ( - f"{path}{client_state_save_path}/{client_state_sub_path}" - ) + client_state_save_path = f"{path}{client_state_save_path}/{client_state_sub_path}" if not os.path.exists(client_state_save_path): os.makedirs(client_state_save_path) return client_state_save_path + + +def get_keys_state_dict(model, algorithm, mode: str = "local") -> list[str]: + match algorithm: + case "fedper": + if mode == "local": + return list(filter(lambda x: "out" in x, model.state_dict().keys())) + elif mode == "global": + return list(filter(lambda x: "out" not in x, model.state_dict().keys())) + case "pfedpara": + if mode == "local": + return list(filter(lambda x: "w2" in x, model.state_dict().keys())) + elif mode == "global": + return list(filter(lambda x: "w1" in x, model.state_dict().keys())) + case _: + raise NotImplementedError(f"algorithm {algorithm} not implemented")