From 7ea0c2e5556d6141b64ad9347b9466f02a80f10e Mon Sep 17 00:00:00 2001 From: Yahia Salaheldin Shaaban <62369984+yehias21@users.noreply.github.com> Date: Sun, 7 Jan 2024 20:22:39 +0200 Subject: [PATCH] - Bug fixes --- baselines/fedpara/.gitignore | 2 + baselines/fedpara/fedpara/client.py | 75 +++++++++++------ baselines/fedpara/fedpara/conf/cifar10.yaml | 3 +- baselines/fedpara/fedpara/conf/cifar100.yaml | 2 - baselines/fedpara/fedpara/conf/mnist.yaml | 14 ++-- baselines/fedpara/fedpara/dataset.py | 28 +++++-- baselines/fedpara/fedpara/main.py | 51 ++++++------ baselines/fedpara/fedpara/models.py | 25 +++--- baselines/fedpara/fedpara/server.py | 3 - baselines/fedpara/fedpara/test.ipynb | 85 -------------------- baselines/fedpara/fedpara/utils.py | 72 ++++++++++++++++- baselines/fedper/fedper/server.py | 1 + 12 files changed, 191 insertions(+), 170 deletions(-) delete mode 100644 baselines/fedpara/fedpara/test.ipynb diff --git a/baselines/fedpara/.gitignore b/baselines/fedpara/.gitignore index de1e160448e5..6244dfada6ee 100644 --- a/baselines/fedpara/.gitignore +++ b/baselines/fedpara/.gitignore @@ -1,2 +1,4 @@ outputs/ multirun/ +client_states/ +data/ \ No newline at end of file diff --git a/baselines/fedpara/fedpara/client.py b/baselines/fedpara/fedpara/client.py index 92c0b0484458..f043a2cf4572 100644 --- a/baselines/fedpara/fedpara/client.py +++ b/baselines/fedpara/fedpara/client.py @@ -2,15 +2,14 @@ from collections import OrderedDict from typing import Callable, Dict, List, Tuple, Optional -import copy +import copy,os import flwr as fl import torch from flwr.common import NDArrays, Scalar from hydra.utils import instantiate from omegaconf import DictConfig from torch.utils.data import DataLoader - -from fedpara.models import train +from fedpara.models import train,test class FlowerClient(fl.client.NumPyClient): @@ -34,7 +33,7 @@ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: """Return the parameters of the current net.""" return [val.cpu().numpy() for _, val in self.net.state_dict().items()] - def _set_parameters(self, parameters: NDArrays) -> None: + def set_parameters(self, parameters: NDArrays) -> None: params_dict = zip(self.net.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) self.net.load_state_dict(state_dict, strict=True) @@ -43,7 +42,7 @@ def fit( self, parameters: NDArrays, config: Dict[str, Scalar] ) -> Tuple[NDArrays, int, Dict]: """Train the network on the training set.""" - self._set_parameters(parameters) + self.set_parameters(parameters) train( self.net, @@ -67,7 +66,7 @@ def __init__( cid: int, net: torch.nn.Module, train_loader: DataLoader, - test_dataset: List[DataLoader], + test_loader: DataLoader, device: str, num_epochs: int, state_path: str, @@ -76,25 +75,38 @@ def __init__( self.cid = cid self.net = net self.train_loader = train_loader - self.test_dataset = test_dataset + self.test_loader = test_loader self.device = torch.device(device) self.num_epochs = num_epochs self.state_path = state_path def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: """Return the parameters of the current net.""" - return [val.cpu().numpy() for _, val in self.net.state_dict().items()] - - def _set_parameters(self, parameters: NDArrays) -> None: - params_dict = zip(self.net.state_dict().keys(), parameters) - state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) - self.net.load_state_dict(state_dict, strict=True) - + return [val.cpu().detach().numpy() for _, val in self.net.get_per_param().items()] + + + def _set_parameters(self, parameters: NDArrays, first_round = False) -> None: + if first_round: + params_dict = zip(self.net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + self.net.load_state_dict(state_dict, strict=True) + else: + params_dict = zip(self.net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + self.net.set_per_param(state_dict) def fit( self, parameters: NDArrays, config: Dict[str, Scalar] ) -> Tuple[NDArrays, int, Dict]: """Train the network on the training set.""" - self._set_parameters(parameters) + if not os.path.isfile(self.state_path): + self._set_parameters(parameters,first_round=True) + else: + try: + self.net.load_state_dict(torch.load(self.state_path),strict=False) + except: + print(f"loading {self.state_path} state dict error") + self._set_parameters(parameters) + print(f"Client {self.cid} Training...") train( @@ -103,8 +115,11 @@ def fit( self.device, epochs=self.num_epochs, hyperparams=config, - round=config["curr_round"], + epoch=config["curr_round"], ) + if self.state_path is not None: + with open(self.state_path, 'wb') as f: + torch.save(self.net.get_per_param(), f) return ( self.get_parameters({}), @@ -113,14 +128,21 @@ def fit( ) def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[int, float, Dict]: """Evaluate the network on the test set.""" - self._set_parameters(parameters) - print(f"Client {self.cid} Evaluating...") - return ( - len(self.test_dataset[self.cid]), - train.test(self.net, self.test_dataset[self.cid], self.device), - {}, - ) + if not os.path.isfile(self.state_path): + self._set_parameters(parameters,first_round=True) + else: + try: + self.net.load_state_dict(torch.load(self.state_path),strict=False) + except: + print(f"loading {self.state_path} state dict error") + self._set_parameters(parameters) + + print(f"Client {self.cid} Evaluating...") + self.net.to(self.device) + loss, accuracy = test(self.net, self.test_loader, device=self.device) + return loss, len(self.test_loader), {"accuracy": accuracy} + def gen_client_fn( train_loaders: List[DataLoader], @@ -135,15 +157,16 @@ def gen_client_fn( def client_fn(cid: str) -> fl.client.NumPyClient: """Create a new FlowerClient for a given cid.""" cid = int(cid) - if args['algorithm'] == "pfedpara" or args['algorithm'] == "fedper": + if args['algorithm'].lower() == "pfedpara" or args['algorithm'] == "fedper": + cl_path = f"{state_path}/client_{cid}.pth" return PFedParaClient( cid=cid, net=instantiate(model).to(args["device"]), train_loader=train_loaders[cid], - test_dataset=copy.deepcopy(test_loader), + test_loader=copy.deepcopy(test_loader), device=args["device"], num_epochs=num_epochs, - state_path=state_path, + state_path=cl_path, ) else: return FlowerClient( diff --git a/baselines/fedpara/fedpara/conf/cifar10.yaml b/baselines/fedpara/fedpara/conf/cifar10.yaml index b8b0c25c6fdb..1c05e0ec1ca9 100644 --- a/baselines/fedpara/fedpara/conf/cifar10.yaml +++ b/baselines/fedpara/fedpara/conf/cifar10.yaml @@ -29,8 +29,7 @@ model: hyperparams: eta_l: 0.1 learning_decay: 0.992 - momentum: 0.0 - weight_decay: 0 + strategy: _target_: fedpara.strategy.FedPara diff --git a/baselines/fedpara/fedpara/conf/cifar100.yaml b/baselines/fedpara/fedpara/conf/cifar100.yaml index cb7eb73283c4..ee8402324519 100644 --- a/baselines/fedpara/fedpara/conf/cifar100.yaml +++ b/baselines/fedpara/fedpara/conf/cifar100.yaml @@ -29,8 +29,6 @@ model: hyperparams: eta_l: 0.1 learning_decay: 0.992 - momentum: 0.0 - weight_decay: 0 strategy: _target_: fedpara.strategy.FedPara diff --git a/baselines/fedpara/fedpara/conf/mnist.yaml b/baselines/fedpara/fedpara/conf/mnist.yaml index 8115cc3e11ab..2cb8f856bb08 100644 --- a/baselines/fedpara/fedpara/conf/mnist.yaml +++ b/baselines/fedpara/fedpara/conf/mnist.yaml @@ -6,8 +6,10 @@ num_rounds: 100 clients_per_round: 10 num_epochs: 5 batch_size: 10 -state_pah: ./state/ -server_device: cuda +state_path: ./client_states/ +client_device: cuda +algorithm: pFedPara + client_resources: num_cpus: 2 @@ -17,21 +19,23 @@ dataset_config: name: MNIST num_classes: 10 shard_size: 300 + data_seed: ${seed} + model: _target_: fedpara.models.FC num_classes: ${dataset_config.num_classes} - weights: lowrank # lowrank or standard + param_type: lowrank # lowrank or standard activation: relu # relu or leaky_relu ratio: 0.5 # lowrank ratio + algorithm: ${algorithm} hyperparams: eta_l: 0.01 learning_decay: 0.999 strategy: - _target_: fedpara.strategy.pFedPara - algorithm: pFedPara + _target_: fedpara.strategy.FedAvg fraction_fit: 0.00001 fraction_evaluate: 0.0 min_evaluate_clients: ${clients_per_round} diff --git a/baselines/fedpara/fedpara/dataset.py b/baselines/fedpara/fedpara/dataset.py index 8bc41f89608c..b4737d597ac2 100644 --- a/baselines/fedpara/fedpara/dataset.py +++ b/baselines/fedpara/fedpara/dataset.py @@ -14,7 +14,7 @@ def load_datasets( ) -> Tuple[List[DataLoader], DataLoader]: """Load the dataset and return the dataloaders for the clients and the server.""" print("Loading data...") - match config.name: + match config['name']: case "CIFAR10": Dataset = datasets.CIFAR10 case "CIFAR100": @@ -23,8 +23,8 @@ def load_datasets( Dataset = datasets.MNIST case _: raise NotImplementedError - data_directory = f"./data/{config.name.lower()}/" - match config.name: + data_directory = f"./data/{config['name'].lower()}/" + match config['name']: case "CIFAR10" | "CIFAR100": ds_path = f"{data_directory}train_{num_clients}_{config.alpha:.2f}.pkl" transform_train = transforms.Compose( @@ -44,6 +44,12 @@ def load_datasets( try: with open(ds_path, "rb") as file: train_datasets = pickle.load(file) + dataset_train = Dataset( + data_directory, train=True, download=False, transform=transform_train + ) + dataset_test = Dataset( + data_directory, train=False, download=False, transform=transform_test + ) except FileNotFoundError: dataset_train = Dataset( data_directory, train=True, download=True, transform=transform_train @@ -54,6 +60,9 @@ def load_datasets( train_datasets, _ = noniid(dataset_train, num_clients, config.alpha) pickle.dump(train_datasets, open(ds_path, "wb")) train_datasets = train_datasets.values() + dataset_test = Dataset( + data_directory, train=False, download=True, transform=transform_test + ) case "MNIST": ds_path = f"{data_directory}train_{num_clients}.pkl" @@ -68,16 +77,23 @@ def load_datasets( ) try: train_datasets = pickle.load(open(ds_path, "rb")) + dataset_train = Dataset( + data_directory, train=True, download=False, transform=transform_train + ) + dataset_test = Dataset( + data_directory, train=False, download=False, transform=transform_test + ) except FileNotFoundError: dataset_train = Dataset( data_directory, train=True, download=True, transform=transform_train ) train_datasets = mnist_niid(dataset_train, num_clients, config.shard_size, config.data_seed) pickle.dump(train_datasets, open(ds_path, "wb")) + dataset_test = Dataset( + data_directory, train=False, download=True, transform=transform_test + ) + - dataset_test = Dataset( - data_directory, train=False, download=True, transform=transform_test - ) test_loader = DataLoader(dataset_test, batch_size=batch_size, num_workers=2) train_loaders = [ DataLoader( diff --git a/baselines/fedpara/fedpara/main.py b/baselines/fedpara/fedpara/main.py index dca22b7ef2da..d94168129996 100644 --- a/baselines/fedpara/fedpara/main.py +++ b/baselines/fedpara/fedpara/main.py @@ -5,13 +5,12 @@ from hydra.core.hydra_config import HydraConfig from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf - from fedpara import client, server, utils from fedpara.dataset import load_datasets -from fedpara.utils import get_parameters, seed_everything +from fedpara.utils import get_parameters, save_results_as_pickle, seed_everything, set_client_state_save_path -@hydra.main(config_path="conf", config_name="cifar10", version_base=None) +@hydra.main(config_path="conf", config_name="mnist", version_base=None) def main(cfg: DictConfig) -> None: """Run the baseline. @@ -24,6 +23,9 @@ def main(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) seed_everything(cfg.seed) OmegaConf.to_container(cfg, resolve=True) + if 'state_path' in cfg: state_path=set_client_state_save_path(cfg.state_path) + else: state_path = None + # 2. Prepare dataset train_loaders, test_loader = load_datasets( config=cfg.dataset_config, @@ -33,31 +35,15 @@ def main(cfg: DictConfig) -> None: # 3. Define clients # In this scheme the responsability of choosing the client is on the client manager - if cfg.strategy.min_evaluate_clients: - client_fn = client.gen_client_fn( - train_loaders=train_loaders, - test_loader=test_loader, - model=cfg.model, - num_epochs=cfg.num_epochs, - args={"device": cfg.client_device, "algorithm": cfg.strategy.algorithm}, - state_path=cfg.state_path, - ) - else : - client_fn = client.gen_client_fn( - train_loaders=train_loaders, - model=cfg.model, - num_epochs=cfg.num_epochs, - args={"device": cfg.client_device, "algorithm": cfg.strategy.algorithm}, - ) - if not cfg.strategy.min_evaluate_clients : - evaluate_fn = server.gen_evaluate_fn( - num_clients=cfg.num_clients, - test_loader=test_loader, - model=cfg.model, - device=cfg.server_device, - state_path=cfg.state_path, - ) + client_fn = client.gen_client_fn( + train_loaders=train_loaders, + test_loader=test_loader, + model=cfg.model, + num_epochs=cfg.num_epochs, + args={"device": cfg.client_device, "algorithm": cfg.algorithm}, + state_path=state_path, + ) def get_on_fit_config(): def fit_config_fn(server_round: int): @@ -76,7 +62,15 @@ def fit_config_fn(server_round: int): on_fit_config_fn=get_on_fit_config(), initial_parameters=fl.common.ndarrays_to_parameters(get_parameters(net_glob)), ) + else : + evaluate_fn = server.gen_evaluate_fn( + num_clients=cfg.num_clients, + test_loader=test_loader, + model=cfg.model, + device=cfg.server_device, + state_path=cfg.state_path, + ) strategy = instantiate( cfg.strategy, evaluate_fn=evaluate_fn, @@ -98,7 +92,8 @@ def fit_config_fn(server_round: int): "_memory": 30 * 1024 * 1024 * 1024, }, ) - + save_results_as_pickle(history, ) + # 6. Save results save_path = HydraConfig.get().runtime.output_dir file_suffix = "_".join( diff --git a/baselines/fedpara/fedpara/models.py b/baselines/fedpara/fedpara/models.py index 233f76f935ec..a4861fc0e729 100644 --- a/baselines/fedpara/fedpara/models.py +++ b/baselines/fedpara/fedpara/models.py @@ -9,6 +9,7 @@ from torch import nn from torch.nn import init from torch.utils.data import DataLoader + class LowRankNN(nn.Module): def __init__(self,input, output, rank,activation: str = "relu",) -> None: super(LowRankNN, self).__init__() @@ -26,8 +27,8 @@ def __init__(self,input, output, rank,activation: str = "relu",) -> None: init.kaiming_normal_(self.X, mode="fan_out", nonlinearity=activation) init.kaiming_normal_(self.Y, mode="fan_out", nonlinearity=activation) - def forward(self,x): - out = torch.einsum("xr,yr->xy", self.X, self.Y) + def forward(self): + out = torch.einsum("yr,xr->yx",self.Y, self.X) return out class Linear(nn.Module): @@ -75,9 +76,10 @@ def forward(self,x): class FC(nn.Module): - def __init__(self, input_size=28**2, hidden_size=256, num_classes=10, ratio=0.1, param_type="standard",activation: str = "relu",): + def __init__(self, input_size=28**2, hidden_size=256, num_classes=10, ratio=0.5, param_type="lowrank",activation: str = "relu",algorithm='pfedpara'): super(FC, self).__init__() - + self.input_size = input_size + self.method = algorithm.lower() if param_type == "standard": self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() @@ -90,8 +92,8 @@ def __init__(self, input_size=28**2, hidden_size=256, num_classes=10, ratio=0.1, self.softmax = nn.Softmax(dim=1) else: raise ValueError("param_type must be either standard or lowrank") - @property - def per_param(self): + + def get_per_param(self): """ Return the personalized parameters of the model """ @@ -103,8 +105,8 @@ def per_param(self): else: raise ValueError("method must be either pfedpara, fedper") return params - @property - def load_per_param(self,state_dict): + + def set_per_param(self,state_dict): """ Load the personalized parameters of the model """ @@ -112,7 +114,7 @@ def load_per_param(self,state_dict): self.fc1.w1.X = state_dict["fc1.X"] self.fc1.w1.Y = state_dict["fc1.Y"] self.fc2.w1.X = state_dict["fc2.X"] - self.fc2.w1.Y = state_dict["fc2.Y"] + self.fc2.w1.Y = state_dict["fcstate_path2.Y"] elif self.method == "fedper": self.fc2.w1 = state_dict["fc2.w1"] self.fc2.w2 = state_dict["fc2.w2"] @@ -135,6 +137,7 @@ def model_size(self): return total_trainable_params, size_all_mb def forward(self,x): + x = x.view(-1, self.input_size) out = self.fc1(x) out = self.relu(out) out = self.fc2(out) @@ -456,8 +459,8 @@ def train( # pylint: disable=too-many-arguments optimizer = torch.optim.SGD( net.parameters(), lr=lr, - momentum=hyperparams["momentum"], - weight_decay=hyperparams["weight_decay"], + momentum=0, + weight_decay=0, ) net.train() for _ in range(epochs): diff --git a/baselines/fedpara/fedpara/server.py b/baselines/fedpara/fedpara/server.py index 275076b900d3..f73618ab8569 100644 --- a/baselines/fedpara/fedpara/server.py +++ b/baselines/fedpara/fedpara/server.py @@ -2,13 +2,11 @@ from collections import OrderedDict from typing import Callable, Dict, Optional, Tuple - import torch from flwr.common import NDArrays, Scalar from hydra.utils import instantiate from omegaconf import DictConfig from torch.utils.data import DataLoader - from fedpara.models import test @@ -18,7 +16,6 @@ def get_on_fit_config(hypearparams: Dict): def fit_config_fn(server_round: int): hypearparams["curr_round"] = server_round return hypearparams - return fit_config_fn diff --git a/baselines/fedpara/fedpara/test.ipynb b/baselines/fedpara/fedpara/test.ipynb deleted file mode 100644 index a25ae10e7d6b..000000000000 --- a/baselines/fedpara/fedpara/test.ipynb +++ /dev/null @@ -1,85 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torchvision.datasets\n", - "import torchvision.transforms as transforms\n", - "import numpy as np\n", - "from collections import Counter\n", - "import logging" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "data = torchvision.datasets.MNIST\n", - "transform= transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n", - "trainset = data(root='./data', train = True, download = True, transform = transform)" - ] - }, - { - "cell_type": "code", - "execution_count": 89, - "metadata": {}, - "outputs": [], - "source": [ - "def mnist_niid(dataset: Dataset, num_clients: int, silo_size: int, seed: int) -> list:\n", - " indices = trainset.targets[np.argsort(trainset.targets)].numpy()\n", - " logging.debug(Counter(trainset.targets[indices].numpy()))\n", - " silos = np.array_split(indices, len(trainset) // 300)# randomly assign silos to clients\n", - " np.random.seed(seed+17)\n", - " np.random.shuffle(silos)\n", - " clients = np.array(np.array_split(silos, 100)).reshape(100, -1)\n", - " logging.debug(clients.shape)\n", - " logging.debug(Counter([len(Counter(trainset.targets[client].numpy())) for client in clients]))\n", - " return clients" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Counter({2: 82, 1: 11, 3: 7})" - ] - }, - "execution_count": 98, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "flower", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/baselines/fedpara/fedpara/utils.py b/baselines/fedpara/fedpara/utils.py index cbb6a1a75b66..47b24b2a50f6 100644 --- a/baselines/fedpara/fedpara/utils.py +++ b/baselines/fedpara/fedpara/utils.py @@ -2,7 +2,7 @@ import random from pathlib import Path - +from typing import Optional, Union import matplotlib.pyplot as plt import numpy as np import torch @@ -10,7 +10,7 @@ from flwr.server import History from omegaconf import DictConfig from torch.nn import Module - +import time, os, pickle def plot_metric_from_history( hist: History, @@ -70,3 +70,71 @@ def seed_everything(seed): def get_parameters(net: Module) -> NDArrays: """Get the parameters of the network.""" return [val.cpu().numpy() for _, val in net.state_dict().items()] + +def save_results_as_pickle( + history: History, + file_path: Union[str, Path], + default_filename: Optional[str] = "results.pkl", +) -> None: + """Save results from simulation to pickle. + + Parameters + ---------- + history: History + History returned by start_simulation. + file_path: Union[str, Path] + Path to file to create and store both history and extra_results. + If path is a directory, the default_filename will be used. + path doesn't exist, it will be created. If file exists, a + randomly generated suffix will be added to the file name. This + is done to avoid overwritting results. + extra_results : Optional[Dict] + A dictionary containing additional results you would like + to be saved to disk. Default: {} (an empty dictionary) + default_filename: Optional[str] + File used by default if file_path points to a directory instead + to a file. Default: "results.pkl" + """ + path = Path(file_path) + + # ensure path exists + path.mkdir(exist_ok=True, parents=True) + + def _add_random_suffix(path_: Path): + """Add a random suffix to the file name.""" + print(f"File `{path_}` exists! ") + suffix = token_hex(4) + print(f"New results to be saved with suffix: {suffix}") + return path_.parent / (path_.stem + "_" + suffix + ".pkl") + + def _complete_path_with_default_name(path_: Path): + """Append the default file name to the path.""" + print("Using default filename") + if default_filename is None: + return path_ + return path_ / default_filename + + if path.is_dir(): + path = _complete_path_with_default_name(path) + + if path.is_file(): + path = _add_random_suffix(path) + + print(f"Results will be saved into: {path}") + # data = {"history": history, **extra_results} + data = {"history": history} + # save results to pickle + with open(str(path), "wb") as handle: + pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) + + +def set_client_state_save_path(path: str) -> str: + """Set the client state save path.""" + client_state_save_path = time.strftime("%Y-%m-%d") + client_state_sub_path = time.strftime("%H-%M-%S") + client_state_save_path = ( + f"{path}{client_state_save_path}/{client_state_sub_path}" + ) + if not os.path.exists(client_state_save_path): + os.makedirs(client_state_save_path) + return client_state_save_path diff --git a/baselines/fedper/fedper/server.py b/baselines/fedper/fedper/server.py index 93616f50f45a..50a4f8c5d8a8 100644 --- a/baselines/fedper/fedper/server.py +++ b/baselines/fedper/fedper/server.py @@ -1,4 +1,5 @@ """Server strategies pipelines for FedPer.""" + from flwr.server.strategy.fedavg import FedAvg from fedper.strategy import (