From 5dbb153872f277746df0196fe1d268e9e30d523b Mon Sep 17 00:00:00 2001 From: edogab33 Date: Wed, 20 Mar 2024 19:28:14 +0100 Subject: [PATCH] Fix lint --- baselines/flanders/flanders/dataset.py | 1 + baselines/flanders/flanders/main.py | 22 ++++++++++++---------- baselines/flanders/flanders/models.py | 2 +- baselines/flanders/flanders/server.py | 2 +- baselines/flanders/flanders/strategy.py | 4 ++-- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/baselines/flanders/flanders/dataset.py b/baselines/flanders/flanders/dataset.py index 28b43a4021f2..c12859147ce4 100644 --- a/baselines/flanders/flanders/dataset.py +++ b/baselines/flanders/flanders/dataset.py @@ -42,6 +42,7 @@ def get_dataset(path_to_data: Path, cid: str, partition: str, transform=None): return TorchVisionFL(path_to_data, transform=transform) +# pylint: disable=too-many-arguments, too-many-locals def get_dataloader( path_to_data: str, cid: str, diff --git a/baselines/flanders/flanders/main.py b/baselines/flanders/flanders/main.py index 90f130b6ac0b..278f847ffbdd 100644 --- a/baselines/flanders/flanders/main.py +++ b/baselines/flanders/flanders/main.py @@ -35,7 +35,9 @@ def main(cfg: DictConfig) -> None: # 0. Set random seed seed = cfg.seed np.random.seed(seed) - np.random.set_state(np.random.RandomState(seed).get_state()) + np.random.set_state( + np.random.RandomState(seed).get_state() + ) # pylint: disable=no-member random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -104,7 +106,7 @@ def main(cfg: DictConfig) -> None: # 3. Define your clients # pylint: disable=no-else-return - def client_fn(cid: str, pool_size: int = 10, dataset_name: str = dataset_name): + def client_fn(cid: str, dataset_name: str = dataset_name): client = clients[dataset_name][0] if dataset_name in ["mnist", "fmnist", "cifar", "cifar100"]: return client(cid, fed_dir) @@ -242,10 +244,10 @@ def client_fn(cid: str, pool_size: int = 10, dataset_name: str = dataset_name): rounds, test_loss = zip(*history.losses_centralized) _, test_accuracy = zip(*history.metrics_centralized["accuracy"]) _, test_auc = zip(*history.metrics_centralized["auc"]) - _, tp = zip(*history.metrics_centralized["TP"]) - _, tn = zip(*history.metrics_centralized["TN"]) - _, fp = zip(*history.metrics_centralized["FP"]) - _, fn = zip(*history.metrics_centralized["FN"]) + _, truep = zip(*history.metrics_centralized["TP"]) + _, truen = zip(*history.metrics_centralized["TN"]) + _, falsep = zip(*history.metrics_centralized["FP"]) + _, falsen = zip(*history.metrics_centralized["FN"]) path_to_save = [os.path.join(save_path, "results.csv"), "outputs/all_results.csv"] @@ -256,10 +258,10 @@ def client_fn(cid: str, pool_size: int = 10, dataset_name: str = dataset_name): "loss": test_loss, "accuracy": test_accuracy, "auc": test_auc, - "TP": tp, - "TN": tn, - "FP": fp, - "FN": fn, + "TP": truep, + "TN": truen, + "FP": falsep, + "FN": falsen, "attack_fn": [attack_fn for _ in range(len(rounds))], "dataset_name": [dataset_name for _ in range(len(rounds))], "num_malicious": [num_malicious for _ in range(len(rounds))], diff --git a/baselines/flanders/flanders/models.py b/baselines/flanders/flanders/models.py index a129442a1954..d8d28da52861 100644 --- a/baselines/flanders/flanders/models.py +++ b/baselines/flanders/flanders/models.py @@ -21,7 +21,7 @@ class MnistNet(nn.Module): """Neural network for MNIST classification.""" def __init__(self): - super(MnistNet, self).__init__() + super().__init__() self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 10) diff --git a/baselines/flanders/flanders/server.py b/baselines/flanders/flanders/server.py index 14eb082faa4b..df2e53b867ac 100644 --- a/baselines/flanders/flanders/server.py +++ b/baselines/flanders/flanders/server.py @@ -286,7 +286,7 @@ def fit_round( # the server simulates an attacker that controls a fraction of the clients if self.attack_fn is not None and server_round > self.warmup_rounds: log(INFO, "Applying attack function") - results, others = self.attack_fn( + results, _ = self.attack_fn( ordered_results, clients_state, omniscent=self.omniscent, diff --git a/baselines/flanders/flanders/strategy.py b/baselines/flanders/flanders/strategy.py index 17908e60185d..389cbba41793 100644 --- a/baselines/flanders/flanders/strategy.py +++ b/baselines/flanders/flanders/strategy.py @@ -295,8 +295,8 @@ def aggregate_fit( parameters_aggregated = ndarrays_to_parameters( self.aggregate_fn(weights_results, **aggregate_parameters) ) - except ValueError as e: - log(WARNING, f"Error in aggregate_fn: {e}") + except ValueError as err: + log(WARNING, "Error in aggregate_fn: %s", err) parameters_aggregated = ndarrays_to_parameters( aggregate(weights_results) )