diff --git a/baselines/flanders/.gitignore b/baselines/flanders/.gitignore new file mode 100644 index 000000000000..4187d73689f0 --- /dev/null +++ b/baselines/flanders/.gitignore @@ -0,0 +1,9 @@ +outputs/* +clients_params/* +flanders/datasets_files/* +*.log +flanders/__pycache__ +MNIST +.DS_Store +*/__pycache__ +multirun \ No newline at end of file diff --git a/baselines/flanders/LICENSE b/baselines/flanders/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/baselines/flanders/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/baselines/flanders/README.md b/baselines/flanders/README.md new file mode 100644 index 000000000000..f5ab6a02d6f3 --- /dev/null +++ b/baselines/flanders/README.md @@ -0,0 +1,157 @@ +--- +title: Protecting Federated Learning from Extreme Model Poisoning Attacks via Multidimensional Time Series Anomaly Detection +url: https://arxiv.org/abs/2303.16668 +labels: [robustness, model poisoning, anomaly detection, autoregressive model, regression, classification] +dataset: [MNIST, FashionMNIST] +--- + +**Paper:** [arxiv.org/abs/2303.16668](https://arxiv.org/abs/2303.16668) + +**Authors:** Edoardo Gabrielli, Gabriele Tolomei, Dimitri Belli, Vittorio Miori + +**Abstract:** Current defense mechanisms against model poisoning attacks in federated learning (FL) systems have proven effective up to a certain threshold of malicious clients. In this work, we introduce FLANDERS, a novel pre-aggregation filter for FL resilient to large-scale model poisoning attacks, i.e., when malicious clients far exceed legitimate participants. FLANDERS treats the sequence of local models sent by clients in each FL round as a matrix-valued time series. Then, it identifies malicious client updates as outliers in this time series by comparing actual observations with estimates generated by a matrix autoregressive forecasting model maintained by the server. Experiments conducted in several non-iid FL setups show that FLANDERS significantly improves robustness across a wide spectrum of attacks when paired with standard and robust existing aggregation methods. + +## About this baseline + +**What’s implemented:** The code in this directory replicates the results of FLANDERS+\[baseline\] on MNIST and Fashion-MNIST under all attack settings: Gaussian, LIE, OPT, and AGR-MM; with $r=[0.2,0.6,0.8]$ (i.e., the fraction of malicious clients), specifically about tables 1, 3, 10, 11, 15, 17, 19, 20 and Figure 3. + +**Datasets:** MNIST, FMNIST + +**Hardware Setup:** AMD Ryzen 9, 64 GB RAM, and an NVIDIA 4090 GPU with 24 GB VRAM. + +**Estimated time to run:** You can expect to run experiments on the given setup in 2m with *MNIST* and 3m with *Fashion-MNIST*, without attacks. With an Apple M2 Pro, 16gb RAM, each experiment with 10 clients for MNIST runs in about 24 minutes. Note that experiments with OPT (fang) and AGR-MM (minmax) can be up to 5x times slower. + +**Contributors:** Edoardo Gabrielli, Sapienza University of Rome ([GitHub](https://github.com/edogab33), [Scholar](https://scholar.google.com/citations?user=b3bePdYAAAAJ)) + + +## Experimental Setup + +Please, checkout Appendix F and G of the paper for a comprehensive overview of the hyperparameters setup, however here's a summary. + +**Task:** Image classification + +**Models:** + +MNIST (multilabel classification, fully connected, feed forward NN): +- Multilevel Perceptron (MLP) +- minimizing multiclass cross-entropy loss using Adam optimizer +- input: 784 +- hidden layer 1: 128 +- hidden layer 2: 256 + +Fashion-MNIST (multilabel classification, fully connected, feed forward NN): +- Multilevel Perceptron (MLP) +- minimizing multiclass cross-entropy loss using Adam optimizer +- input: 784 +- hidden layer 1: 256 +- hidden layer 2: 128 +- hidden layer 3: 64 + +**Dataset:** Every dataset is partitioned into two disjoint sets: 80% for training and 20% for testing. The training set is distributed across all clients (100) by using the Dirichlet distribution with $\alpha=0.5$, simulating a high non-i.i.d. scenario, while the testing set is uniform and held by the server to evaluate the global model. + +| Description | Default Value | +| ----------- | ----- | +| Partitions | 100 | +| Evaluation | centralized | +| Training set | 80% | +| Testing set | 20% | +| Distribution | Dirichlet | +| $\alpha$ | 0.5 | + +**Training Hyperparameters:** + +| Dataset | # of clients | Clients per round | # of rounds | Batch size | Learning rate | Optimizer | Dropout | Alpha | Beta | # of clients to keep | Sampling | +| -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | +| MNIST | 100 | 100 | 50 | 32 | $10^{-3}$ | Adam | 0.2 | 0.0 | 0.0 | $m - b$ | 500 | +| FMNIST | 100 | 100 | 50 | 32 | $10^{-3}$ | Adam | 0.2 | 0.0 | 0.0 | $m - b$ | 500 | + +Where $m$ is the number of clients partecipating during n-th round and $b$ is the number of malicious clients. The variable $sampling$ identifies how many parameters MAR analyzes. + + +## Environment Setup + +```bash +# Use a version of Python >=3.9 and <3.12.0. +pyenv local 3.10.12 +poetry env use 3.10.12 + +# Install everything from the toml +poetry install + +# Activate the env +poetry shell +``` + + +## Running the Experiments +Ensure that the environment is properly set up, then run: + +```bash +python -m flanders.main +``` + +To execute a single experiment with the default values in `conf/base.yaml`. + +To run custom experiments, you can override the default values like that: + +```bash +python -m flanders.main dataset=mnist server.attack_fn=lie server.num_malicious=1 +``` + +To run multiple custom experiments: + +```bash +python -m flanders.main --multirun dataset=mnist,fmnist server.attack_fn=gaussian,lie,fang,minmax server.num_malicious=0,1,2,3,4,5 +``` + +## Expected Results + +To run all the experiments of the paper (for MNIST and Fashion-MNIST), I've set up a script: + +```bash +sh run.sh +``` + +This code will produce the output in the file `outputs/all_results.csv`. To generate the plots and tables displayed below, you can use the notebook in the `plotting/` directory. + + +### Accuracy over multiple rounds +**(left) MNIST, FLANDERS+FedAvg with 80% of malicious clients (b = 80); (right) Vanilla FedAvg in the same setting:** + +![acc_over_rounds](_static/screenshot-8.png) + +### Precision and Recall of FLANDERS + +**b = 20:** + +![alt text](_static/screenshot-4.png) +--- + +**b = 60:** + +![alt text](_static/screenshot-5.png) +--- +**b = 80:** + +![alt text](_static/screenshot-6.png) + + +### Accuracy w.r.t. number of attackers: +**b = 0:** + +![alt text](_static/screenshot.png) + +--- +**b = 20:** + +![alt text](_static/screenshot-1.png) + +--- +**b = 60:** + +![alt text](_static/screenshot-2.png) + +--- +**b = 80:** + +![alt text](_static/screenshot-3.png) diff --git a/baselines/flanders/_static/screenshot-1.png b/baselines/flanders/_static/screenshot-1.png new file mode 100644 index 000000000000..f9c14a7e72f2 Binary files /dev/null and b/baselines/flanders/_static/screenshot-1.png differ diff --git a/baselines/flanders/_static/screenshot-2.png b/baselines/flanders/_static/screenshot-2.png new file mode 100644 index 000000000000..7aacd2ba5778 Binary files /dev/null and b/baselines/flanders/_static/screenshot-2.png differ diff --git a/baselines/flanders/_static/screenshot-3.png b/baselines/flanders/_static/screenshot-3.png new file mode 100644 index 000000000000..978ed4902bf5 Binary files /dev/null and b/baselines/flanders/_static/screenshot-3.png differ diff --git a/baselines/flanders/_static/screenshot-4.png b/baselines/flanders/_static/screenshot-4.png new file mode 100644 index 000000000000..5a24c47ff513 Binary files /dev/null and b/baselines/flanders/_static/screenshot-4.png differ diff --git a/baselines/flanders/_static/screenshot-5.png b/baselines/flanders/_static/screenshot-5.png new file mode 100644 index 000000000000..e0defab01d22 Binary files /dev/null and b/baselines/flanders/_static/screenshot-5.png differ diff --git a/baselines/flanders/_static/screenshot-6.png b/baselines/flanders/_static/screenshot-6.png new file mode 100644 index 000000000000..bfb3120fef7b Binary files /dev/null and b/baselines/flanders/_static/screenshot-6.png differ diff --git a/baselines/flanders/_static/screenshot-8.png b/baselines/flanders/_static/screenshot-8.png new file mode 100644 index 000000000000..cda98c21d034 Binary files /dev/null and b/baselines/flanders/_static/screenshot-8.png differ diff --git a/baselines/flanders/_static/screenshot.png b/baselines/flanders/_static/screenshot.png new file mode 100644 index 000000000000..537ebb66c123 Binary files /dev/null and b/baselines/flanders/_static/screenshot.png differ diff --git a/baselines/flanders/flanders/__init__.py b/baselines/flanders/flanders/__init__.py new file mode 100644 index 000000000000..eb3edd489459 --- /dev/null +++ b/baselines/flanders/flanders/__init__.py @@ -0,0 +1 @@ +"""FLANDERS package.""" diff --git a/baselines/flanders/flanders/attacks.py b/baselines/flanders/flanders/attacks.py new file mode 100644 index 000000000000..9b1acd9ad639 --- /dev/null +++ b/baselines/flanders/flanders/attacks.py @@ -0,0 +1,493 @@ +"""Implementation of attacks used in the paper.""" + +import math +from typing import Dict, List, Tuple + +import numpy as np +from flwr.common import FitRes, ndarrays_to_parameters, parameters_to_ndarrays +from flwr.server.client_proxy import ClientProxy +from scipy.stats import norm + + +# pylint: disable=unused-argument +def no_attack( + ordered_results: List[Tuple[ClientProxy, FitRes]], states: Dict[str, bool], **kwargs +): + """No attack.""" + return ordered_results, {} + + +def gaussian_attack(ordered_results, states, **kwargs): + """Apply Gaussian attack on parameters. + + Parameters + ---------- + ordered_results + List of tuples (client_proxy, fit_result) ordered by client id. + states + Dictionary of client ids and their states (True if malicious, False otherwise). + magnitude + Magnitude of the attack. + dataset_name + Name of the dataset. + + Returns + ------- + results + List of tuples (client_proxy, fit_result) ordered by client id. + """ + magnitude = kwargs.get("magnitude", 0.0) + dataset_name = kwargs.get("dataset_name", "no name") + results = ordered_results.copy() + + def perturbate(vect): + return vect + np.random.normal(loc=0, scale=magnitude, size=vect.size) + + for proxy, fitres in ordered_results: + if states[fitres.metrics["cid"]]: + params = parameters_to_ndarrays(fitres.parameters) + if dataset_name == "income": + new_params = [perturbate(layer) for layer in params] + else: + new_params = [] + for par in params: + # if par is an array of one element, it is a scalar + if par.size == 1: + new_params.append(perturbate(par)) + else: + new_params.append(np.apply_along_axis(perturbate, 0, par)) + fitres.parameters = ndarrays_to_parameters(new_params) + results[int(fitres.metrics["cid"])] = (proxy, fitres) + return results, {} + + +# pylint: disable=too-many-locals, unused-argument +def lie_attack( + ordered_results, + states, + omniscent=True, + **kwargs, +): + """Apply Omniscent LIE attack, Baruch et al. (2019) on parameters. + + Parameters + ---------- + ordered_results + List of tuples (client_proxy, fit_result) ordered by client id. + states + Dictionary of client ids and their states (True if malicious, False otherwise). + omniscent + Whether the attacker knows the local models of all clients or not. + + Returns + ------- + results + List of tuples (client_proxy, fit_result) ordered by client id. + """ + results = ordered_results.copy() + params = [parameters_to_ndarrays(fitres.parameters) for _, fitres in results] + grads_mean = [np.mean(layer, axis=0) for layer in zip(*params)] + grads_stdev = [np.std(layer, axis=0) ** 0.5 for layer in zip(*params)] + + if not omniscent: + # if not omniscent, the attacker doesn't know the + # local models of all clients, but only of the corrupted ones + params = [ + params[i] + for i in range(len(params)) + if states[results[i][1].metrics["cid"]] + ] + + num_clients = len(ordered_results) + num_malicious = sum(val is True for val in states.values()) + + # pylint: disable=c-extension-no-member + num_supporters = math.floor((num_clients / 2) + 1) - num_malicious + + z_max = norm.cdf( + (num_clients - num_malicious - num_supporters) / (num_clients - num_malicious) + ) + + for proxy, fitres in ordered_results: + if states[fitres.metrics["cid"]]: + mul_std = [layer * z_max for layer in grads_stdev] + new_params = [grads_mean[i] - mul_std[i] for i in range(len(grads_mean))] + fitres.parameters = ndarrays_to_parameters(new_params) + results[int(fitres.metrics["cid"])] = (proxy, fitres) + return results, {} + + +def fang_attack( + ordered_results, + states, + omniscent=True, + **kwargs, +): + """Apply Local Model Poisoning Attacks. + + (Fang et al. (2020)) + Specifically designed for Krum, but they claim it works for other + aggregation functions as well. + Full-knowledge version (attackers knows the local models of all clients). + + Parameters + ---------- + ordered_results + List of tuples (client_proxy, fit_result) ordered by client id. + states + Dictionary of client ids and their states (True if malicious, False + otherwise). + omniscent + Whether the attacker knows the local models of all clients or not. + num_layers + Number of layers. + w_re + The received global model. + old_lambda + The lambda from the previous round. + threshold + The threshold for lambda. + malicious_selected + Whether the attacker was selected as malicious in the previous round. + + Returns + ------- + results + List of tuples (client_proxy, fit_result) ordered by client id. + """ + num_layers = kwargs.get("num_layers", 2) + w_re = kwargs.get("w_re", None) # the received global model + threshold = kwargs.get("threshold", 1e-5) + + num_clients = len(ordered_results) + num_corrupted = sum(val is True for val in states.values()) + # there can't be an attack with less than 2 malicious clients + # to avoid division by 0 + num_corrupted = max(num_corrupted, 2) + + if not omniscent: + # if not omniscent, the attacker doesn't know the + # local models of all clients, but only of the corrupted ones + ordered_results = [ + ordered_results[i] + for i in range(len(ordered_results)) + if states[ordered_results[i][1].metrics["cid"]] + ] + + # Initialize lambda + benign = [ + (parameters_to_ndarrays(fitres.parameters), fitres.num_examples) + for _, fitres in ordered_results + if states[fitres.metrics["cid"]] is False + ] + all_params = [ + (parameters_to_ndarrays(fitres.parameters), fitres.num_examples) + for _, fitres in ordered_results + ] + # Compute the smallest distance that Krum would choose + _, _, _, distances = _krum(all_params, num_corrupted, 1) + + idx_benign = [int(cid) for cid in states.keys() if states[cid] is False] + + min_dist = np.min(np.array(distances)[idx_benign]) / ( + ((num_clients - 2) * (num_corrupted - 1)) * np.sqrt(num_layers) + ) + + # Compute max distance from w_re + dist_wre = np.zeros((len(benign))) + for i in range(len(benign)): + dist = [benign[i][0][j] - w_re[j] for j in range(num_layers)] + norm_sums = 0 + for k in dist: + norm_sums += np.linalg.norm(k) + dist_wre[i] = norm_sums**2 + max_dist = np.max(dist_wre) / np.sqrt(num_layers) + lamda = min( + min_dist + max_dist, 999 + ) # lambda (capped to 999 to avoid numerical problems in specific settings) + + malicious_selected, corrupted_params = _fang_corrupt_and_select( + all_params, w_re, states, num_corrupted, lamda + ) + while lamda > threshold and malicious_selected is False: + lamda = lamda * 0.5 + malicious_selected, corrupted_params = _fang_corrupt_and_select( + all_params, w_re, states, num_corrupted, lamda + ) + + # Set corrupted clients' updates to w_1 + results = [ + ( + ( + proxy, + FitRes( + fitres.status, + parameters=ndarrays_to_parameters(corrupted_params), + num_examples=fitres.num_examples, + metrics=fitres.metrics, + ), + ) + if states[fitres.metrics["cid"]] + else (proxy, fitres) + ) + for proxy, fitres in ordered_results + ] + + return results, {} + + +def minmax_attack( + ordered_results, + states, + omniscent=True, + **kwargs, +): + """Apply Min-Max agnostic attack. + + Full-knowledge, perturbation function chosen according to our experimental + results. + From: + "Manipulating the Byzantine: Optimizing Model Poisoning Attacks and + Defenses for Federated Learning" (Shejwalkar et al., 2021) + + Parameters + ---------- + ordered_results + List of tuples (client_proxy, fit_result) ordered by client id. + states + Dictionary of client ids and their states (True if malicious, False + otherwise). + omniscent + Whether the attacker knows the local models of all clients or not. + threshold + Threshold for lambda. + lambda_init + Initial value for lambda. + + Returns + ------- + results + List of tuples (client_proxy, fit_result) ordered by client id. + """ + dataset_name = kwargs.get("dataset_name", None) + threshold = kwargs.get("threshold", 1e-5) + lambda_init = kwargs.get("lambda", 5.0) + malicious_num = kwargs.get("malicious_num", 0) + + results = ordered_results.copy() + params = [parameters_to_ndarrays(fitres.parameters) for _, fitres in results] + params_avg = [np.mean(param, axis=0) for param in zip(*params)] + + if not omniscent: + # if not omniscent, the attacker doesn't know the + # local models of all clients, but only of the corrupted ones + results = [ + results[i] + for i in range(len(results)) + if states[results[i][1].metrics["cid"]] + ] + + # Decide what perturbation to use according to the + # results presented in the paper. + if dataset_name == "mnist": + # Apply std perturbation + # In the paper authors state that sign function is the best + # but in my experience std perturbation works better + perturbation_vect = [-np.std(layer, axis=0) for layer in zip(*params)] + elif dataset_name == "cifar": + # Apply std perturbation + perturbation_vect = [-np.std(layer, axis=0) for layer in zip(*params)] + else: + # Apply std perturbation + perturbation_vect = [-np.std(layer, axis=0) for layer in zip(*params)] + + # Compute lambda (referred as gamma in the paper) + lambda_succ = lambda_init + 1 + curr_lambda = lambda_init + step = lambda_init * 0.5 + while ( + abs(lambda_succ - curr_lambda) > threshold + and step > threshold + and malicious_num > 0 + ): + # Compute malicious gradients + perturbed_params = [ + curr_lambda * perturbation_vect[i] for i in range(len(perturbation_vect)) + ] + corrupted_params = [ + params_avg[i] + perturbed_params[i] for i in range(len(params_avg)) + ] + + # Set corrupted clients' updates to corrupted_params + params_c = [ + corrupted_params if states[str(i)] else params[i] + for i in range(len(params)) + ] + distance_matrix = _compute_distances(params_c) + + # Remove from matrix distance_matrix all malicious clients in both + # rows and columns + distance_matrix_b = np.delete( + distance_matrix, + [ + i + for i in range(len(distance_matrix)) + if states[results[i][1].metrics["cid"]] + ], + axis=0, + ) + distance_matrix_b = np.delete( + distance_matrix_b, + [ + i + for i in range(len(distance_matrix)) + if states[results[i][1].metrics["cid"]] + ], + axis=1, + ) + + # Remove from distance_matrix all benign clients on + # rows and all malicious on columns + distance_matrix_m = np.delete( + distance_matrix, + [ + i + for i in range(len(distance_matrix)) + if not states[results[i][1].metrics["cid"]] + ], + axis=0, + ) + distance_matrix_m = np.delete( + distance_matrix_m, + [ + i + for i in range(len(distance_matrix)) + if states[results[i][1].metrics["cid"]] + ], + axis=1, + ) + + # Take the maximum distance between any benign client and any malicious one + max_dist_m = np.max(distance_matrix_m) + + # Take the maximum distance between any two benign clients + max_dist_b = np.max(distance_matrix_b) + + # Compute lambda (best scaling coefficient) + if max_dist_m < max_dist_b: + # Lambda (gamma in the paper) is good. Save and try to increase it + lambda_succ = curr_lambda + curr_lambda = curr_lambda + step * 0.5 + else: + # Lambda is to big, must be reduced to increse the chances of being selected + curr_lambda = curr_lambda - step * 0.5 + step *= 0.5 + + # Compute the final malicious update + perturbation_vect = [ + lambda_succ * perturbation_vect[i] for i in range(len(perturbation_vect)) + ] + corrupted_params = [ + params_avg[i] + perturbation_vect[i] for i in range(len(params_avg)) + ] + corrupted_params = ndarrays_to_parameters(corrupted_params) + for proxy, fitres in ordered_results: + if states[fitres.metrics["cid"]]: + fitres.parameters = corrupted_params + results[int(fitres.metrics["cid"])] = (proxy, fitres) + return results, {} + + +def _krum(results, num_malicious, to_keep, num_closest=None): + """Get the best parameters vector according to the Krum function. + + Output: the best parameters vector. + """ + weights = [w for w, _ in results] # list of weights + distance_matrix = _compute_distances(weights) # matrix of distances + + if not num_closest: + num_closest = ( + len(weights) - num_malicious - 2 + ) # number of closest points to use + if num_closest <= 0: + num_closest = 1 + elif num_closest > len(weights): + num_closest = len(weights) + + closest_indices = _get_closest_indices( + distance_matrix, num_closest + ) # indices of closest points + + scores = [ + np.sum(distance_matrix[i, closest_indices[i]]) + for i in range(len(distance_matrix)) + ] # scores i->j for each i + + best_index = np.argmin(scores) # index of the best score + best_indices = np.argsort(scores)[::-1][ + len(scores) - to_keep : + ] # indices of best scores (multikrum) + return weights[best_index], best_index, best_indices, scores + + +def _compute_distances(weights): + """Compute distances between vectors. + + Input: weights - list of weights vectors + Output: distances - matrix distance_matrix of squared distances between the vectors + """ + flat_w = np.array([np.concatenate(par, axis=None).ravel() for par in weights]) + distance_matrix = np.zeros((len(weights), len(weights))) + for i, _ in enumerate(flat_w): + for j, _ in enumerate(flat_w): + delta = flat_w[i] - flat_w[j] + dist = np.linalg.norm(delta) + distance_matrix[i, j] = dist**2 + return distance_matrix + + +def _get_closest_indices(distance_matrix, num_closest): + """Get the indices of the closest points. + + Args: + distance_matrix + matrix of distances + num_closest + number of closest points to get for each parameter vector + Output: + closest_indices + list of lists of indices of the closest points for each vector. + """ + closest_indices = [] + for idx, _ in enumerate(distance_matrix): + closest_indices.append( + np.argsort(distance_matrix[idx])[1 : num_closest + 1].tolist() + ) + return closest_indices + + +def _fang_corrupt_params(global_model, lamda): + # Compute sign vector num_supporters + magnitude = [] + for i, _ in enumerate(global_model): + magnitude.append(np.sign(global_model[i]) * lamda) + + corrupted_params = [ + global_model[i] - magnitude[i] for i in range(len(global_model)) + ] # corrupted model + return corrupted_params + + +def _fang_corrupt_and_select(all_models, global_model, states, num_corrupted, lamda): + # Check that krum selects a malicious client + corrupted_params = _fang_corrupt_params(global_model, lamda) + all_models_m = [ + (corrupted_params, num_examples) if states[str(i)] else (model, num_examples) + for i, (model, num_examples) in enumerate(all_models) + ] + _, idx_best_model, _, _ = _krum(all_models_m, num_corrupted, 1) + + # Check if the best model is malicious + malicious_selected = states[str(idx_best_model)] + return malicious_selected, corrupted_params diff --git a/baselines/flanders/flanders/client.py b/baselines/flanders/flanders/client.py new file mode 100644 index 000000000000..57513ccf7291 --- /dev/null +++ b/baselines/flanders/flanders/client.py @@ -0,0 +1,174 @@ +"""Clients implementation for Flanders.""" + +from collections import OrderedDict +from pathlib import Path +from typing import Tuple + +import flwr as fl +import numpy as np +import ray +import torch + +from .dataset import get_dataloader, mnist_transformation +from .models import ( + FMnistNet, + MnistNet, + test_fmnist, + test_mnist, + train_fmnist, + train_mnist, +) + +XY = Tuple[np.ndarray, np.ndarray] + + +def get_params(model): + """Get model weights as a list of NumPy ndarrays.""" + return [val.cpu().numpy() for _, val in model.state_dict().items()] + + +def set_params(model, params): + """Set model weights from a list of NumPy ndarrays.""" + params_dict = zip(model.state_dict().keys(), params) + state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict}) + model.load_state_dict(state_dict, strict=True) + + +class MnistClient(fl.client.NumPyClient): + """Implementation of MNIST image classification using PyTorch.""" + + def __init__(self, cid, fed_dir_data): + """Instantiate a client for the MNIST dataset.""" + self.cid = cid + self.fed_dir = Path(fed_dir_data) + self.properties = {"tensor_type": "numpy.ndarray"} + + # Instantiate model + self.net = MnistNet() + + # Determine device + # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + + def get_parameters(self, config): + """Get model parameters as a list of NumPy ndarrays.""" + return get_params(self.net) + + def fit(self, parameters, config): + """Set model parameters from a list of NumPy ndarrays.""" + set_params(self.net, parameters) + + # Load data for this client and get trainloader + num_workers = 1 + trainloader = get_dataloader( + self.fed_dir, + self.cid, + is_train=True, + batch_size=config["batch_size"], + workers=num_workers, + transform=mnist_transformation, + ) + + self.net.to(self.device) + train_mnist(self.net, trainloader, epochs=config["epochs"], device=self.device) + + return ( + get_params(self.net), + len(trainloader.dataset), + {"cid": self.cid, "malicious": config["malicious"]}, + ) + + def evaluate(self, parameters, config): + """Evaluate using local test dataset.""" + set_params(self.net, parameters) + + # Load data for this client and get trainloader + num_workers = len(ray.worker.get_resource_ids()["CPU"]) + valloader = get_dataloader( + self.fed_dir, + self.cid, + is_train=False, + batch_size=50, + workers=num_workers, + transform=mnist_transformation, + ) + + self.net.to(self.device) + loss, accuracy = test_mnist(self.net, valloader, device=self.device) + + return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)} + + +class FMnistClient(fl.client.NumPyClient): + """Implementation of MNIST image classification using PyTorch.""" + + def __init__(self, cid, fed_dir_data): + """Instantiate a client for the MNIST dataset.""" + self.cid = cid + self.fed_dir = Path(fed_dir_data) + self.properties = {"tensor_type": "numpy.ndarray"} + + # Instantiate model + self.net = FMnistNet() + + # Determine device + # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + + def get_parameters(self, config): + """Get model parameters as a list of NumPy ndarrays.""" + return get_params(self.net) + + def fit(self, parameters, config): + """Set model parameters from a list of NumPy ndarrays.""" + set_params(self.net, parameters) + + # Load data for this client and get trainloader + num_workers = 1 + trainloader = get_dataloader( + self.fed_dir, + self.cid, + is_train=True, + batch_size=config["batch_size"], + workers=num_workers, + transform=mnist_transformation, + ) + + self.net.to(self.device) + train_fmnist(self.net, trainloader, epochs=config["epochs"], device=self.device) + + return ( + get_params(self.net), + len(trainloader.dataset), + {"cid": self.cid, "malicious": config["malicious"]}, + ) + + def evaluate(self, parameters, config): + """Evaluate using local test dataset.""" + set_params(self.net, parameters) + + # Load data for this client and get trainloader + num_workers = len(ray.worker.get_resource_ids()["CPU"]) + valloader = get_dataloader( + self.fed_dir, + self.cid, + is_train=False, + batch_size=50, + workers=num_workers, + transform=mnist_transformation, + ) + + self.net.to(self.device) + loss, accuracy = test_fmnist(self.net, valloader, device=self.device) + + return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)} diff --git a/baselines/flanders/flanders/conf/aggregate_fn/bulyan.yaml b/baselines/flanders/flanders/conf/aggregate_fn/bulyan.yaml new file mode 100644 index 000000000000..1361f158daf1 --- /dev/null +++ b/baselines/flanders/flanders/conf/aggregate_fn/bulyan.yaml @@ -0,0 +1,9 @@ +--- +name: bulyan + +aggregate_fn: + function: flwr.server.strategy.aggregate.aggregate_bulyan + parameters: + aggregation_name: aggregate_krum + aggregation_module_name: flwr.server.strategy.aggregate + to_keep: 0 # if 0, normal Krum is applied \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/aggregate_fn/fedavg.yaml b/baselines/flanders/flanders/conf/aggregate_fn/fedavg.yaml new file mode 100644 index 000000000000..826a4163b2eb --- /dev/null +++ b/baselines/flanders/flanders/conf/aggregate_fn/fedavg.yaml @@ -0,0 +1,6 @@ +--- +name: fedavg + +aggregate_fn: + function: flwr.server.strategy.aggregate.aggregate + parameters: {} \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/aggregate_fn/fedmedian.yaml b/baselines/flanders/flanders/conf/aggregate_fn/fedmedian.yaml new file mode 100644 index 000000000000..7bf0a725ab6f --- /dev/null +++ b/baselines/flanders/flanders/conf/aggregate_fn/fedmedian.yaml @@ -0,0 +1,6 @@ +--- +name: fedmedian + +aggregate_fn: + function: flwr.server.strategy.aggregate.aggregate_median + parameters: {} \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/aggregate_fn/krum.yaml b/baselines/flanders/flanders/conf/aggregate_fn/krum.yaml new file mode 100644 index 000000000000..220b93d92b3e --- /dev/null +++ b/baselines/flanders/flanders/conf/aggregate_fn/krum.yaml @@ -0,0 +1,7 @@ +--- +name: krum + +aggregate_fn: + function: flwr.server.strategy.aggregate.aggregate_krum + parameters: + to_keep: 10 \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/aggregate_fn/trimmedmean.yaml b/baselines/flanders/flanders/conf/aggregate_fn/trimmedmean.yaml new file mode 100644 index 000000000000..d2e418fa9738 --- /dev/null +++ b/baselines/flanders/flanders/conf/aggregate_fn/trimmedmean.yaml @@ -0,0 +1,7 @@ +--- +name: trimmedmean + +aggregate_fn: + function: flwr.server.strategy.aggregate.aggregate_trimmed_avg + parameters: + proportiontocut: 0.4 \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/base.yaml b/baselines/flanders/flanders/conf/base.yaml new file mode 100644 index 000000000000..9742d85e2af8 --- /dev/null +++ b/baselines/flanders/flanders/conf/base.yaml @@ -0,0 +1,27 @@ +defaults: + - _self_ + - strategy: fedavg + - aggregate_fn: fedavg + +dataset: mnist + +server: + _target_: flanders.server.EnhancedServer + num_rounds: 100 + pool_size: 100 + warmup_rounds: 2 + sampling: 500 + history_dir: clients_params + magnitude: 10 + threshold: 1e-05 + attack_fn: gaussian + num_malicious: 0 + omniscent: True + noniidness: 0.5 + +server_device: cpu +seed: 33 + +client_resources: + num_cpus: 1 + num_gpus: 0 \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/strategy/bulyan.yaml b/baselines/flanders/flanders/conf/strategy/bulyan.yaml new file mode 100644 index 000000000000..1692d5d4306c --- /dev/null +++ b/baselines/flanders/flanders/conf/strategy/bulyan.yaml @@ -0,0 +1,8 @@ +--- +name: bulyan + +strategy: + _target_: flwr.server.strategy.Bulyan + _recursive_: true + num_malicious_clients: $(server.num_malicious) + to_keep: 0 # Normal Krum is applied \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/strategy/fedavg.yaml b/baselines/flanders/flanders/conf/strategy/fedavg.yaml new file mode 100644 index 000000000000..1be4b0a0cc5b --- /dev/null +++ b/baselines/flanders/flanders/conf/strategy/fedavg.yaml @@ -0,0 +1,5 @@ +--- +name: fedavg + +strategy: + _target_: flwr.server.strategy.FedAvg \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/strategy/fedmedian.yaml b/baselines/flanders/flanders/conf/strategy/fedmedian.yaml new file mode 100644 index 000000000000..d79293f4ca23 --- /dev/null +++ b/baselines/flanders/flanders/conf/strategy/fedmedian.yaml @@ -0,0 +1,5 @@ +--- +name: fedmedian + +strategy: + _target_: flwr.server.strategy.FedMedian \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/strategy/flanders.yaml b/baselines/flanders/flanders/conf/strategy/flanders.yaml new file mode 100644 index 000000000000..0222708dd836 --- /dev/null +++ b/baselines/flanders/flanders/conf/strategy/flanders.yaml @@ -0,0 +1,10 @@ +--- +name: flanders + +strategy: + _target_: flanders.strategy.Flanders + _recursive_: true + num_clients_to_keep: 3 # number of benign local models to filter-out before the aggregation (atm it's set to be pool_size - num_malicious, hard coded in main.py) + maxiter: 100 # number of iterations done by MAR + alpha: 1 + beta: 1 \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/strategy/krum.yaml b/baselines/flanders/flanders/conf/strategy/krum.yaml new file mode 100644 index 000000000000..bc36d37755fa --- /dev/null +++ b/baselines/flanders/flanders/conf/strategy/krum.yaml @@ -0,0 +1,7 @@ +--- +name: krum + +strategy: + _target_: flwr.server.strategy.Krum + num_clients_to_keep: 3 + num_malicious_clients: ${server.num_malicious} \ No newline at end of file diff --git a/baselines/flanders/flanders/conf/strategy/trimmedmean.yaml b/baselines/flanders/flanders/conf/strategy/trimmedmean.yaml new file mode 100644 index 000000000000..561755f82d35 --- /dev/null +++ b/baselines/flanders/flanders/conf/strategy/trimmedmean.yaml @@ -0,0 +1,6 @@ +--- +name: trimmedmean + +strategy: + _target_: flwr.server.strategy.FedTrimmedAvg + beta: 0.2 \ No newline at end of file diff --git a/baselines/flanders/flanders/dataset.py b/baselines/flanders/flanders/dataset.py new file mode 100644 index 000000000000..2c13e80d75c5 --- /dev/null +++ b/baselines/flanders/flanders/dataset.py @@ -0,0 +1,289 @@ +"""Dataset utilities for FL experiments.""" + +# Borrowed from adap/Flower examples + +import shutil +from pathlib import Path +from typing import Any, Callable, Optional, Tuple + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, SubsetRandomSampler +from torchvision import datasets, transforms +from torchvision.datasets import VisionDataset + +from .dataset_preparation import create_lda_partitions + + +class Data(torch.utils.data.Dataset): + """Dataset class.""" + + def __init__(self, X, y): + """Initialize dataset.""" + self.X = torch.from_numpy(X.astype(np.float32)) + self.y = torch.from_numpy(y.astype(np.float32)) + self.len = self.X.shape[0] + + def __getitem__(self, index): + """Return data and label pair.""" + return self.X[index], self.y[index] + + def __len__(self): + """Return size of dataset.""" + return self.len + + +def get_dataset(path_to_data: Path, cid: str, partition: str, transform=None): + """Return TorchVisionFL dataset object.""" + # generate path to cid's data + path_to_data = path_to_data / cid / (partition + ".pt") + + return TorchVisionFL(path_to_data, transform=transform) + + +# pylint: disable=too-many-arguments, too-many-locals +def get_dataloader( + path_to_data: str, + cid: str, + is_train: bool, + batch_size: int, + workers: int, + transform=None, +): + """Generate trainset/valset object and returns appropiate dataloader.""" + partition = "train" if is_train else "val" + dataset = get_dataset(Path(path_to_data), str(cid), partition, transform=transform) + + # we use as number of workers all the cpu cores assigned to this actor + kwargs = {"num_workers": workers, "pin_memory": True, "drop_last": False} + return DataLoader(dataset, batch_size=batch_size, **kwargs) + + +def get_random_id_splits(total: int, val_ratio: float, shuffle: bool = True): + """Random split. + + Split a list of length `total` into two following a (1-val_ratio):val_ratio + partitioning. + + By default the indices are shuffled before creating the split and returning. + """ + if isinstance(total, int): + indices = list(range(total)) + else: + indices = total + + split = int(np.floor(val_ratio * len(indices))) + # print(f"Users left out for validation (ratio={val_ratio}) = {split} ") + if shuffle: + np.random.shuffle(indices) + return indices[split:], indices[:split] + + +# pylint: disable=too-many-arguments, too-many-locals +def do_fl_partitioning( + path_to_dataset, pool_size, alpha, num_classes, val_ratio=0.0, seed=None +): + """Torchvision (e.g. CIFAR-10) datasets using LDA.""" + images, labels = torch.load(path_to_dataset) + idx = np.array(range(len(images))) + dataset = [idx, labels] + partitions, _ = create_lda_partitions( + dataset, + num_partitions=pool_size, + concentration=alpha, + accept_imbalanced=True, + seed=seed, + ) + + # Show label distribution for first partition (purely informative) + partition_zero = partitions[0][1] + hist, _ = np.histogram(partition_zero, bins=list(range(num_classes + 1))) + print( + "Class histogram for 0-th partition" + f"(alpha={alpha}, {num_classes} classes): {hist}" + ) + + # now save partitioned dataset to disk + # first delete dir containing splits (if exists), then create it + splits_dir = path_to_dataset.parent / "federated" + if splits_dir.exists(): + shutil.rmtree(splits_dir) + Path.mkdir(splits_dir, parents=True) + + for idx in range(pool_size): + labels = partitions[idx][1] + image_idx = partitions[idx][0] + imgs = images[image_idx] + + # create dir + Path.mkdir(splits_dir / str(idx)) + + if val_ratio > 0.0: + # split data according to val_ratio + train_idx, val_idx = get_random_id_splits(len(labels), val_ratio) + val_imgs = imgs[val_idx] + val_labels = labels[val_idx] + + with open(splits_dir / str(idx) / "val.pt", "wb") as fil: + torch.save([val_imgs, val_labels], fil) + + # remaining images for training + imgs = imgs[train_idx] + labels = labels[train_idx] + + with open(splits_dir / str(idx) / "train.pt", "wb") as fil: + torch.save([imgs, labels], fil) + + return splits_dir + + +def mnist_transformation(img): + """Return TorchVision transformation for MNIST.""" + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=(0.5,), std=(0.5,)), + ] + )(img) + + +class TorchVisionFL(VisionDataset): + """TorchVision FL class. + + Use this class by either passing a path to a torch file (.pt) containing (data, + targets) or pass the data, targets directly instead. + + This is just a trimmed down version of torchvision.datasets.MNIST. + """ + + def __init__( + self, + path_to_data=None, + data=None, + targets=None, + transform: Optional[Callable] = None, + ) -> None: + """Initialize dataset.""" + path = path_to_data.parent if path_to_data else None + super().__init__(path, transform=transform) + self.transform = transform + + if path_to_data: + # load data and targets (path_to_data points to an specific .pt file) + self.data, self.targets = torch.load(path_to_data) + else: + self.data = data + self.targets = targets + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Return a tuple (data, target).""" + img, target = self.data[index], int(self.targets[index]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + if not isinstance(img, Image.Image): # if not PIL image + if not isinstance(img, np.ndarray): # if torch tensor + img = img.numpy() + + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + """Return length of dataset.""" + return len(self.data) + + +def get_mnist(path_to_data="flanders/datasets_files/mnist/data"): + """Download MNIST dataset.""" + # download dataset and load train set + train_set = datasets.MNIST(root=path_to_data, train=True, download=True) + + # fuse all data splits into a single "training.pt" + data_loc = Path(path_to_data) / "MNIST" + training_data = data_loc / "training.pt" + print("Generating unified MNIST dataset") + torch.save([train_set.data, np.array(train_set.targets)], training_data) + + test_set = datasets.MNIST( + root=path_to_data, train=False, transform=mnist_transformation + ) + + # returns path where training data is and testset + return training_data, test_set + + +def get_fmnist(path_to_data="flanders/datasets_files/fmnist/data"): + """Download FashionMNIST dataset.""" + # download dataset and load train set + train_set = datasets.FashionMNIST(root=path_to_data, train=True, download=True) + + # fuse all data splits into a single "training.pt" + data_loc = Path(path_to_data) / "FashionMNIST" + training_data = data_loc / "training.pt" + print("Generating unified FashionMNIST dataset") + torch.save([train_set.data, np.array(train_set.targets)], training_data) + + test_set = datasets.FashionMNIST( + root=path_to_data, train=False, transform=mnist_transformation + ) + + # returns path where training data is and testset + return training_data, test_set + + +def dataset_partitioner( + dataset: torch.utils.data.Dataset, + batch_size: int, + client_id: int, + number_of_clients: int, + workers: int = 1, +) -> torch.utils.data.DataLoader: + """Make datasets partitions for a specific client_id. + + Parameters + ---------- + dataset: torch.utils.data.Dataset + Dataset to be partitioned into *number_of_clients* subsets. + batch_size: int + Size of mini-batches used by the returned DataLoader. + client_id: int + Unique integer used for selecting a specific partition. + number_of_clients: int + Total number of clients launched during training. + This value dictates the number of partitions to be created. + + Returns + ------- + data_loader: torch.utils.data.Dataset + DataLoader for specific client_id considering number_of_clients partitions. + """ + # Set the seed so we are sure to generate the same global batches + # indices across all clients + np.random.seed(123) + + # Get the data corresponding to this client + dataset_size = len(dataset) + nb_samples_per_clients = dataset_size // number_of_clients + dataset_indices = list(range(dataset_size)) + np.random.shuffle(dataset_indices) + + # Get starting and ending indices w.r.t CLIENT_ID + start_ind = int(client_id) * nb_samples_per_clients + end_ind = start_ind + nb_samples_per_clients + data_sampler = SubsetRandomSampler(dataset_indices[start_ind:end_ind]) + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + sampler=data_sampler, + num_workers=workers, + ) + return data_loader diff --git a/baselines/flanders/flanders/dataset_preparation.py b/baselines/flanders/flanders/dataset_preparation.py new file mode 100644 index 000000000000..3c1cfbe6a5d2 --- /dev/null +++ b/baselines/flanders/flanders/dataset_preparation.py @@ -0,0 +1,490 @@ +# Copyright 2020 Adap GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Commonly used functions for generating partitioned datasets.""" + +# pylint: disable=invalid-name + +from typing import List, Optional, Tuple, Union + +import numpy as np +from numpy.random import BitGenerator, Generator, SeedSequence + +XY = Tuple[np.ndarray, np.ndarray] +XYList = List[XY] +PartitionedDataset = Tuple[XYList, XYList] + + +def float_to_int(i: float) -> int: + """Return float as int but raise if decimal is dropped.""" + if not i.is_integer(): + raise Exception("Cast would drop decimals") + + return int(i) + + +def sort_by_label(x: np.ndarray, y: np.ndarray) -> XY: + """Sort by label. + + Assuming two labels and four examples the resulting label order would be 1,1,2,2 + """ + idx = np.argsort(y, axis=0).reshape((y.shape[0])) + return (x[idx], y[idx]) + + +def sort_by_label_repeating(x: np.ndarray, y: np.ndarray) -> XY: + """Sort by label in repeating groups. + + Assuming two labels and four examples the resulting label order would be 1,2,1,2. + + Create sorting index which is applied to by label sorted x, y + + .. code-block:: python + + # given: + y = [ + 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9 + ] + + # use: + idx = [ + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19 + ] + + # so that y[idx] becomes: + y = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + ] + """ + x, y = sort_by_label(x, y) + + num_example = x.shape[0] + num_class = np.unique(y).shape[0] + idx = ( + np.array(range(num_example), np.int64) + .reshape((num_class, num_example // num_class)) + .transpose() + .reshape(num_example) + ) + + return (x[idx], y[idx]) + + +def split_at_fraction(x: np.ndarray, y: np.ndarray, fraction: float) -> Tuple[XY, XY]: + """Split x, y at a certain fraction.""" + splitting_index = float_to_int(x.shape[0] * fraction) + # Take everything BEFORE splitting_index + x_0, y_0 = x[:splitting_index], y[:splitting_index] + # Take everything AFTER splitting_index + x_1, y_1 = x[splitting_index:], y[splitting_index:] + return (x_0, y_0), (x_1, y_1) + + +def shuffle(x: np.ndarray, y: np.ndarray) -> XY: + """Shuffle x and y.""" + idx = np.random.permutation(len(x)) + return x[idx], y[idx] + + +def partition(x: np.ndarray, y: np.ndarray, num_partitions: int) -> List[XY]: + """Return x, y as list of partitions.""" + return list(zip(np.split(x, num_partitions), np.split(y, num_partitions))) + + +def combine_partitions(xy_list_0: XYList, xy_list_1: XYList) -> XYList: + """Combine two lists of ndarray Tuples into one list.""" + return [ + (np.concatenate([x_0, x_1], axis=0), np.concatenate([y_0, y_1], axis=0)) + for (x_0, y_0), (x_1, y_1) in zip(xy_list_0, xy_list_1) + ] + + +def create_partitions( + unpartitioned_dataset: XY, + iid_fraction: float, + num_partitions: int, +) -> XYList: + """Create partitioned version of a training or test set. + + Currently tested and supported are MNIST and FashionMNIST + """ + x, y = unpartitioned_dataset + + x, y = shuffle(x, y) + x, y = sort_by_label_repeating(x, y) + + (x_0, y_0), (x_1, y_1) = split_at_fraction(x, y, fraction=iid_fraction) + + # Shift in second split of dataset the classes into two groups + x_1, y_1 = _shift(x_1, y_1) + + xy_0_partitions = partition(x_0, y_0, num_partitions) + xy_1_partitions = partition(x_1, y_1, num_partitions) + + xy_partitions = combine_partitions(xy_0_partitions, xy_1_partitions) + + # Adjust x and y shape + return [adjust_xy_shape(xy) for xy in xy_partitions] + + +def create_partitioned_dataset( + keras_dataset: Tuple[XY, XY], + iid_fraction: float, + num_partitions: int, +) -> Tuple[PartitionedDataset, XY]: + """Create partitioned version of keras dataset. + + Currently tested and supported are MNIST and FashionMNIST + """ + xy_train, xy_test = keras_dataset + + xy_train_partitions = create_partitions( + unpartitioned_dataset=xy_train, + iid_fraction=iid_fraction, + num_partitions=num_partitions, + ) + + xy_test_partitions = create_partitions( + unpartitioned_dataset=xy_test, + iid_fraction=iid_fraction, + num_partitions=num_partitions, + ) + + return (xy_train_partitions, xy_test_partitions), adjust_xy_shape(xy_test) + + +def log_distribution(xy_partitions: XYList) -> None: + """Print label distribution for list of paritions.""" + distro = [np.unique(y, return_counts=True) for _, y in xy_partitions] + for d in distro: + print(d) + + +def adjust_xy_shape(xy: XY) -> XY: + """Adjust shape of both x and y.""" + x, y = xy + if x.ndim == 3: + x = adjust_x_shape(x) + if y.ndim == 2: + y = adjust_y_shape(y) + return (x, y) + + +def adjust_x_shape(nda: np.ndarray) -> np.ndarray: + """Turn shape (x, y, z) into (x, y, z, 1).""" + nda_adjusted = np.reshape(nda, (nda.shape[0], nda.shape[1], nda.shape[2], 1)) + return nda_adjusted + + +def adjust_y_shape(nda: np.ndarray) -> np.ndarray: + """Turn shape (x, 1) into (x).""" + nda_adjusted = np.reshape(nda, (nda.shape[0])) + return nda_adjusted + + +def split_array_at_indices( + x: np.ndarray, split_idx: np.ndarray +) -> List[List[np.ndarray]]: + """Split an array `x` into list of elements using starting indices from `split_idx`. + + This function should be used with `unique_indices` from `np.unique()` after + sorting by label. + + Args: + x (np.ndarray): Original array of dimension (N,a,b,c,...) + split_idx (np.ndarray): 1-D array contaning increasing number of + indices to be used as partitions. Initial value must be zero. Last value + must be less than N. + + Returns + ------- + List[List[np.ndarray]]: List of list of samples. + """ + if split_idx.ndim != 1: + raise ValueError("Variable `split_idx` must be a 1-D numpy array.") + if split_idx.dtype != np.int64: + raise ValueError("Variable `split_idx` must be of type np.int64.") + if split_idx[0] != 0: + raise ValueError("First value of `split_idx` must be 0.") + if split_idx[-1] >= x.shape[0]: + raise ValueError( + """Last value in `split_idx` must be less than + the number of samples in `x`.""" + ) + if not np.all(split_idx[:-1] <= split_idx[1:]): + raise ValueError("Items in `split_idx` must be in increasing order.") + + num_splits: int = len(split_idx) + split_idx = np.append(split_idx, x.shape[0]) + + list_samples_split: List[List[np.ndarray]] = [[] for _ in range(num_splits)] + for j in range(num_splits): + tmp_x = x[split_idx[j] : split_idx[j + 1]] # noqa: E203 + for sample in tmp_x: + list_samples_split[j].append(sample) + + return list_samples_split + + +def exclude_classes_and_normalize( + distribution: np.ndarray, exclude_dims: List[bool], eps: float = 1e-5 +) -> np.ndarray: + """Exclude classes from a distribution. + + This function is particularly useful when sampling without replacement. + Classes for which no sample is available have their probabilities are set to 0. + Classes that had probabilities originally set to 0 are incremented with + `eps` to allow sampling from remaining items. + + Args: + distribution (np.array): Distribution being used. + exclude_dims (List[bool]): Dimensions to be excluded. + eps (float, optional): Small value to be addad to non-excluded dimensions. + Defaults to 1e-5. + + Returns + ------- + np.ndarray: Normalized distributions. + """ + if np.any(distribution < 0) or (not np.isclose(np.sum(distribution), 1.0)): + raise ValueError("distribution must sum to 1 and have only positive values.") + + if distribution.size != len(exclude_dims): + raise ValueError( + """Length of distribution must be equal + to the length `exclude_dims`.""" + ) + if eps < 0: + raise ValueError("""The value of `eps` must be positive and small.""") + + distribution[[not x for x in exclude_dims]] += eps + distribution[exclude_dims] = 0.0 + sum_rows = np.sum(distribution) + np.finfo(float).eps + distribution = distribution / sum_rows + + return distribution + + +def sample_without_replacement( + distribution: np.ndarray, + list_samples: List[List[np.ndarray]], + num_samples: int, + empty_classes: List[bool], +) -> Tuple[XY, List[bool]]: + """Sample from a list without replacement using a given distribution. + + Args: + distribution (np.ndarray): Distribution used for sampling. + list_samples(List[List[np.ndarray]]): List of samples. + num_samples (int): Total number of items to be sampled. + empty_classes (List[bool]): List of booleans indicating which classes are empty. + This is useful to differentiate which classes should still be sampled. + + Returns + ------- + XY: Dataset contaning samples + List[bool]: empty_classes. + """ + if np.sum([len(x) for x in list_samples]) < num_samples: + raise ValueError( + """Number of samples in `list_samples` is less than `num_samples`""" + ) + + # Make sure empty classes are not sampled + # and solves for rare cases where + if not empty_classes: + empty_classes = len(distribution) * [False] + + distribution = exclude_classes_and_normalize( + distribution=distribution, exclude_dims=empty_classes + ) + + data: List[np.ndarray] = [] + target: List[np.ndarray] = [] + + for _ in range(num_samples): + sample_class = np.where(np.random.multinomial(1, distribution) == 1)[0][0] + sample: np.ndarray = list_samples[sample_class].pop() + + data.append(sample) + target.append(sample_class) + + # If last sample of the class was drawn, then set the + # probability density function (PDF) to zero for that class. + if len(list_samples[sample_class]) == 0: + empty_classes[sample_class] = True + # Be careful to distinguish between classes that had zero probability + # and classes that are now empty + distribution = exclude_classes_and_normalize( + distribution=distribution, exclude_dims=empty_classes + ) + data_array: np.ndarray = np.concatenate([data], axis=0) + target_array: np.ndarray = np.array(target, dtype=np.int64) + + return (data_array, target_array), empty_classes + + +def get_partitions_distributions(partitions: XYList) -> Tuple[np.ndarray, List[int]]: + """Evaluate the distribution over classes for a set of partitions. + + Args: + partitions (XYList): Input partitions + + Returns + ------- + np.ndarray: Distributions of size (num_partitions, num_classes) + """ + # Get largest available label + labels = set() + for _, y in partitions: + labels.update(set(y)) + list_labels = sorted(labels) + bin_edges = np.arange(len(list_labels) + 1) + + # Pre-allocate distributions + distributions = np.zeros((len(partitions), len(list_labels)), dtype=np.float32) + for idx, (_, _y) in enumerate(partitions): + hist, _ = np.histogram(_y, bin_edges) + distributions[idx] = hist / hist.sum() + + return distributions, list_labels + + +def create_lda_partitions( + dataset: XY, + dirichlet_dist: Optional[np.ndarray] = None, + num_partitions: int = 100, + concentration: Union[float, np.ndarray, List[float]] = 0.5, + accept_imbalanced: bool = False, + seed: Optional[Union[int, SeedSequence, BitGenerator, Generator]] = None, +) -> Tuple[XYList, np.ndarray]: + r"""Create imbalanced non-iid partitions. + + Create imbalanced non-iid partitions using Latent Dirichlet Allocation (LDA) + without resampling. + + Args: + dataset (XY): Dataset containing samples X and labels Y. + dirichlet_dist (numpy.ndarray, optional): previously generated distribution to + be used. This is useful when applying the same distribution for train and + validation sets. + num_partitions (int, optional): Number of partitions to be created. + Defaults to 100. + concentration (float, np.ndarray, List[float]): Dirichlet Concentration + (:math:`\\alpha`) parameter. Set to float('inf') to get uniform partitions. + An :math:`\\alpha \\to \\Inf` generates uniform distributions over classes. + An :math:`\\alpha \\to 0.0` generates one class per client. Defaults to 0.5. + accept_imbalanced (bool): Whether or not to accept imbalanced output classes. + Default False. + seed (None, int, SeedSequence, BitGenerator, Generator): + A seed to initialize the BitGenerator for generating the Dirichlet + distribution. This is defined in Numpy's official documentation as follows: + If None, then fresh, unpredictable entropy will be pulled from the OS. + One may also pass in a SeedSequence instance. + Additionally, when passed a BitGenerator, it will be wrapped by Generator. + If passed a Generator, it will be returned unaltered. + See official Numpy Documentation for further details. + + Returns + ------- + Tuple[XYList, numpy.ndarray]: List of XYList containing partitions + for each dataset and the dirichlet probability density functions. + """ + # pylint: disable=too-many-arguments,too-many-locals + + x, y = dataset + x, y = shuffle(x, y) + x, y = sort_by_label(x, y) + + if (x.shape[0] % num_partitions) and (not accept_imbalanced): + raise ValueError( + """Total number of samples must be a multiple of `num_partitions`. + If imbalanced classes are allowed, set + `accept_imbalanced=True`.""" + ) + + num_samples = num_partitions * [0] + for j in range(x.shape[0]): + num_samples[j % num_partitions] += 1 + + # Get number of classes and verify if they matching with + classes, start_indices = np.unique(y, return_index=True) + + # Make sure that concentration is np.array and + # check if concentration is appropriate + concentration = np.asarray(concentration) + + # Check if concentration is Inf, if so create uniform partitions + partitions: List[XY] = [(_, _) for _ in range(num_partitions)] + if float("inf") in concentration: + partitions = create_partitions( + unpartitioned_dataset=(x, y), + iid_fraction=1.0, + num_partitions=num_partitions, + ) + dirichlet_dist = get_partitions_distributions(partitions)[0] + + return partitions, dirichlet_dist + + if concentration.size == 1: + concentration = np.repeat(concentration, classes.size) + elif concentration.size != classes.size: # Sequence + raise ValueError( + f"The size of the provided concentration ({concentration.size}) ", + f"must be either 1 or equal number of classes {classes.size})", + ) + + # Split into list of list of samples per class + list_samples_per_class: List[List[np.ndarray]] = split_array_at_indices( + x, start_indices + ) + + if dirichlet_dist is None: + dirichlet_dist = np.random.default_rng(seed).dirichlet( + alpha=concentration, size=num_partitions + ) + + if dirichlet_dist.size != 0: + if dirichlet_dist.shape != (num_partitions, classes.size): + raise ValueError( + f"""The shape of the provided dirichlet distribution + ({dirichlet_dist.shape}) must match the provided number + of partitions and classes ({num_partitions},{classes.size})""" + ) + + # Assuming balanced distribution + empty_classes = classes.size * [False] + for partition_id in range(num_partitions): + partitions[partition_id], empty_classes = sample_without_replacement( + distribution=dirichlet_dist[partition_id].copy(), + list_samples=list_samples_per_class, + num_samples=num_samples[partition_id], + empty_classes=empty_classes, + ) + + return partitions, dirichlet_dist + + +def _shift(x: np.ndarray, y: np.ndarray) -> XY: + """Shift data. + + Shift x_1, y_1 so that the first half contains only labels 0 to 4 and the second + half 5 to 9. + """ + x, y = sort_by_label(x, y) + + (x_0, y_0), (x_1, y_1) = split_at_fraction(x, y, fraction=0.5) + (x_0, y_0), (x_1, y_1) = shuffle(x_0, y_0), shuffle(x_1, y_1) + x, y = np.concatenate([x_0, x_1], axis=0), np.concatenate([y_0, y_1], axis=0) + return x, y diff --git a/baselines/flanders/flanders/main.py b/baselines/flanders/flanders/main.py new file mode 100644 index 000000000000..022c38b1ef32 --- /dev/null +++ b/baselines/flanders/flanders/main.py @@ -0,0 +1,279 @@ +"""FLANDERS main scrip.""" + +import importlib +import os +import random +import shutil + +import flwr as fl +import hydra +import numpy as np +import pandas as pd +import torch +from flwr.server.client_manager import SimpleClientManager +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +from .attacks import fang_attack, gaussian_attack, lie_attack, minmax_attack, no_attack +from .client import FMnistClient, MnistClient +from .dataset import do_fl_partitioning, get_fmnist, get_mnist +from .server import EnhancedServer +from .utils import fmnist_evaluate, l2_norm, mnist_evaluate + + +# pylint: disable=too-many-locals, too-many-branches, too-many-statements +@hydra.main(config_path="conf", config_name="base", version_base=None) +def main(cfg: DictConfig) -> None: + """Run the baseline. + + Parameters + ---------- + cfg : DictConfig + An omegaconf object that stores the hydra config. + """ + # 0. Set random seed + seed = cfg.seed + np.random.seed(seed) + np.random.set_state( + np.random.RandomState(seed).get_state() # pylint: disable=no-member + ) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # 1. Print parsed config + print(OmegaConf.to_yaml(cfg)) + + # Skip if: + # - strategy = bulyan and num_malicious > 20 + # - attack_fn != gaussian and num_malicious = 0 + if cfg.strategy.name == "bulyan" and cfg.server.num_malicious > 20: + print( + "Skipping experiment because strategy is bulyan and num_malicious is > 20" + ) + return + # skip if attack_fn is not gaussian and num_malicious is 0, but continue if + # attack_fn is na + if ( + cfg.server.attack_fn != "gaussian" + and cfg.server.num_malicious == 0 + and cfg.server.attack_fn != "na" + ): + print( + "Skipping experiment because attack_fn is not gaussian and " + "num_malicious is 0" + ) + return + + attacks = { + "na": no_attack, + "gaussian": gaussian_attack, + "lie": lie_attack, + "fang": fang_attack, # OPT + "minmax": minmax_attack, # AGR-MM + } + + clients = { + "mnist": (MnistClient, mnist_evaluate), + "fmnist": (FMnistClient, fmnist_evaluate), + } + + # Delete old client_params + if os.path.exists(cfg.server.history_dir): + shutil.rmtree(cfg.server.history_dir) + + dataset_name = cfg.dataset + attack_fn = cfg.server.attack_fn + num_malicious = cfg.server.num_malicious + + # 2. Prepare your dataset + if dataset_name in ["mnist", "fmnist"]: + if dataset_name == "mnist": + train_path, _ = get_mnist() + elif dataset_name == "fmnist": + train_path, _ = get_fmnist() + fed_dir = do_fl_partitioning( + train_path, + pool_size=cfg.server.pool_size, + alpha=cfg.server.noniidness, + num_classes=10, + val_ratio=0.2, + seed=seed, + ) + else: + raise ValueError("Dataset not supported") + + # 3. Define your clients + # pylint: disable=no-else-return + def client_fn(cid: str, dataset_name: str = dataset_name): + client = clients[dataset_name][0] + if dataset_name in ["mnist", "fmnist"]: + return client(cid, fed_dir) + else: + raise ValueError("Dataset not supported") + + # 4. Define your strategy + strategy = None + if cfg.strategy.name == "flanders": + function_path = cfg.aggregate_fn.aggregate_fn.function + module_name, function_name = function_path.rsplit(".", 1) + module = importlib.import_module(module_name, package=__package__) + aggregation_fn = getattr(module, function_name) + + strategy = instantiate( + cfg.strategy.strategy, + evaluate_fn=clients[dataset_name][1], + on_fit_config_fn=fit_config, + fraction_fit=1, + fraction_evaluate=0, + min_fit_clients=cfg.server.pool_size, + min_evaluate_clients=0, + num_clients_to_keep=cfg.server.pool_size - num_malicious, + aggregate_fn=aggregation_fn, + aggregate_parameters=cfg.aggregate_fn.aggregate_fn.parameters, + min_available_clients=cfg.server.pool_size, + window=cfg.server.warmup_rounds, + distance_function=l2_norm, + maxiter=cfg.strategy.strategy.maxiter, + alpha=cfg.strategy.strategy.alpha, + beta=int(cfg.strategy.strategy.beta), + ) + elif cfg.strategy.name == "krum": + strategy = instantiate( + cfg.strategy.strategy, + evaluate_fn=clients[dataset_name][1], + on_fit_config_fn=fit_config, + fraction_fit=1, + fraction_evaluate=0, + min_fit_clients=cfg.server.pool_size, + min_evaluate_clients=0, + num_clients_to_keep=cfg.strategy.strategy.num_clients_to_keep, + min_available_clients=cfg.server.pool_size, + num_malicious_clients=num_malicious, + ) + elif cfg.strategy.name == "fedavg": + strategy = instantiate( + cfg.strategy.strategy, + evaluate_fn=clients[dataset_name][1], + on_fit_config_fn=fit_config, + fraction_fit=1, + fraction_evaluate=0, + min_fit_clients=cfg.server.pool_size, + min_evaluate_clients=0, + min_available_clients=cfg.server.pool_size, + ) + elif cfg.strategy.name == "bulyan": + # Get aggregation rule function + strategy = instantiate( + cfg.strategy.strategy, + evaluate_fn=clients[dataset_name][1], + on_fit_config_fn=fit_config, + fraction_fit=1, + fraction_evaluate=0, + min_fit_clients=cfg.server.pool_size, + min_evaluate_clients=0, + min_available_clients=cfg.server.pool_size, + num_malicious_clients=num_malicious, + to_keep=cfg.strategy.strategy.to_keep, + ) + elif cfg.strategy.name == "trimmedmean": + strategy = instantiate( + cfg.strategy.strategy, + evaluate_fn=clients[dataset_name][1], + on_fit_config_fn=fit_config, + fraction_fit=1, + fraction_evaluate=0, + min_fit_clients=cfg.server.pool_size, + min_evaluate_clients=0, + min_available_clients=cfg.server.pool_size, + beta=cfg.strategy.strategy.beta, + ) + elif cfg.strategy.name == "fedmedian": + strategy = instantiate( + cfg.strategy.strategy, + evaluate_fn=clients[dataset_name][1], + on_fit_config_fn=fit_config, + fraction_fit=1, + fraction_evaluate=0, + min_fit_clients=cfg.server.pool_size, + min_evaluate_clients=0, + min_available_clients=cfg.server.pool_size, + ) + else: + raise ValueError("Strategy not supported") + + # 5. Start Simulation + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.server.pool_size, + client_resources=cfg.client_resources, + server=EnhancedServer( + warmup_rounds=cfg.server.warmup_rounds, + num_malicious=num_malicious, + attack_fn=attacks[attack_fn], # type: ignore + magnitude=cfg.server.magnitude, + client_manager=SimpleClientManager(), + strategy=strategy, + sampling=cfg.server.sampling, + history_dir=cfg.server.history_dir, + dataset_name=dataset_name, + threshold=cfg.server.threshold, + omniscent=cfg.server.omniscent, + ), + config=fl.server.ServerConfig(num_rounds=cfg.server.num_rounds), + strategy=strategy, + ) + + save_path = HydraConfig.get().runtime.output_dir + + rounds, test_loss = zip(*history.losses_centralized) + _, test_accuracy = zip(*history.metrics_centralized["accuracy"]) + _, test_auc = zip(*history.metrics_centralized["auc"]) + _, truep = zip(*history.metrics_centralized["TP"]) + _, truen = zip(*history.metrics_centralized["TN"]) + _, falsep = zip(*history.metrics_centralized["FP"]) + _, falsen = zip(*history.metrics_centralized["FN"]) + + if not os.path.exists(os.path.join(save_path, "outputs")): + os.makedirs(os.path.join(save_path, "outputs")) + path_to_save = [os.path.join(save_path, "results.csv"), "outputs/all_results.csv"] + + for file_name in path_to_save: + data = pd.DataFrame( + { + "round": rounds, + "loss": test_loss, + "accuracy": test_accuracy, + "auc": test_auc, + "TP": truep, + "TN": truen, + "FP": falsep, + "FN": falsen, + "attack_fn": [attack_fn for _ in range(len(rounds))], + "dataset_name": [dataset_name for _ in range(len(rounds))], + "num_malicious": [num_malicious for _ in range(len(rounds))], + "strategy": [cfg.strategy.name for _ in range(len(rounds))], + "aggregate_fn": [ + cfg.aggregate_fn.aggregate_fn.function for _ in range(len(rounds)) + ], + } + ) + if os.path.exists(file_name): + data.to_csv(file_name, mode="a", header=False, index=False) + else: + data.to_csv(file_name, index=False, header=True) + + +# pylint: disable=unused-argument +def fit_config(server_round): + """Return a configuration with static batch size and (local) epochs.""" + config = { + "epochs": 1, # number of local epochs + "batch_size": 32, + } + return config + + +if __name__ == "__main__": + main() diff --git a/baselines/flanders/flanders/models.py b/baselines/flanders/flanders/models.py new file mode 100644 index 000000000000..2fd10f5496d3 --- /dev/null +++ b/baselines/flanders/flanders/models.py @@ -0,0 +1,164 @@ +"""Models for FLANDERS experiments.""" + +import itertools + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score +from sklearn.preprocessing import LabelBinarizer + + +def roc_auc_multiclass(y_true, y_pred): + """Compute the ROC AUC for multiclass classification.""" + l_b = LabelBinarizer() + l_b.fit(y_true) + y_true = l_b.transform(y_true) + y_pred = l_b.transform(y_pred) + return roc_auc_score(y_true, y_pred, multi_class="ovr") + + +class MnistNet(nn.Module): + """Neural network for MNIST classification.""" + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(28 * 28, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + """Forward pass through the network.""" + x = x.view(-1, 28 * 28) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def train_mnist(model, dataloader, epochs, device): + """Train the network on the training set.""" + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + for epoch in range(epochs): + for i, (images, labels) in enumerate(dataloader): + images = images.view(-1, 28 * 28).to(device) + labels = labels.to(device) + + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + if (i + 1) % 100 == 0: + print( + f"Epoch [{epoch+1}/{epochs}], " + f"Step [{i+1}/{len(dataloader)}], " + f"Loss: {loss.item():.4f}" + ) + + +# pylint: disable=too-many-locals +def test_mnist(model, dataloader, device): + """Validate the network on the entire test set.""" + loss = 0 + model.eval() + criterion = nn.CrossEntropyLoss() + y_true, y_pred = [], [] + with torch.no_grad(): + n_correct = 0 + n_samples = 0 + for images, labels in dataloader: + images = images.reshape(-1, 28 * 28).to(device) + labels = labels.to(device) + outputs = model(images) + # max returns (value ,index) + _, predicted = torch.max(outputs.data, 1) + n_samples += labels.size(0) + n_correct += (predicted == labels).sum().item() + loss += criterion(outputs, labels).item() + y_true.append(labels.cpu().numpy()) + y_pred.append(predicted.cpu().numpy()) + y_true = list(itertools.chain(*y_true)) + y_pred = list(itertools.chain(*y_pred)) + auc = roc_auc_multiclass(y_true, y_pred) + acc = n_correct / n_samples + return loss, acc, auc + + +class FMnistNet(nn.Module): + """Neural network for Fashion MNIST classification.""" + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 256) + self.fc2 = nn.Linear(256, 128) + self.fc3 = nn.Linear(128, 64) + self.fc4 = nn.Linear(64, 10) + + # Dropout module with a 0.2 drop probability + self.dropout = nn.Dropout(p=0.2) + + def forward(self, x): + """Forward pass through the network.""" + # Flatten the input tensor + x = x.view(x.shape[0], -1) + # Set the activation functions + x = self.dropout(F.relu(self.fc1(x))) + x = self.dropout(F.relu(self.fc2(x))) + x = self.dropout(F.relu(self.fc3(x))) + x = F.log_softmax(self.fc4(x), dim=1) + + return x + + +def train_fmnist(model, dataloader, epochs, device): + """Train the network on the training set.""" + criterion = nn.NLLLoss(reduction="sum") + optimizer = torch.optim.Adam(model.parameters(), lr=0.003) + + for epoch in range(epochs): + for i, (images, labels) in enumerate(dataloader): + images = images.view(-1, 28 * 28).to(device) + labels = labels.to(device) + + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + if (i + 1) % 100 == 0: + print( + f"Epoch [{epoch+1}/{epochs}], " + f"Step [{i+1}/{len(dataloader)}], " + f"Loss: {loss.item():.4f}" + ) + + +# pylint: disable=too-many-locals +def test_fmnist(model, dataloader, device): + """Validate the network on the entire test set.""" + loss = 0 + model.eval() + criterion = nn.NLLLoss(reduction="sum") + y_true, y_pred = [], [] + with torch.no_grad(): + n_correct = 0 + n_samples = 0 + for images, labels in dataloader: + images = images.reshape(-1, 28 * 28).to(device) + labels = labels.to(device) + outputs = model(images) + # max returns (value ,index) + _, predicted = torch.max(outputs.data, 1) + n_samples += labels.size(0) + n_correct += (predicted == labels).sum().item() + loss += criterion(outputs, labels).item() + y_true.append(labels.cpu().numpy()) + y_pred.append(predicted.cpu().numpy()) + y_true = list(itertools.chain(*y_true)) + y_pred = list(itertools.chain(*y_pred)) + auc = roc_auc_multiclass(y_true, y_pred) + acc = n_correct / n_samples + return loss, acc, auc diff --git a/baselines/flanders/flanders/server.py b/baselines/flanders/flanders/server.py new file mode 100644 index 000000000000..622aa890a966 --- /dev/null +++ b/baselines/flanders/flanders/server.py @@ -0,0 +1,384 @@ +"""Server with enhanced functionality. + +It can be used to simulate an attacker that controls a fraction of the clients and to +save the parameters of each client in its memory. +""" + +import timeit +from logging import DEBUG, INFO +from typing import Any, Callable, Dict, List, Tuple, Union + +import numpy as np +from flwr.common import DisconnectRes, EvaluateRes, FitRes, parameters_to_ndarrays +from flwr.common.logger import log +from flwr.server.client_proxy import ClientProxy +from flwr.server.history import History +from flwr.server.server import Server, fit_clients + +from .strategy import Flanders +from .utils import flatten_params, save_params, update_confusion_matrix + +FitResultsAndFailures = Tuple[ + List[Tuple[ClientProxy, FitRes]], + List[Union[Tuple[ClientProxy, FitRes], BaseException]], +] +EvaluateResultsAndFailures = Tuple[ + List[Tuple[ClientProxy, EvaluateRes]], + List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], +] +ReconnectResultsAndFailures = Tuple[ + List[Tuple[ClientProxy, DisconnectRes]], + List[Union[Tuple[ClientProxy, DisconnectRes], BaseException]], +] + + +class EnhancedServer(Server): + """Server with enhanced functionality.""" + + # pylint: disable=too-many-arguments,too-many-instance-attributes + def __init__( + self, + num_malicious: int, + warmup_rounds: int, + attack_fn: Callable, + dataset_name: str, + *args: Any, + threshold: float = 0.0, + to_keep: int = 1, + magnitude: float = 0.0, + sampling: int = 0, + history_dir: str = "clients_params", + omniscent: bool = True, + **kwargs: Any, + ) -> None: + """Create a new EnhancedServer instance. + + Parameters + ---------- + num_malicious : int + Number of malicious clients + warmup_rounds : int + Number of warmup rounds + attack_fn : Callable + Attack function to be used + dataset_name : str + Name of the dataset + threshold : float, optional + Threshold used by the attacks, by default 0.0 + to_keep : int, optional + Number of clients to keep (i.e., to classify as "good"), by default 1 + magnitude : float, optional + Magnitude of the Gaussian attack, by default 0.0 + sampling : int, optional + Number of parameters to sample, by default 0 + history_dir : str, optional + Directory where to save the parameters, by default "clients_params" + omniscent : bool, optional + Whether to use the omniscent attack, by default True + """ + super().__init__(*args, **kwargs) + self.num_malicious = num_malicious + self.warmup_rounds = warmup_rounds + self.attack_fn = attack_fn + self.sampling = sampling + self.aggregated_parameters: List = [] + self.params_indexes: List = [] + self.history_dir = history_dir + self.dataset_name = dataset_name + self.magnitude = magnitude + self.threshold = threshold + self.to_keep = to_keep + self.omniscent = omniscent + self.malicious_lst: List = [] + self.confusion_matrix = {"TP": 0, "TN": 0, "FP": 0, "FN": 0} + self.clients_state: Dict[str, bool] = {} + self.good_clients_idx: List[int] = [] + self.malicious_clients_idx: List[int] = [] + + # pylint: disable=too-many-locals + def fit(self, num_rounds, timeout): + """Run federated averaging for a number of rounds.""" + history = History() + + # Initialize parameters + log(INFO, "Initializing global parameters") + self.parameters = self._get_initial_parameters(timeout=timeout) + log(INFO, "Evaluating initial parameters") + res = self.strategy.evaluate(0, parameters=self.parameters) + + if res is not None: + log( + INFO, + "initial parameters (loss, other metrics): %s, %s", + res[0], + res[1], + ) + res[1]["TP"] = 0 + res[1]["TN"] = 0 + res[1]["FP"] = 0 + res[1]["FN"] = 0 + history.add_loss_centralized(server_round=0, loss=res[0]) + history.add_metrics_centralized(server_round=0, metrics=res[1]) + + # Run federated learning for num_rounds + log(INFO, "FL starting") + start_time = timeit.default_timer() + + for current_round in range(1, num_rounds + 1): + # Train model and replace previous global model + res_fit = self.fit_round( + server_round=current_round, + timeout=timeout, + ) + if res_fit is not None: + parameters_prime, fit_metrics, _ = res_fit # fit_metrics_aggregated + if parameters_prime: + self.parameters = parameters_prime + history.add_metrics_distributed_fit( + server_round=current_round, metrics=fit_metrics + ) + + # Evaluate model using strategy implementation + res_cen = self.strategy.evaluate(current_round, parameters=self.parameters) + if res_cen is not None: + loss_cen, metrics_cen = res_cen + # Update confusion matrix + if current_round > self.warmup_rounds: + self.confusion_matrix = update_confusion_matrix( + self.confusion_matrix, + self.clients_state, + self.malicious_clients_idx, + self.good_clients_idx, + ) + + for key, val in self.confusion_matrix.items(): + metrics_cen[key] = val + + log( + INFO, + "fit progress: (%s, %s, %s, %s)", + current_round, + loss_cen, + metrics_cen, + timeit.default_timer() - start_time, + ) + history.add_loss_centralized(server_round=current_round, loss=loss_cen) + history.add_metrics_centralized( + server_round=current_round, metrics=metrics_cen + ) + + # Evaluate model on a sample of available clients + res_fed = self.evaluate_round(server_round=current_round, timeout=timeout) + if res_fed is not None: + loss_fed, evaluate_metrics_fed, _ = res_fed + if loss_fed is not None: + history.add_loss_distributed( + server_round=current_round, loss=loss_fed + ) + history.add_metrics_distributed( + server_round=current_round, metrics=evaluate_metrics_fed + ) + + # Bookkeeping + end_time = timeit.default_timer() + elapsed = end_time - start_time + log(INFO, "FL finished in %s", elapsed) + return history + + # pylint: disable-msg=R0915 + def fit_round( + self, + server_round, + timeout, + ): + # pylint: disable-msg=R0912 + """Perform a single round of federated learning.""" + # Get clients and their respective instructions from strategy + client_instructions = self.strategy.configure_fit( + server_round=server_round, + parameters=self.parameters, + client_manager=self._client_manager, + ) + + if not client_instructions: + log(INFO, "fit_round %s: no clients selected, cancel", server_round) + return None + log( + DEBUG, + "fit_round %s: strategy sampled %s clients (out of %s)", + server_round, + len(client_instructions), + self._client_manager.num_available(), + ) + + # Randomly decide which client is malicious + size = self.num_malicious + if server_round <= self.warmup_rounds: + size = 0 + log(INFO, "Selecting %s malicious clients", size) + self.malicious_lst = np.random.choice( + [proxy.cid for proxy, _ in client_instructions], size=size, replace=False + ) + + # Create dict clients_state to keep track of malicious clients + # and send the information to the clients + clients_state = {} + for _, (proxy, ins) in enumerate(client_instructions): + clients_state[proxy.cid] = False + ins.config["malicious"] = False + if proxy.cid in self.malicious_lst: + clients_state[proxy.cid] = True + ins.config["malicious"] = True + + # Sort clients states + clients_state = {k: clients_state[k] for k in sorted(clients_state)} + log( + DEBUG, + "fit_round %s: malicious clients selected %s, clients_state %s", + server_round, + self.malicious_lst, + clients_state, + ) + + # Collect `fit` results from all clients participating in this round + results, failures = fit_clients( + client_instructions=client_instructions, + max_workers=self.max_workers, + timeout=timeout, + ) + log( + DEBUG, + "fit_round %s received %s results and %s failures", + server_round, + len(results), + len(failures), + ) + + # Save parameters of each client as time series + ordered_results = [0 for _ in range(len(results))] + for proxy, fitres in results: + params = flatten_params(parameters_to_ndarrays(fitres.parameters)) + if self.sampling > 0: + # if the sampling number is greater than the number of + # parameters, just sample all of them + self.sampling = min(self.sampling, len(params)) + if len(self.params_indexes) == 0: + # Sample a random subset of parameters + self.params_indexes = np.random.randint( + 0, len(params), size=self.sampling + ) + + params = params[self.params_indexes] + + save_params(params, fitres.metrics["cid"], params_dir=self.history_dir) + + # Re-arrange results in the same order as clients' cids impose + ordered_results[int(fitres.metrics["cid"])] = (proxy, fitres) + + log(INFO, "Clients state: %s", clients_state) + + # Initialize aggregated_parameters if it is the first round + if self.aggregated_parameters == []: + for key, val in clients_state.items(): + if val is False: + self.aggregated_parameters = parameters_to_ndarrays( + ordered_results[int(key)][1].parameters + ) + break + + # Apply attack function + # the server simulates an attacker that controls a fraction of the clients + if self.attack_fn is not None and server_round > self.warmup_rounds: + log(INFO, "Applying attack function") + results, _ = self.attack_fn( + ordered_results, + clients_state, + omniscent=self.omniscent, + magnitude=self.magnitude, + w_re=self.aggregated_parameters, + threshold=self.threshold, + d=len(self.aggregated_parameters), + dataset_name=self.dataset_name, + to_keep=self.to_keep, + malicious_num=self.num_malicious, + num_layers=len(self.aggregated_parameters), + ) + + # Update saved parameters time series after the attack + for _, fitres in results: + if clients_state[fitres.metrics["cid"]]: + if self.sampling > 0: + params = flatten_params( + parameters_to_ndarrays(fitres.parameters) + )[self.params_indexes] + else: + params = flatten_params( + parameters_to_ndarrays(fitres.parameters) + ) + log( + INFO, + "Saving parameters of client %s with shape %s after the attack", + fitres.metrics["cid"], + params.shape, + ) + save_params( + params, + fitres.metrics["cid"], + params_dir=self.history_dir, + remove_last=True, + ) + else: + results = ordered_results + + # Aggregate training results + log(INFO, "fit_round - Aggregating training results") + good_clients_idx = [] + malicious_clients_idx = [] + aggregated_result = self.strategy.aggregate_fit(server_round, results, failures) + if isinstance(self.strategy, Flanders): + parameters_aggregated, metrics_aggregated = aggregated_result + malicious_clients_idx = metrics_aggregated["malicious_clients_idx"] + good_clients_idx = metrics_aggregated["good_clients_idx"] + + log(INFO, "Malicious clients: %s", malicious_clients_idx) + + log(INFO, "clients_state: %s", clients_state) + + # For clients detected as malicious, replace the last params in + # their history with tha current global model, otherwise the + # forecasting in next round won't be reliable (see the paper for + # more details) + if server_round > self.warmup_rounds: + log(INFO, "Saving parameters of clients") + for idx in malicious_clients_idx: + if self.sampling > 0: + new_params = flatten_params( + parameters_to_ndarrays(parameters_aggregated) + )[self.params_indexes] + else: + new_params = flatten_params( + parameters_to_ndarrays(parameters_aggregated) + ) + + log( + INFO, + "Saving parameters of client %s with shape %s", + idx, + new_params.shape, + ) + save_params( + new_params, + idx, + params_dir=self.history_dir, + remove_last=True, + rrl=False, + ) + else: + # Aggregate training results + log(INFO, "fit_round - Aggregating training results") + parameters_aggregated, metrics_aggregated = aggregated_result + + self.clients_state = clients_state + self.good_clients_idx = good_clients_idx + self.malicious_clients_idx = malicious_clients_idx + return parameters_aggregated, metrics_aggregated, (results, failures) diff --git a/baselines/flanders/flanders/strategy.py b/baselines/flanders/flanders/strategy.py new file mode 100644 index 000000000000..36dbc1182653 --- /dev/null +++ b/baselines/flanders/flanders/strategy.py @@ -0,0 +1,375 @@ +"""FLANDERS strategy.""" + +import importlib +import typing +from logging import INFO, WARNING +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +from flwr.common import ( + FitIns, + FitRes, + MetricsAggregationFn, + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy.aggregate import aggregate +from flwr.server.strategy.fedavg import FedAvg + +from .utils import load_all_time_series + +WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """ +Setting `min_available_clients` lower than `min_fit_clients` or +`min_evaluate_clients` can cause the server to fail when there are too few clients +connected to the server. `min_available_clients` must be set to a value larger +than or equal to the values of `min_fit_clients` and `min_evaluate_clients`. +""" + + +class Flanders(FedAvg): + """Aggregation function based on MAR. + + Take a look at the paper for more details about the parameters. + """ + + # pylint: disable=too-many-arguments,too-many-instance-attributes, too-many-locals + def __init__( + self, + fraction_fit: float = 1.0, + fraction_evaluate: float = 1.0, + min_fit_clients: int = 2, + min_evaluate_clients: int = 2, + min_available_clients: int = 2, + evaluate_fn: Optional[ + Callable[ + [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]], + ] + ] = None, + on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + accept_failures: bool = True, + initial_parameters: Optional[Parameters] = None, + fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + num_clients_to_keep: int = 1, + aggregate_fn: Callable = aggregate, + aggregate_parameters: Optional[Dict[str, Scalar]] = None, + window: int = 0, + maxiter: int = 100, + alpha: float = 1, + beta: float = 1, + distance_function=None, + ) -> None: + """Initialize FLANDERS. + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during the fit phase, by default 1.0 + fraction_evaluate : float, optional + Fraction of clients used during the evaluate phase, by default 1.0 + min_fit_clients : int, optional + Minimum number of clients used during the fit phase, by default 2 + min_evaluate_clients : int, optional + Minimum number of clients used during the evaluate phase, by + default 2 + min_available_clients : int, optional + Minimum number of clients available for training and evaluation, by + default 2 + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]]]], optional + Evaluation function, by default None + on_fit_config_fn : Optional[Callable[[int], Dict[str, Scalar]]], + optional + Function to generate the config fed to the clients during the fit + phase, by default None + on_evaluate_config_fn : Optional[Callable[[int], Dict[str, Scalar]]], + optional + Function to generate the config fed to the clients during the + evaluate phase, by default None + accept_failures : bool, optional + Whether to accept failures from clients, by default True + initial_parameters : Optional[Parameters], optional + Initial model parameters, by default None + fit_metrics_aggregation_fn : Optional[MetricsAggregationFn], optional + Function to aggregate metrics during the fit phase, by default None + evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn], + optional + Function to aggregate metrics during the evaluate phase, by default + None + num_clients_to_keep : int, optional + Number of clients to keep (i.e., to classify as "good"), by default + 1 + aggregate_fn : Callable[[List[Tuple[NDArrays, int]]], NDArrays], + optional + Function to aggregate the parameters, by default FedAvg + window : int, optional + Sliding window size used as a "training set" of MAR, by default 0 + maxiter : int, optional + Maximum number of iterations of MAR, by default 100 + alpha : float, optional + Alpha parameter (regularization), by default 1 + beta : float, optional + Beta parameter (regularization), by default 1 + distance_function : Callable, optional + Distance function used to compute the distance between predicted + params and real ones, by default None + """ + super().__init__( + fraction_fit=fraction_fit, + fraction_evaluate=fraction_evaluate, + min_fit_clients=min_fit_clients, + min_evaluate_clients=min_evaluate_clients, + min_available_clients=min_available_clients, + evaluate_fn=evaluate_fn, + on_fit_config_fn=on_fit_config_fn, + on_evaluate_config_fn=on_evaluate_config_fn, + accept_failures=accept_failures, + initial_parameters=initial_parameters, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + ) + self.num_clients_to_keep = num_clients_to_keep + self.window = window + self.maxiter = maxiter + self.alpha = alpha + self.beta = beta + self.params_indexes = None + self.distance_function = distance_function + self.aggregate_fn = aggregate_fn + self.aggregate_parameters = aggregate_parameters + if self.aggregate_parameters is None: + self.aggregate_parameters = {} + + @typing.no_type_check + def configure_fit( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + # Sample clients + sample_size, min_num_clients = self.num_fit_clients( + client_manager.num_available() + ) + + # Custom FitIns object for each client + fit_ins_list = [ + FitIns( + parameters, + ( + {} + if not self.on_fit_config_fn + else self.on_fit_config_fn(server_round) + ), + ) + for _ in range(sample_size) + ] + + clients = client_manager.sample( + num_clients=sample_size, min_num_clients=min_num_clients + ) + + # Return client/config pairs + result = [] + for client, fit in zip(clients, fit_ins_list): + result.append((client, fit)) + return result + + # pylint: disable=too-many-locals,too-many-statements + @typing.no_type_check + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Apply MAR forecasting to exclude malicious clients from FedAvg. + + Parameters + ---------- + server_round : int + Current server round. + results : List[Tuple[ClientProxy, FitRes]] + List of results from the clients. + failures : List[Union[Tuple[ClientProxy, FitRes], BaseException]] + List of failures from the clients. + + Returns + ------- + parameters_aggregated: Optional[Parameters] + Aggregated parameters. + metrics_aggregated: Dict[str, Scalar] + Aggregated metrics. + malicious_clients_idx: List[int] + List of malicious clients' cids (indexes). + """ + good_clients_idx = [] + malicious_clients_idx = [] + if server_round > 1: + if server_round < self.window: + self.window = server_round + params_tensor = load_all_time_series( + params_dir="clients_params", window=self.window + ) + params_tensor = np.transpose( + params_tensor, (0, 2, 1) + ) # (clients, params, time) + ground_truth = params_tensor[:, :, -1].copy() + pred_step = 1 + log(INFO, "Computing MAR on params_tensor %s", params_tensor.shape) + predicted_matrix = mar( + params_tensor[:, :, :-1], + pred_step, + maxiter=self.maxiter, + alpha=self.alpha, + beta=self.beta, + ) + + log(INFO, "Computing anomaly scores") + anomaly_scores = self.distance_function( + ground_truth, predicted_matrix[:, :, 0] + ) + log(INFO, "Anomaly scores: %s", anomaly_scores) + + log(INFO, "Selecting good clients") + good_clients_idx = sorted( + np.argsort(anomaly_scores)[: self.num_clients_to_keep] + ) # noqa + malicious_clients_idx = sorted( + np.argsort(anomaly_scores)[self.num_clients_to_keep :] + ) # noqa + + avg_anomaly_score_gc = np.mean(anomaly_scores[good_clients_idx]) + log( + INFO, "Average anomaly score for good clients: %s", avg_anomaly_score_gc + ) + + avg_anomaly_score_m = np.mean(anomaly_scores[malicious_clients_idx]) + log( + INFO, + "Average anomaly score for malicious clients: %s", + avg_anomaly_score_m, + ) + + results = np.array(results)[good_clients_idx].tolist() + log(INFO, "Good clients: %s", good_clients_idx) + + log(INFO, "Applying aggregate_fn") + # Convert results + weights_results = [ + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for _, fit_res in results + ] + + # Check that self.aggregate_fn has num_malicious parameter + if "num_malicious" in self.aggregate_fn.__code__.co_varnames: + # Count the number of malicious clients in + # good_clients_idx by checking FitRes + clients_state = { + str(fit_res.metrics["cid"]): fit_res.metrics["malicious"] + for _, fit_res in results + } + num_malicious = sum([clients_state[str(cid)] for cid in good_clients_idx]) + log( + INFO, + "Number of malicious clients in good_clients_idx after filtering: %s", + num_malicious, + ) + self.aggregate_parameters["num_malicious"] = num_malicious + + if "aggregation_rule" in self.aggregate_fn.__code__.co_varnames: + module = importlib.import_module( + self.aggregate_parameters["aggregation_module_name"] + ) + function_name = self.aggregate_parameters["aggregation_name"] + self.aggregate_parameters["aggregation_rule"] = getattr( + module, function_name + ) + # Remove aggregation_module_name and aggregation_name + # from self.aggregate_parameters + aggregate_parameters = self.aggregate_parameters.copy() + del aggregate_parameters["aggregation_module_name"] + del aggregate_parameters["aggregation_name"] + try: + parameters_aggregated = ndarrays_to_parameters( + self.aggregate_fn(weights_results, **aggregate_parameters) + ) + except ValueError as err: + log(WARNING, "Error in aggregate_fn: %s", err) + parameters_aggregated = ndarrays_to_parameters( + aggregate(weights_results) + ) + else: + parameters_aggregated = ndarrays_to_parameters( + self.aggregate_fn(weights_results, **self.aggregate_parameters) + ) + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") + + # Add good_clients_idx and malicious_clients_idx to metrics_aggregated + metrics_aggregated["good_clients_idx"] = good_clients_idx + metrics_aggregated["malicious_clients_idx"] = malicious_clients_idx + + return parameters_aggregated, metrics_aggregated + + +# pylint: disable=too-many-locals, too-many-arguments, invalid-name +def mar(X, pred_step, alpha=1, beta=1, maxiter=100): + """Forecast the next tensor of params. + + Forecast the next tensor of params by using MAR algorithm. + + Code provided by Xinyu Chen at: + https://towardsdatascience.com/ matrix-autoregressive-model-for-multidimensional- + time-series-forecasting-6a4d7dce5143 + + With some modifications. + """ + m, n, T = X.shape + start = 0 + + A = np.random.randn(m, m) + B = np.random.randn(n, n) + X_norm = (X - np.min(X)) / np.max(X) + + for _ in range(maxiter): + temp0 = B.T @ B + temp1 = np.zeros((m, m)) + temp2 = np.zeros((m, m)) + identity_m = np.identity(m) + + for t in range(start, T): + temp1 += X_norm[:, :, t] @ B @ X_norm[:, :, t - 1].T + temp2 += X_norm[:, :, t - 1] @ temp0 @ X_norm[:, :, t - 1].T + + temp2 += alpha * identity_m + A = temp1 @ np.linalg.inv(temp2) + + temp0 = A.T @ A + temp1 = np.zeros((n, n)) + temp2 = np.zeros((n, n)) + identity_n = np.identity(n) + + for t in range(start, T): + temp1 += X_norm[:, :, t].T @ A @ X_norm[:, :, t - 1] + temp2 += X_norm[:, :, t - 1].T @ temp0 @ X_norm[:, :, t - 1] + + temp2 += beta * identity_n + B = temp1 @ np.linalg.inv(temp2) + + tensor = np.append(X, np.zeros((m, n, pred_step)), axis=2) + for s in range(pred_step): + tensor[:, :, T + s] = A @ tensor[:, :, T + s - 1] @ B.T + return tensor[:, :, -pred_step:] diff --git a/baselines/flanders/flanders/utils.py b/baselines/flanders/flanders/utils.py new file mode 100644 index 000000000000..619e685e51cd --- /dev/null +++ b/baselines/flanders/flanders/utils.py @@ -0,0 +1,182 @@ +"""Collection of help functions needed by the strategies.""" + +import os +from threading import Lock +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +from flwr.common import NDArrays, Parameters, Scalar, parameters_to_ndarrays +from natsort import natsorted +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.datasets import MNIST, FashionMNIST + +from .client import set_params +from .models import FMnistNet, MnistNet, test_fmnist, test_mnist + +lock = Lock() + + +def l2_norm(true_matrix, predicted_matrix): + """Compute the l2 norm between two matrices. + + Parameters + ---------- + true_matrix : ndarray + The true matrix. + predicted_matrix : ndarray + The predicted matrix by MAR. + + Returns + ------- + anomaly_scores : ndarray + 1-d array of anomaly scores. + """ + delta = np.subtract(true_matrix, predicted_matrix) + anomaly_scores = np.sum(delta**2, axis=-1) ** (1.0 / 2) + return anomaly_scores + + +def save_params( + parameters, cid, params_dir="clients_params", remove_last=False, rrl=False +): + """Save parameters in a file. + + Args: + - parameters (ndarray): decoded parameters to append at the end of the file + - cid (int): identifier of the client + - remove_last (bool): + if True, remove the last saved parameters and replace with "parameters" + - rrl (bool): + if True, remove the last saved parameters and replace with the ones + saved before this round. + """ + new_params = parameters + # Save parameters in clients_params/cid_params + path_file = f"{params_dir}/{cid}_params.npy" + if os.path.exists(params_dir) is False: + os.mkdir(params_dir) + if os.path.exists(path_file): + # load old parameters + old_params = np.load(path_file, allow_pickle=True) + if remove_last: + old_params = old_params[:-1] + if rrl: + new_params = old_params[-1] + # add new parameters + new_params = np.vstack((old_params, new_params)) + + # save parameters + np.save(path_file, new_params) + + +def load_all_time_series(params_dir="clients_params", window=0): + """Load all time series. + + Load all time series in order to have a tensor of shape (m,T,n) + where: + - T := time; + - m := number of clients; + - n := number of parameters. + """ + files = os.listdir(params_dir) + files = natsorted(files) + data = [] + for file in files: + data.append(np.load(os.path.join(params_dir, file), allow_pickle=True)) + + return np.array(data)[:, -window:, :] + + +def flatten_params(params): + """Transform a list of (layers-)parameters into a single vector of shape (n).""" + return np.concatenate(params, axis=None).ravel() + + +# pylint: disable=unused-argument +def evaluate_aggregated( + evaluate_fn: Optional[ + Callable[[int, NDArrays, Dict[str, Scalar]], Tuple[float, Dict[str, Scalar]]] + ], + server_round: int, + parameters: Parameters, +): + """Evaluate model parameters using an evaluation function.""" + if evaluate_fn is None: + # No evaluation function provided + return None + parameters_ndarrays = parameters_to_ndarrays(parameters) + eval_res = evaluate_fn(server_round, parameters_ndarrays, {}) + if eval_res is None: + return None + loss, metrics = eval_res + + return loss, metrics + + +# pylint: disable=unused-argument +def mnist_evaluate(server_round: int, parameters: NDArrays, config: Dict[str, Scalar]): + """Evaluate MNIST model on the test set.""" + # determine device + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + model = MnistNet() + set_params(model, parameters) + model.to(device) + + testset = MNIST("", train=False, download=True, transform=transforms.ToTensor()) + testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=1) + loss, accuracy, auc = test_mnist(model, testloader, device=device) + + return loss, {"accuracy": accuracy, "auc": auc} + + +# pylint: disable=unused-argument +def fmnist_evaluate(server_round: int, parameters: NDArrays, config: Dict[str, Scalar]): + """Evaluate MNIST model on the test set.""" + # determine device + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + model = FMnistNet() + set_params(model, parameters) + model.to(device) + + testset = FashionMNIST( + "", train=False, download=True, transform=transforms.ToTensor() + ) + testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=1) + loss, accuracy, auc = test_fmnist(model, testloader, device=device) + + return loss, {"accuracy": accuracy, "auc": auc} + + +def update_confusion_matrix( + confusion_matrix: Dict[str, int], + clients_states: Dict[str, bool], + malicious_clients_idx: List, + good_clients_idx: List, +): + """Update TN, FP, FN, TP of confusion matrix.""" + for client_idx, client_state in clients_states.items(): + if int(client_idx) in malicious_clients_idx: + if client_state: + confusion_matrix["TP"] += 1 + else: + confusion_matrix["FP"] += 1 + elif int(client_idx) in good_clients_idx: + if client_state: + confusion_matrix["FN"] += 1 + else: + confusion_matrix["TN"] += 1 + return confusion_matrix diff --git a/baselines/flanders/plotting/FLANDERS_results.ipynb b/baselines/flanders/plotting/FLANDERS_results.ipynb new file mode 100644 index 000000000000..4f3fdcc9b0d8 --- /dev/null +++ b/baselines/flanders/plotting/FLANDERS_results.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"markdown","metadata":{"id":"Cg37xeuu7Xy5"},"source":["# Preliminaries"]},{"cell_type":"code","execution_count":92,"metadata":{"id":"J_Dh3sGVyb2w"},"outputs":[],"source":["import pandas as pd\n","from natsort import natsorted\n","import matplotlib.pyplot as plt"]},{"cell_type":"code","execution_count":93,"metadata":{"id":"FjlCyr_B8OdT"},"outputs":[],"source":["results_dir = \"../outputs/\""]},{"cell_type":"markdown","metadata":{"id":"VX2oCpZf7Z7y"},"source":["# Prepare data"]},{"cell_type":"markdown","metadata":{"id":"P_3Z05w0wvNB"},"source":["## Utils"]},{"cell_type":"code","execution_count":94,"metadata":{},"outputs":[],"source":["def divide_results_by_dataset(results_dir, file=\"all_results.csv\"):\n"," \"\"\"Divide csv results into multiple files distinguished by dataset and if strategy is FLANDERS or not (e.g., all_results_mnist_flanders and all_results_mnist_no_flanders).\"\"\"\n"," results = pd.read_csv(results_dir + file, float_precision='round_trip')\n"," datasets = natsorted(results[\"dataset_name\"].unique())\n"," for dataset in datasets:\n"," flanders = results[(results[\"dataset_name\"] == dataset) & (results[\"strategy\"] == \"flanders\")]\n"," no_flanders = results[(results[\"dataset_name\"] == dataset) & (results[\"strategy\"] != \"flanders\")]\n"," flanders.to_csv(results_dir + \"all_results_\" + dataset + \"_flanders.csv\", index=False)\n"," no_flanders.to_csv(results_dir + \"all_results_\" + dataset + \"_no_flanders.csv\", index=False)\n"," "]},{"cell_type":"code","execution_count":95,"metadata":{"id":"fZSDCuT497HV"},"outputs":[],"source":["def print_unique_data(results_df):\n"," for col in [\"attack_fn\", \"num_malicious\", \"dataset_name\", \"strategy\", \"aggregate_fn\"]:\n"," print(f\"Unique values in {col}: {results_df[col].unique()}\")"]},{"cell_type":"code","execution_count":96,"metadata":{"id":"8GcIZNuu8q5Y"},"outputs":[],"source":["def translate_cols(df, attack_dict, dataset_dict, strategy_dict, aggregate_dict):\n"," column_names = [\"attack_fn\", \"dataset_name\", \"strategy\", \"aggregate_fn\"]\n"," for idx, d in enumerate([attack_dict, dataset_dict, strategy_dict, aggregate_dict]):\n"," df[column_names[idx]] = df[column_names[idx]].replace(d)\n"," return df"]},{"cell_type":"code","execution_count":97,"metadata":{"id":"oHcF2pl8sdOG"},"outputs":[],"source":["attack_dict = {\n"," \"gaussian\": \"GAUSS\",\n"," \"lie\": \"LIE\",\n"," \"fang\": \"OPT\",\n"," \"minmax\": \"AGR-MM\",\n"," \"adaptive\": \"MAR-ATK\"\n","}\n","\n","dataset_dict = {\n"," \"mnist\": \"MNIST\",\n"," \"fmnist\": \"FMNIST\",\n"," \"cifar\": \"CIFAR-10\",\n"," \"cifar100\": \"CIFAR-100\"\n","}\n","\n","strategy_dict = {\n"," \"flanders\": \"FLANDERS\",\n"," \"fedavg\": \"FedAvg\",\n"," \"fedmedian\": \"FedMedian\",\n"," \"trimmedmean\": \"TrimmedMean\",\n"," \"bulyan\": \"Bulyan\",\n"," \"krum\": \"MultiKrum\",\n"," \"fldetector\": \"FLDetector\"\n","}\n","\n","aggregate_dict = {\n"," \"flwr.server.strategy.aggregate.aggregate\": \"FedAvg\",\n"," \"flwr.server.strategy.aggregate.aggregate_median\": \"FedMedian\",\n"," \"flwr.server.strategy.aggregate.aggregate_trimmed_avg\": \"TrimmedMean\",\n"," \"flwr.server.strategy.aggregate.aggregate_bulyan\": \"Bulyan\",\n"," \"flwr.server.strategy.aggregate.aggregate_krum\": \"MultiKrum\"\n","}"]},{"cell_type":"code","execution_count":98,"metadata":{},"outputs":[],"source":["divide_results_by_dataset(results_dir)"]},{"cell_type":"markdown","metadata":{"id":"y0XCCkuhwydB"},"source":["## MNIST"]},{"cell_type":"markdown","metadata":{"id":"NG2-2cpnyjkY"},"source":["### Use this shortcut"]},{"cell_type":"code","execution_count":99,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":443},"executionInfo":{"elapsed":244,"status":"ok","timestamp":1716376729975,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"lY85nEb6yrXu","outputId":"5439a3bc-684f-492f-b615-53e2252cd94c"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
roundlossaccuracyaucTPTNFPFNattack_fndataset_namenum_maliciousstrategyaggregate_fn
00720.8331210.10450.5034220000GAUSSMNIST0FLANDERSFedAvg
11664.5222520.20890.5631160000GAUSSMNIST0FLANDERSFedAvg
22624.6338260.35600.6447310000GAUSSMNIST0FLANDERSFedAvg
33581.4764720.47730.710941010000GAUSSMNIST0FLANDERSFedAvg
44545.1142050.54300.746970020000GAUSSMNIST0FLANDERSFedAvg
..........................................
356546724.2817250.10250.5004790000AGR-MMMNIST80dncFedAvg
356647724.3501380.10240.5004210000AGR-MMMNIST80dncFedAvg
356748724.5352630.10250.5004790000AGR-MMMNIST80dncFedAvg
356849724.5888810.10280.5005980000AGR-MMMNIST80dncFedAvg
356950724.7838510.10280.5006020000AGR-MMMNIST80dncFedAvg
\n","

7548 rows × 13 columns

\n","
"],"text/plain":[" round loss accuracy auc TP TN FP FN attack_fn \\\n","0 0 720.833121 0.1045 0.503422 0 0 0 0 GAUSS \n","1 1 664.522252 0.2089 0.563116 0 0 0 0 GAUSS \n","2 2 624.633826 0.3560 0.644731 0 0 0 0 GAUSS \n","3 3 581.476472 0.4773 0.710941 0 100 0 0 GAUSS \n","4 4 545.114205 0.5430 0.746970 0 200 0 0 GAUSS \n","... ... ... ... ... .. ... .. .. ... \n","3565 46 724.281725 0.1025 0.500479 0 0 0 0 AGR-MM \n","3566 47 724.350138 0.1024 0.500421 0 0 0 0 AGR-MM \n","3567 48 724.535263 0.1025 0.500479 0 0 0 0 AGR-MM \n","3568 49 724.588881 0.1028 0.500598 0 0 0 0 AGR-MM \n","3569 50 724.783851 0.1028 0.500602 0 0 0 0 AGR-MM \n","\n"," dataset_name num_malicious strategy aggregate_fn \n","0 MNIST 0 FLANDERS FedAvg \n","1 MNIST 0 FLANDERS FedAvg \n","2 MNIST 0 FLANDERS FedAvg \n","3 MNIST 0 FLANDERS FedAvg \n","4 MNIST 0 FLANDERS FedAvg \n","... ... ... ... ... \n","3565 MNIST 80 dnc FedAvg \n","3566 MNIST 80 dnc FedAvg \n","3567 MNIST 80 dnc FedAvg \n","3568 MNIST 80 dnc FedAvg \n","3569 MNIST 80 dnc FedAvg \n","\n","[7548 rows x 13 columns]"]},"execution_count":99,"metadata":{},"output_type":"execute_result"}],"source":["# CSV pre-processing MNIST\n","results_flanders_file = results_dir + \"all_results_mnist_flanders.csv\"\n","results_no_flanders_file = results_dir + \"all_results_mnist_no_flanders.csv\"\n","results_flanders_df = pd.read_csv(results_flanders_file)\n","results_no_flanders_df = pd.read_csv(results_no_flanders_file)\n","results_flanders_df = translate_cols(results_flanders_df, attack_dict ,dataset_dict, strategy_dict, aggregate_dict)\n","results_no_flanders_df = translate_cols(results_no_flanders_df, attack_dict ,dataset_dict, strategy_dict, aggregate_dict)\n","mnist_df = pd.concat([results_flanders_df, results_no_flanders_df])\n","mnist_df"]},{"cell_type":"code","execution_count":100,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":9,"status":"ok","timestamp":1716115669854,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"hg3ysnqiNrms","outputId":"bae8ab71-ce7b-409a-b154-6ad42f9dfc3b"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['GAUSS' 'LIE' 'OPT' 'AGR-MM']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['MNIST']\n","Unique values in strategy: ['FLANDERS' 'FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan' 'dnc']\n","Unique values in aggregate_fn: ['FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan'\n"," 'flanders.strategies.aggregate.aggregate_dnc']\n"]}],"source":["print_unique_data(mnist_df)"]},{"cell_type":"markdown","metadata":{"id":"dE_uqUeuyl6M"},"source":["### Step-by-step processing"]},{"cell_type":"code","execution_count":101,"metadata":{"id":"R9Cpe8bF8a2z"},"outputs":[],"source":["results_flanders_file = results_dir + \"all_results_mnist_flanders.csv\"\n","results_no_flanders_file = results_dir + \"all_results_mnist_no_flanders.csv\""]},{"cell_type":"code","execution_count":102,"metadata":{"id":"8nPsIraZ7nJK"},"outputs":[],"source":["results_flanders_df = pd.read_csv(results_flanders_file)\n","results_no_flanders_df = pd.read_csv(results_no_flanders_file)"]},{"cell_type":"code","execution_count":103,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3,"status":"ok","timestamp":1707513800371,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"oC_C6WVMshle","outputId":"20708ffb-24d9-4d94-fc83-6ab93c8d4ed0"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['gaussian' 'lie' 'fang' 'minmax']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['mnist']\n","Unique values in strategy: ['flanders']\n","Unique values in aggregate_fn: ['flwr.server.strategy.aggregate.aggregate'\n"," 'flwr.server.strategy.aggregate.aggregate_trimmed_avg'\n"," 'flwr.server.strategy.aggregate.aggregate_median'\n"," 'flwr.server.strategy.aggregate.aggregate_krum'\n"," 'flwr.server.strategy.aggregate.aggregate_bulyan'\n"," 'flanders.strategies.aggregate.aggregate_dnc']\n"]}],"source":["print_unique_data(results_flanders_df)"]},{"cell_type":"code","execution_count":104,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":423,"status":"ok","timestamp":1707478795736,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"2xSStl9-52cc","outputId":"391a3d7c-c4b5-486c-85a2-97d9a5ddb30d"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['gaussian' 'lie' 'fang' 'minmax']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['mnist']\n","Unique values in strategy: ['fedavg' 'trimmedmean' 'fedmedian' 'krum' 'bulyan' 'dnc']\n","Unique values in aggregate_fn: ['flwr.server.strategy.aggregate.aggregate']\n"]}],"source":["print_unique_data(results_no_flanders_df)"]},{"cell_type":"markdown","metadata":{"id":"8dEepZY28raZ"},"source":["Translate strings"]},{"cell_type":"code","execution_count":105,"metadata":{"id":"zNPGc6YJ7E_J"},"outputs":[],"source":["results_flanders_df = translate_cols(results_flanders_df, attack_dict ,dataset_dict, strategy_dict, aggregate_dict)"]},{"cell_type":"code","execution_count":106,"metadata":{"id":"AQaNnF1K7TQc"},"outputs":[],"source":["results_no_flanders_df = translate_cols(results_no_flanders_df, attack_dict ,dataset_dict, strategy_dict, aggregate_dict)"]},{"cell_type":"code","execution_count":107,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":393,"status":"ok","timestamp":1707478246670,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"MXYnAzh8V-9t","outputId":"8f09ab5c-cd3e-4627-9f76-5f474ec09227"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['GAUSS' 'LIE' 'OPT' 'AGR-MM']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['MNIST']\n","Unique values in strategy: ['FLANDERS']\n","Unique values in aggregate_fn: ['FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan'\n"," 'flanders.strategies.aggregate.aggregate_dnc']\n"]}],"source":["print_unique_data(results_flanders_df)"]},{"cell_type":"code","execution_count":108,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":8,"status":"ok","timestamp":1707472989224,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"7xmPF5X77VQk","outputId":"f8e8331e-cde6-4413-f795-f0fe1a9cdd19"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['GAUSS' 'LIE' 'OPT' 'AGR-MM']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['MNIST']\n","Unique values in strategy: ['FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan' 'dnc']\n","Unique values in aggregate_fn: ['FedAvg']\n"]}],"source":["print_unique_data(results_no_flanders_df)"]},{"cell_type":"markdown","metadata":{"id":"mjTmoV2M9YTk"},"source":["Concatenate the 2 dataframes, namely FLANDERS+f and baselines:"]},{"cell_type":"code","execution_count":109,"metadata":{"id":"apvpT9Ve8wwv"},"outputs":[],"source":["mnist_df = pd.concat([results_flanders_df, results_no_flanders_df])"]},{"cell_type":"code","execution_count":110,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":424},"executionInfo":{"elapsed":5,"status":"ok","timestamp":1707513807441,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"aZuW73BW9Iu3","outputId":"3f558906-0d55-4ccd-e64f-5b62203ae746"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
roundlossaccuracyaucTPTNFPFNattack_fndataset_namenum_maliciousstrategyaggregate_fn
00720.8331210.10450.5034220000GAUSSMNIST0FLANDERSFedAvg
11664.5222520.20890.5631160000GAUSSMNIST0FLANDERSFedAvg
22624.6338260.35600.6447310000GAUSSMNIST0FLANDERSFedAvg
33581.4764720.47730.710941010000GAUSSMNIST0FLANDERSFedAvg
44545.1142050.54300.746970020000GAUSSMNIST0FLANDERSFedAvg
..........................................
356546724.2817250.10250.5004790000AGR-MMMNIST80dncFedAvg
356647724.3501380.10240.5004210000AGR-MMMNIST80dncFedAvg
356748724.5352630.10250.5004790000AGR-MMMNIST80dncFedAvg
356849724.5888810.10280.5005980000AGR-MMMNIST80dncFedAvg
356950724.7838510.10280.5006020000AGR-MMMNIST80dncFedAvg
\n","

7548 rows × 13 columns

\n","
"],"text/plain":[" round loss accuracy auc TP TN FP FN attack_fn \\\n","0 0 720.833121 0.1045 0.503422 0 0 0 0 GAUSS \n","1 1 664.522252 0.2089 0.563116 0 0 0 0 GAUSS \n","2 2 624.633826 0.3560 0.644731 0 0 0 0 GAUSS \n","3 3 581.476472 0.4773 0.710941 0 100 0 0 GAUSS \n","4 4 545.114205 0.5430 0.746970 0 200 0 0 GAUSS \n","... ... ... ... ... .. ... .. .. ... \n","3565 46 724.281725 0.1025 0.500479 0 0 0 0 AGR-MM \n","3566 47 724.350138 0.1024 0.500421 0 0 0 0 AGR-MM \n","3567 48 724.535263 0.1025 0.500479 0 0 0 0 AGR-MM \n","3568 49 724.588881 0.1028 0.500598 0 0 0 0 AGR-MM \n","3569 50 724.783851 0.1028 0.500602 0 0 0 0 AGR-MM \n","\n"," dataset_name num_malicious strategy aggregate_fn \n","0 MNIST 0 FLANDERS FedAvg \n","1 MNIST 0 FLANDERS FedAvg \n","2 MNIST 0 FLANDERS FedAvg \n","3 MNIST 0 FLANDERS FedAvg \n","4 MNIST 0 FLANDERS FedAvg \n","... ... ... ... ... \n","3565 MNIST 80 dnc FedAvg \n","3566 MNIST 80 dnc FedAvg \n","3567 MNIST 80 dnc FedAvg \n","3568 MNIST 80 dnc FedAvg \n","3569 MNIST 80 dnc FedAvg \n","\n","[7548 rows x 13 columns]"]},"execution_count":110,"metadata":{},"output_type":"execute_result"}],"source":["mnist_df"]},{"cell_type":"code","execution_count":111,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3,"status":"ok","timestamp":1707480685917,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"Ub0W-iA69LpR","outputId":"bfbcaf9e-575c-4d02-e9f0-0be7e74accb4"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['GAUSS' 'LIE' 'OPT' 'AGR-MM']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['MNIST']\n","Unique values in strategy: ['FLANDERS' 'FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan' 'dnc']\n","Unique values in aggregate_fn: ['FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan'\n"," 'flanders.strategies.aggregate.aggregate_dnc']\n"]}],"source":["print_unique_data(mnist_df)"]},{"cell_type":"markdown","metadata":{"id":"E3TZ_fJuTVuU"},"source":["## Fashion MNIST"]},{"cell_type":"markdown","metadata":{"id":"45GAIKG9Tmyb"},"source":["### Use this shortcut"]},{"cell_type":"code","execution_count":112,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":443},"executionInfo":{"elapsed":327,"status":"ok","timestamp":1716376732776,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"Qju5S7VmTpB_","outputId":"fbe0aed8-164c-4343-9949-e9fa7cc7f0a7"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
roundlossaccuracyaucTPTNFPFNattack_fndataset_namenum_maliciousstrategyaggregate_fn
0023082.3338130.06310.4795000000GAUSSFMNIST0FLANDERSFedAvg
1121920.1315610.19770.5542780000GAUSSFMNIST0FLANDERSFedAvg
2217859.0960200.42100.6783330000GAUSSFMNIST0FLANDERSFedAvg
3315559.0449260.49200.717778010000GAUSSFMNIST0FLANDERSFedAvg
4414684.1937220.50010.722278020000GAUSSFMNIST0FLANDERSFedAvg
..........................................
35654623279.5649070.10000.5000000000AGR-MMFMNIST80dncFedAvg
35664723290.9804420.10000.5000000000AGR-MMFMNIST80dncFedAvg
35674823302.2510220.10000.5000000000AGR-MMFMNIST80dncFedAvg
35684923312.5125960.10000.5000000000AGR-MMFMNIST80dncFedAvg
35695023326.1161770.10000.5000000000AGR-MMFMNIST80dncFedAvg
\n","

6884 rows × 13 columns

\n","
"],"text/plain":[" round loss accuracy auc TP TN FP FN attack_fn \\\n","0 0 23082.333813 0.0631 0.479500 0 0 0 0 GAUSS \n","1 1 21920.131561 0.1977 0.554278 0 0 0 0 GAUSS \n","2 2 17859.096020 0.4210 0.678333 0 0 0 0 GAUSS \n","3 3 15559.044926 0.4920 0.717778 0 100 0 0 GAUSS \n","4 4 14684.193722 0.5001 0.722278 0 200 0 0 GAUSS \n","... ... ... ... ... .. ... .. .. ... \n","3565 46 23279.564907 0.1000 0.500000 0 0 0 0 AGR-MM \n","3566 47 23290.980442 0.1000 0.500000 0 0 0 0 AGR-MM \n","3567 48 23302.251022 0.1000 0.500000 0 0 0 0 AGR-MM \n","3568 49 23312.512596 0.1000 0.500000 0 0 0 0 AGR-MM \n","3569 50 23326.116177 0.1000 0.500000 0 0 0 0 AGR-MM \n","\n"," dataset_name num_malicious strategy aggregate_fn \n","0 FMNIST 0 FLANDERS FedAvg \n","1 FMNIST 0 FLANDERS FedAvg \n","2 FMNIST 0 FLANDERS FedAvg \n","3 FMNIST 0 FLANDERS FedAvg \n","4 FMNIST 0 FLANDERS FedAvg \n","... ... ... ... ... \n","3565 FMNIST 80 dnc FedAvg \n","3566 FMNIST 80 dnc FedAvg \n","3567 FMNIST 80 dnc FedAvg \n","3568 FMNIST 80 dnc FedAvg \n","3569 FMNIST 80 dnc FedAvg \n","\n","[6884 rows x 13 columns]"]},"execution_count":112,"metadata":{},"output_type":"execute_result"}],"source":["# CSV pre-processing FMNIST\n","results_flanders_file = results_dir + \"all_results_fmnist_flanders.csv\"\n","results_no_flanders_file = results_dir + \"all_results_fmnist_no_flanders.csv\"\n","results_flanders_df = pd.read_csv(results_flanders_file)\n","results_no_flanders_df = pd.read_csv(results_no_flanders_file)\n","results_flanders_df = translate_cols(results_flanders_df, attack_dict ,dataset_dict, strategy_dict, aggregate_dict)\n","results_no_flanders_df = translate_cols(results_no_flanders_df, attack_dict ,dataset_dict, strategy_dict, aggregate_dict)\n","fmnist_df = pd.concat([results_flanders_df, results_no_flanders_df])\n","fmnist_df"]},{"cell_type":"code","execution_count":113,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1,"status":"ok","timestamp":1716047458204,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"08YXqMTZNpN6","outputId":"9de6cdc8-241e-4867-cd53-645434b99ce2"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['GAUSS' 'LIE' 'OPT' 'AGR-MM']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['FMNIST']\n","Unique values in strategy: ['FLANDERS' 'FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan' 'dnc']\n","Unique values in aggregate_fn: ['FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan']\n"]}],"source":["print_unique_data(fmnist_df)"]},{"cell_type":"markdown","metadata":{"id":"9vVX6wsxT-rc"},"source":["### Step-by-step processing"]},{"cell_type":"code","execution_count":114,"metadata":{"id":"j0ZnLmVnUBT3"},"outputs":[],"source":["results_flanders_file = results_dir + \"all_results_fmnist_flanders.csv\"\n","results_no_flanders_file = results_dir + \"all_results_fmnist_no_flanders.csv\""]},{"cell_type":"code","execution_count":115,"metadata":{"id":"qsYaQiAWUBOw"},"outputs":[],"source":["results_flanders_df = pd.read_csv(results_flanders_file)\n","results_no_flanders_df = pd.read_csv(results_no_flanders_file)"]},{"cell_type":"code","execution_count":116,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":250,"status":"ok","timestamp":1709217712591,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"m1VKWq_jUHyY","outputId":"5bd5b442-4ab4-473d-bc60-b6f5fdc6320d"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['gaussian' 'lie' 'fang' 'minmax']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['fmnist']\n","Unique values in strategy: ['flanders']\n","Unique values in aggregate_fn: ['flwr.server.strategy.aggregate.aggregate'\n"," 'flwr.server.strategy.aggregate.aggregate_trimmed_avg'\n"," 'flwr.server.strategy.aggregate.aggregate_median'\n"," 'flwr.server.strategy.aggregate.aggregate_krum'\n"," 'flwr.server.strategy.aggregate.aggregate_bulyan']\n"]}],"source":["print_unique_data(results_flanders_df)"]},{"cell_type":"code","execution_count":117,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":4,"status":"ok","timestamp":1709217720407,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"IxPT9D6DUJN3","outputId":"ded1f8b2-baf5-437b-abe3-25bad7a3c4ad"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['gaussian' 'lie' 'fang' 'minmax']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['fmnist']\n","Unique values in strategy: ['fedavg' 'trimmedmean' 'fedmedian' 'krum' 'bulyan' 'dnc']\n","Unique values in aggregate_fn: ['flwr.server.strategy.aggregate.aggregate']\n"]}],"source":["print_unique_data(results_no_flanders_df)"]},{"cell_type":"markdown","metadata":{"id":"X8k98LNrUaLp"},"source":["Translate strings"]},{"cell_type":"code","execution_count":118,"metadata":{"id":"zHNwpvZMUaLq"},"outputs":[],"source":["results_flanders_df = translate_cols(results_flanders_df, attack_dict ,dataset_dict, strategy_dict, aggregate_dict)"]},{"cell_type":"code","execution_count":119,"metadata":{"id":"9zMOOjCiUaLr"},"outputs":[],"source":["results_no_flanders_df = translate_cols(results_no_flanders_df, attack_dict ,dataset_dict, strategy_dict, aggregate_dict)"]},{"cell_type":"code","execution_count":120,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":4,"status":"ok","timestamp":1709217802421,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"tygeoDz6UaLr","outputId":"2441a31a-ead0-4adf-8b95-73d75c6b3739"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['GAUSS' 'LIE' 'OPT' 'AGR-MM']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['FMNIST']\n","Unique values in strategy: ['FLANDERS']\n","Unique values in aggregate_fn: ['FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan']\n"]}],"source":["print_unique_data(results_flanders_df)"]},{"cell_type":"code","execution_count":121,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3,"status":"ok","timestamp":1709217803343,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"b8Xo1VseUaLr","outputId":"1fbb0a64-0695-4a89-eec1-e0a18ba54c40"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['GAUSS' 'LIE' 'OPT' 'AGR-MM']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['FMNIST']\n","Unique values in strategy: ['FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan' 'dnc']\n","Unique values in aggregate_fn: ['FedAvg']\n"]}],"source":["print_unique_data(results_no_flanders_df)"]},{"cell_type":"markdown","metadata":{"id":"zkmqmUTzUaLr"},"source":["Concatenate the 2 dataframes, namely FLANDERS+f and baselines:"]},{"cell_type":"code","execution_count":122,"metadata":{"id":"m-wRVa9eUaLr"},"outputs":[],"source":["fmnist_df = pd.concat([results_flanders_df, results_no_flanders_df])"]},{"cell_type":"code","execution_count":123,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":423},"executionInfo":{"elapsed":4,"status":"ok","timestamp":1709217813677,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"IhpjH5n1UaLs","outputId":"73f9a359-9f74-4e5e-b518-25af272b2207"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
roundlossaccuracyaucTPTNFPFNattack_fndataset_namenum_maliciousstrategyaggregate_fn
0023082.3338130.06310.4795000000GAUSSFMNIST0FLANDERSFedAvg
1121920.1315610.19770.5542780000GAUSSFMNIST0FLANDERSFedAvg
2217859.0960200.42100.6783330000GAUSSFMNIST0FLANDERSFedAvg
3315559.0449260.49200.717778010000GAUSSFMNIST0FLANDERSFedAvg
4414684.1937220.50010.722278020000GAUSSFMNIST0FLANDERSFedAvg
..........................................
35654623279.5649070.10000.5000000000AGR-MMFMNIST80dncFedAvg
35664723290.9804420.10000.5000000000AGR-MMFMNIST80dncFedAvg
35674823302.2510220.10000.5000000000AGR-MMFMNIST80dncFedAvg
35684923312.5125960.10000.5000000000AGR-MMFMNIST80dncFedAvg
35695023326.1161770.10000.5000000000AGR-MMFMNIST80dncFedAvg
\n","

6884 rows × 13 columns

\n","
"],"text/plain":[" round loss accuracy auc TP TN FP FN attack_fn \\\n","0 0 23082.333813 0.0631 0.479500 0 0 0 0 GAUSS \n","1 1 21920.131561 0.1977 0.554278 0 0 0 0 GAUSS \n","2 2 17859.096020 0.4210 0.678333 0 0 0 0 GAUSS \n","3 3 15559.044926 0.4920 0.717778 0 100 0 0 GAUSS \n","4 4 14684.193722 0.5001 0.722278 0 200 0 0 GAUSS \n","... ... ... ... ... .. ... .. .. ... \n","3565 46 23279.564907 0.1000 0.500000 0 0 0 0 AGR-MM \n","3566 47 23290.980442 0.1000 0.500000 0 0 0 0 AGR-MM \n","3567 48 23302.251022 0.1000 0.500000 0 0 0 0 AGR-MM \n","3568 49 23312.512596 0.1000 0.500000 0 0 0 0 AGR-MM \n","3569 50 23326.116177 0.1000 0.500000 0 0 0 0 AGR-MM \n","\n"," dataset_name num_malicious strategy aggregate_fn \n","0 FMNIST 0 FLANDERS FedAvg \n","1 FMNIST 0 FLANDERS FedAvg \n","2 FMNIST 0 FLANDERS FedAvg \n","3 FMNIST 0 FLANDERS FedAvg \n","4 FMNIST 0 FLANDERS FedAvg \n","... ... ... ... ... \n","3565 FMNIST 80 dnc FedAvg \n","3566 FMNIST 80 dnc FedAvg \n","3567 FMNIST 80 dnc FedAvg \n","3568 FMNIST 80 dnc FedAvg \n","3569 FMNIST 80 dnc FedAvg \n","\n","[6884 rows x 13 columns]"]},"execution_count":123,"metadata":{},"output_type":"execute_result"}],"source":["fmnist_df"]},{"cell_type":"code","execution_count":124,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":6,"status":"ok","timestamp":1709217818750,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-60},"id":"XYwnEFa2UaLs","outputId":"44dcb4de-f892-4555-b5d8-8cfcacd37137"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['GAUSS' 'LIE' 'OPT' 'AGR-MM']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['FMNIST']\n","Unique values in strategy: ['FLANDERS' 'FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan' 'dnc']\n","Unique values in aggregate_fn: ['FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan']\n"]}],"source":["print_unique_data(fmnist_df)"]},{"cell_type":"markdown","metadata":{"id":"1TUxrAF6w6cY"},"source":["## Unify datasets"]},{"cell_type":"code","execution_count":125,"metadata":{"id":"R2wOP2Eex7X2"},"outputs":[],"source":["all_datasets_df = pd.concat([mnist_df, fmnist_df])"]},{"cell_type":"code","execution_count":126,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":443},"executionInfo":{"elapsed":428,"status":"ok","timestamp":1716376740426,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"jwDN17ygyFK7","outputId":"c3db739f-98b7-4070-f826-4344553e9ab7"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
roundlossaccuracyaucTPTNFPFNattack_fndataset_namenum_maliciousstrategyaggregate_fn
00720.8331210.10450.5034220000GAUSSMNIST0FLANDERSFedAvg
11664.5222520.20890.5631160000GAUSSMNIST0FLANDERSFedAvg
22624.6338260.35600.6447310000GAUSSMNIST0FLANDERSFedAvg
33581.4764720.47730.710941010000GAUSSMNIST0FLANDERSFedAvg
44545.1142050.54300.746970020000GAUSSMNIST0FLANDERSFedAvg
..........................................
35654623279.5649070.10000.5000000000AGR-MMFMNIST80dncFedAvg
35664723290.9804420.10000.5000000000AGR-MMFMNIST80dncFedAvg
35674823302.2510220.10000.5000000000AGR-MMFMNIST80dncFedAvg
35684923312.5125960.10000.5000000000AGR-MMFMNIST80dncFedAvg
35695023326.1161770.10000.5000000000AGR-MMFMNIST80dncFedAvg
\n","

14432 rows × 13 columns

\n","
"],"text/plain":[" round loss accuracy auc TP TN FP FN attack_fn \\\n","0 0 720.833121 0.1045 0.503422 0 0 0 0 GAUSS \n","1 1 664.522252 0.2089 0.563116 0 0 0 0 GAUSS \n","2 2 624.633826 0.3560 0.644731 0 0 0 0 GAUSS \n","3 3 581.476472 0.4773 0.710941 0 100 0 0 GAUSS \n","4 4 545.114205 0.5430 0.746970 0 200 0 0 GAUSS \n","... ... ... ... ... .. ... .. .. ... \n","3565 46 23279.564907 0.1000 0.500000 0 0 0 0 AGR-MM \n","3566 47 23290.980442 0.1000 0.500000 0 0 0 0 AGR-MM \n","3567 48 23302.251022 0.1000 0.500000 0 0 0 0 AGR-MM \n","3568 49 23312.512596 0.1000 0.500000 0 0 0 0 AGR-MM \n","3569 50 23326.116177 0.1000 0.500000 0 0 0 0 AGR-MM \n","\n"," dataset_name num_malicious strategy aggregate_fn \n","0 MNIST 0 FLANDERS FedAvg \n","1 MNIST 0 FLANDERS FedAvg \n","2 MNIST 0 FLANDERS FedAvg \n","3 MNIST 0 FLANDERS FedAvg \n","4 MNIST 0 FLANDERS FedAvg \n","... ... ... ... ... \n","3565 FMNIST 80 dnc FedAvg \n","3566 FMNIST 80 dnc FedAvg \n","3567 FMNIST 80 dnc FedAvg \n","3568 FMNIST 80 dnc FedAvg \n","3569 FMNIST 80 dnc FedAvg \n","\n","[14432 rows x 13 columns]"]},"execution_count":126,"metadata":{},"output_type":"execute_result"}],"source":["all_datasets_df"]},{"cell_type":"code","execution_count":127,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2,"status":"ok","timestamp":1716376741411,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"VWrdDPtPyHHA","outputId":"1ef8b853-63f8-4eac-a785-578684dbf0a6"},"outputs":[{"name":"stdout","output_type":"stream","text":["Unique values in attack_fn: ['GAUSS' 'LIE' 'OPT' 'AGR-MM']\n","Unique values in num_malicious: [ 0 20 60 80]\n","Unique values in dataset_name: ['MNIST' 'FMNIST']\n","Unique values in strategy: ['FLANDERS' 'FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan' 'dnc']\n","Unique values in aggregate_fn: ['FedAvg' 'TrimmedMean' 'FedMedian' 'MultiKrum' 'Bulyan'\n"," 'flanders.strategies.aggregate.aggregate_dnc']\n"]}],"source":["print_unique_data(all_datasets_df)"]},{"cell_type":"markdown","metadata":{"id":"OlES57TEn2Ng"},"source":["# Tables\n"]},{"cell_type":"markdown","metadata":{"id":"hcHkTXfGbapg"},"source":["## Accuracy"]},{"cell_type":"markdown","metadata":{"id":"7F1YDs12sZbE"},"source":["### Best with improvment w.r.t. baseline"]},{"cell_type":"code","execution_count":128,"metadata":{"id":"GuQM8bzXnIGx"},"outputs":[],"source":["def accuracy_table(input_df, b):\n"," # Define strategies and attacks\n"," strategies = ['FedAvg', 'FLANDERS + FedAvg', 'FedMedian', 'FLANDERS + FedMedian', 'TrimmedMean', 'FLANDERS + TrimmedMean', 'MultiKrum', 'FLANDERS + MultiKrum', 'Bulyan', 'FLANDERS + Bulyan']\n"," attacks = ['GAUSS', 'LIE', 'OPT', 'AGR-MM']\n"," dataset_names = [\"MNIST\", \"FMNIST\"]\n","\n"," # Create MultiIndex for the columns\n"," columns = pd.MultiIndex.from_product([dataset_names, attacks], names=['Dataset', 'Attack'])\n","\n"," # Create an empty DataFrame with the defined columns and strategies\n"," df = pd.DataFrame(index=strategies, columns=columns)\n","\n"," filtered_df = input_df[(input_df['num_malicious'] == b) & (input_df['round'] >= 3)]\n"," baseline_df = filtered_df[filtered_df['strategy'] != 'FLANDERS']\n"," flanders_df = filtered_df[filtered_df['strategy'] == 'FLANDERS']\n","\n"," # Populate the DataFrame\n"," for strategy in ['FedAvg', 'TrimmedMean', 'FedMedian', 'MultiKrum', 'Bulyan']:\n"," for dataset in dataset_names:\n"," for attack in attacks:\n"," df.loc[strategy, (dataset, attack)] = round(baseline_df[(baseline_df['strategy']==strategy) & (baseline_df['attack_fn']==attack) & (baseline_df['dataset_name']==dataset)]['accuracy'].max(), 2)\n"," df.loc[f\"FLANDERS + {strategy}\", (dataset, attack)] = round(flanders_df[(flanders_df['aggregate_fn']==strategy) & (flanders_df['attack_fn']==attack) & (flanders_df['dataset_name']==dataset)]['accuracy'].max(), 2)\n","\n"," return df\n"]},{"cell_type":"code","execution_count":129,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":457},"executionInfo":{"elapsed":1243,"status":"ok","timestamp":1715943873268,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"PruXJg2vA87b","outputId":"98d83851-626b-4877-fd09-8f2f01200c65"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
DatasetMNISTFMNIST
AttackGAUSSGAUSS
FedAvg0.860.68
FLANDERS + FedAvg0.840.64
FedMedian0.830.71
FLANDERS + FedMedian0.760.73
TrimmedMean0.850.69
FLANDERS + TrimmedMean0.780.7
MultiKrum0.680.66
FLANDERS + MultiKrum0.740.73
Bulyan0.860.62
FLANDERS + Bulyan0.870.65
\n","
"],"text/plain":["Dataset MNIST FMNIST\n","Attack GAUSS GAUSS\n","FedAvg 0.86 0.68\n","FLANDERS + FedAvg 0.84 0.64\n","FedMedian 0.83 0.71\n","FLANDERS + FedMedian 0.76 0.73\n","TrimmedMean 0.85 0.69\n","FLANDERS + TrimmedMean 0.78 0.7\n","MultiKrum 0.68 0.66\n","FLANDERS + MultiKrum 0.74 0.73\n","Bulyan 0.86 0.62\n","FLANDERS + Bulyan 0.87 0.65"]},"execution_count":129,"metadata":{},"output_type":"execute_result"}],"source":["# Table 19\n","acc_0 = accuracy_table(all_datasets_df, 0).dropna(axis=1, how='all')\n","acc_0"]},{"cell_type":"code","execution_count":130,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":457},"executionInfo":{"elapsed":858,"status":"ok","timestamp":1715944158330,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"8bV7hABWbyMS","outputId":"6d24f87d-da15-40b2-8880-3b6dcefda1fd"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
DatasetMNISTFMNIST
AttackGAUSSLIEOPTAGR-MMGAUSSLIEOPTAGR-MM
FedAvg0.20.170.670.450.250.170.570.11
FLANDERS + FedAvg0.880.870.480.880.660.670.570.64
FedMedian0.80.660.790.590.660.650.670.6
FLANDERS + FedMedian0.850.850.660.830.710.690.630.73
TrimmedMean0.860.520.730.610.690.540.620.58
FLANDERS + TrimmedMean0.810.850.780.830.690.70.630.73
MultiKrum0.780.770.810.820.740.650.70.67
FLANDERS + MultiKrum0.820.860.840.820.730.70.730.71
Bulyan0.820.840.840.830.710.720.690.76
FLANDERS + Bulyan0.90.840.790.850.650.650.660.65
\n","
"],"text/plain":["Dataset MNIST FMNIST \n","Attack GAUSS LIE OPT AGR-MM GAUSS LIE OPT AGR-MM\n","FedAvg 0.2 0.17 0.67 0.45 0.25 0.17 0.57 0.11\n","FLANDERS + FedAvg 0.88 0.87 0.48 0.88 0.66 0.67 0.57 0.64\n","FedMedian 0.8 0.66 0.79 0.59 0.66 0.65 0.67 0.6\n","FLANDERS + FedMedian 0.85 0.85 0.66 0.83 0.71 0.69 0.63 0.73\n","TrimmedMean 0.86 0.52 0.73 0.61 0.69 0.54 0.62 0.58\n","FLANDERS + TrimmedMean 0.81 0.85 0.78 0.83 0.69 0.7 0.63 0.73\n","MultiKrum 0.78 0.77 0.81 0.82 0.74 0.65 0.7 0.67\n","FLANDERS + MultiKrum 0.82 0.86 0.84 0.82 0.73 0.7 0.73 0.71\n","Bulyan 0.82 0.84 0.84 0.83 0.71 0.72 0.69 0.76\n","FLANDERS + Bulyan 0.9 0.84 0.79 0.85 0.65 0.65 0.66 0.65"]},"execution_count":130,"metadata":{},"output_type":"execute_result"}],"source":["# Table 15\n","acc_20 = accuracy_table(all_datasets_df, 20)\n","acc_20"]},{"cell_type":"markdown","metadata":{},"source":["Bulyan is NaN because it cannot work when the number of malicious clients is > 25%"]},{"cell_type":"code","execution_count":131,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":477},"executionInfo":{"elapsed":1126,"status":"ok","timestamp":1716115710006,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"xSvgEwLoPmh3","outputId":"1c65e66e-7374-46f0-a2a5-0964bc40e49a"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
DatasetMNISTFMNIST
AttackGAUSSLIEOPTAGR-MMGAUSSLIEOPTAGR-MM
FedAvg0.190.150.20.160.280.10.190.1
FLANDERS + FedAvg0.760.880.850.850.690.670.710.67
FedMedian0.80.190.160.290.650.10.10.1
FLANDERS + FedMedian0.80.860.830.860.710.690.710.71
TrimmedMean0.250.20.330.10.330.10.170.1
FLANDERS + TrimmedMean0.780.870.840.830.70.710.730.74
MultiKrum0.790.140.220.150.710.10.120.1
FLANDERS + MultiKrum0.880.880.860.780.720.710.730.69
BulyanNaNNaNNaNNaNNaNNaNNaNNaN
FLANDERS + Bulyan0.890.870.90.850.680.640.60.69
\n","
"],"text/plain":["Dataset MNIST FMNIST \n","Attack GAUSS LIE OPT AGR-MM GAUSS LIE OPT AGR-MM\n","FedAvg 0.19 0.15 0.2 0.16 0.28 0.1 0.19 0.1\n","FLANDERS + FedAvg 0.76 0.88 0.85 0.85 0.69 0.67 0.71 0.67\n","FedMedian 0.8 0.19 0.16 0.29 0.65 0.1 0.1 0.1\n","FLANDERS + FedMedian 0.8 0.86 0.83 0.86 0.71 0.69 0.71 0.71\n","TrimmedMean 0.25 0.2 0.33 0.1 0.33 0.1 0.17 0.1\n","FLANDERS + TrimmedMean 0.78 0.87 0.84 0.83 0.7 0.71 0.73 0.74\n","MultiKrum 0.79 0.14 0.22 0.15 0.71 0.1 0.12 0.1\n","FLANDERS + MultiKrum 0.88 0.88 0.86 0.78 0.72 0.71 0.73 0.69\n","Bulyan NaN NaN NaN NaN NaN NaN NaN NaN\n","FLANDERS + Bulyan 0.89 0.87 0.9 0.85 0.68 0.64 0.6 0.69"]},"execution_count":131,"metadata":{},"output_type":"execute_result"}],"source":["# Table 17\n","acc_60 = accuracy_table(all_datasets_df, 60)\n","acc_60"]},{"cell_type":"code","execution_count":132,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":477},"executionInfo":{"elapsed":1188,"status":"ok","timestamp":1716050662469,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"dM_AMm_jcCye","outputId":"90533ef1-500f-40a3-f565-2c2ca1d68aaa"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
DatasetMNISTFMNIST
AttackGAUSSLIEOPTAGR-MMGAUSSLIEOPTAGR-MM
FedAvg0.210.160.310.130.240.10.180.1
FLANDERS + FedAvg0.850.860.880.850.690.70.690.66
FedMedian0.340.170.140.090.30.10.140.1
FLANDERS + FedMedian0.870.840.80.80.730.740.720.72
TrimmedMean0.170.150.210.140.210.10.120.1
FLANDERS + TrimmedMean0.810.850.810.820.740.730.70.69
MultiKrum0.820.210.320.110.720.10.150.1
FLANDERS + MultiKrum0.870.830.870.850.680.730.720.7
BulyanNaNNaNNaNNaNNaNNaNNaNNaN
FLANDERS + Bulyan0.840.840.830.80.690.720.690.68
\n","
"],"text/plain":["Dataset MNIST FMNIST \n","Attack GAUSS LIE OPT AGR-MM GAUSS LIE OPT AGR-MM\n","FedAvg 0.21 0.16 0.31 0.13 0.24 0.1 0.18 0.1\n","FLANDERS + FedAvg 0.85 0.86 0.88 0.85 0.69 0.7 0.69 0.66\n","FedMedian 0.34 0.17 0.14 0.09 0.3 0.1 0.14 0.1\n","FLANDERS + FedMedian 0.87 0.84 0.8 0.8 0.73 0.74 0.72 0.72\n","TrimmedMean 0.17 0.15 0.21 0.14 0.21 0.1 0.12 0.1\n","FLANDERS + TrimmedMean 0.81 0.85 0.81 0.82 0.74 0.73 0.7 0.69\n","MultiKrum 0.82 0.21 0.32 0.11 0.72 0.1 0.15 0.1\n","FLANDERS + MultiKrum 0.87 0.83 0.87 0.85 0.68 0.73 0.72 0.7\n","Bulyan NaN NaN NaN NaN NaN NaN NaN NaN\n","FLANDERS + Bulyan 0.84 0.84 0.83 0.8 0.69 0.72 0.69 0.68"]},"execution_count":132,"metadata":{},"output_type":"execute_result"}],"source":["# Table 3\n","acc_80 = accuracy_table(all_datasets_df, 80)\n","acc_80"]},{"cell_type":"markdown","metadata":{"id":"CZX8c37MsgFL"},"source":["### Best w.r.t. number of attackers"]},{"cell_type":"code","execution_count":133,"metadata":{"id":"xgIKM1obsmd2"},"outputs":[],"source":["def accuracy_table_attackers(input_df, aggregate_fn):\n"," # Define strategies and attacks\n"," attacks = ['GAUSS', 'LIE', 'OPT', 'AGR-MM']\n"," dataset_names = [\"MNIST\", \"FMNIST\"]\n"," num_malicious = [0, 20, 60, 80]\n","\n"," # Create MultiIndex for the columns\n"," columns = pd.MultiIndex.from_product([dataset_names, num_malicious], names=['Dataset', '# Malicious'])\n","\n"," #######\n"," #columns = pd.MultiIndex.from_product([['MNIST', 'CIFAR-10'], ['GAUSS', 'LIE', 'OPT', 'AGR-MM'], ['LAST', 'BEST']])\n"," #######\n","\n","\n"," # Create an empty DataFrame with the defined columns and strategies\n"," df = pd.DataFrame(index=attacks, columns=columns)\n","\n"," filtered_df = input_df[(input_df['aggregate_fn'] == aggregate_fn) & (input_df['round'] >= 3)]\n"," baseline_df = filtered_df[filtered_df['strategy'] != 'FLANDERS']\n"," flanders_df = filtered_df[filtered_df['strategy'] == 'FLANDERS']\n","\n"," # Populate the DataFrame\n"," for dataset in dataset_names:\n"," for attack in attacks:\n"," for b in num_malicious:\n"," if b == 0:\n"," df.loc[attack, (dataset, b)] = round(flanders_df[(flanders_df['num_malicious']==b) & (flanders_df['attack_fn']=='GAUSS') & (flanders_df['dataset_name']==dataset)]['accuracy'].max(), 2)\n"," else:\n"," df.loc[attack, (dataset, b)] = round(flanders_df[(flanders_df['num_malicious']==b) & (flanders_df['attack_fn']==attack) & (flanders_df['dataset_name']==dataset)]['accuracy'].max(), 2)\n","\n"," return df"]},{"cell_type":"code","execution_count":134,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":206},"executionInfo":{"elapsed":305,"status":"ok","timestamp":1715954054792,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"0n9vhQiQuxk_","outputId":"500dfb31-f525-4289-f337-2c945bf50cd2"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
DatasetMNISTFMNIST
# Malicious02060800206080
GAUSS0.740.820.880.870.730.730.720.68
LIE0.740.860.880.830.730.70.710.73
OPT0.740.840.860.870.730.730.730.72
AGR-MM0.740.820.780.850.730.710.690.7
\n","
"],"text/plain":["Dataset MNIST FMNIST \n","# Malicious 0 20 60 80 0 20 60 80\n","GAUSS 0.74 0.82 0.88 0.87 0.73 0.73 0.72 0.68\n","LIE 0.74 0.86 0.88 0.83 0.73 0.7 0.71 0.73\n","OPT 0.74 0.84 0.86 0.87 0.73 0.73 0.73 0.72\n","AGR-MM 0.74 0.82 0.78 0.85 0.73 0.71 0.69 0.7"]},"execution_count":134,"metadata":{},"output_type":"execute_result"}],"source":["# Table 20\n","acc_att = accuracy_table_attackers(all_datasets_df, 'MultiKrum')\n","acc_att"]},{"cell_type":"markdown","metadata":{"id":"g6yDorrubUw1"},"source":["## Precision and Recall"]},{"cell_type":"code","execution_count":135,"metadata":{"id":"y0PlDlG0bKek"},"outputs":[],"source":["def pr_table(input_df, b):\n"," strategies = ['FLANDERS']\n"," attacks = ['GAUSS', 'LIE', 'OPT', 'AGR-MM']\n"," dataset_names = [\"MNIST\", \"FMNIST\"]\n","\n"," # Create MultiIndex for the columns\n"," columns = pd.MultiIndex.from_product([strategies, attacks, ['P', 'R']], names=['Strategy', 'Attack', 'P/R'])\n","\n"," # Create an empty DataFrame with the defined columns and strategies\n"," df = pd.DataFrame(index=dataset_names, columns=columns)\n","\n"," filtered_df = input_df[(input_df['num_malicious'] == b) & (input_df['round'] == 50) & (input_df['aggregate_fn']=='FedAvg')]\n"," flanders_df = filtered_df[filtered_df['strategy'] == 'FLANDERS']\n"," strat_dfs = [flanders_df]\n","\n"," # Populate the DataFrame\n"," for dataset in dataset_names:\n"," for attack in attacks:\n"," for idx, strategy in enumerate(strategies):\n"," tp = strat_dfs[idx][(strat_dfs[idx]['attack_fn']==attack) & (strat_dfs[idx]['dataset_name']==dataset)]['TP'].iloc[0]\n"," fp = strat_dfs[idx][(strat_dfs[idx]['attack_fn']==attack) & (strat_dfs[idx]['dataset_name']==dataset)]['FP'].iloc[0]\n"," fn = strat_dfs[idx][(strat_dfs[idx]['attack_fn']==attack) & (strat_dfs[idx]['dataset_name']==dataset)]['FN'].iloc[0]\n"," df.loc[dataset, (strategy, attack, 'P')] = round(tp / (tp+fp), 2)\n"," df.loc[dataset, (strategy, attack, 'R')] = round(tp / (tp+fn), 2)\n","\n"," return df"]},{"cell_type":"code","execution_count":136,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":257},"executionInfo":{"elapsed":248,"status":"ok","timestamp":1716367268457,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"rpQyziNRh3dn","outputId":"26e2c0f9-144d-4cad-d13f-1e2351aa2081"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
StrategyFLANDERS
AttackGAUSSLIEOPTAGR-MM
P/RPRPRPRPR
MNIST1.01.01.01.00.150.151.01.0
FMNIST1.01.01.01.00.160.161.01.0
\n","
"],"text/plain":["Strategy FLANDERS \n","Attack GAUSS LIE OPT AGR-MM \n","P/R P R P R P R P R\n","MNIST 1.0 1.0 1.0 1.0 0.15 0.15 1.0 1.0\n","FMNIST 1.0 1.0 1.0 1.0 0.16 0.16 1.0 1.0"]},"execution_count":136,"metadata":{},"output_type":"execute_result"}],"source":["# Table 1\n","pr_20 = pr_table(all_datasets_df, 20)\n","pr_20"]},{"cell_type":"code","execution_count":137,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":257},"executionInfo":{"elapsed":327,"status":"ok","timestamp":1716367273542,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"ccW0Ups3iMvZ","outputId":"ba945e17-1bbe-414f-9aa0-b03fb0fe107f"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
StrategyFLANDERS
AttackGAUSSLIEOPTAGR-MM
P/RPRPRPRPR
MNIST1.01.01.01.01.01.01.01.0
FMNIST1.01.01.01.01.01.01.01.0
\n","
"],"text/plain":["Strategy FLANDERS \n","Attack GAUSS LIE OPT AGR-MM \n","P/R P R P R P R P R\n","MNIST 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n","FMNIST 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0"]},"execution_count":137,"metadata":{},"output_type":"execute_result"}],"source":["# Table 2\n","pr_60 = pr_table(all_datasets_df, 60)\n","pr_60"]},{"cell_type":"code","execution_count":138,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":257},"executionInfo":{"elapsed":308,"status":"ok","timestamp":1716376750779,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"05a0Gv5piS2v","outputId":"0cf6555a-4cc8-4aa5-f9b6-c8286afa130a"},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
StrategyFLANDERS
AttackGAUSSLIEOPTAGR-MM
P/RPRPRPRPR
MNIST1.01.01.01.01.01.01.01.0
FMNIST1.01.01.01.01.01.01.01.0
\n","
"],"text/plain":["Strategy FLANDERS \n","Attack GAUSS LIE OPT AGR-MM \n","P/R P R P R P R P R\n","MNIST 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n","FMNIST 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0"]},"execution_count":138,"metadata":{},"output_type":"execute_result"}],"source":["# Table 3\n","pr_80 = pr_table(all_datasets_df, 80)\n","pr_80"]},{"cell_type":"markdown","metadata":{"id":"bN7dTn2u0r6K"},"source":["# Plots"]},{"cell_type":"markdown","metadata":{"id":"xZ0wiadBsVUh"},"source":["## Accuracy over rounds"]},{"cell_type":"code","execution_count":139,"metadata":{"id":"LQ_uYJCtjdJS"},"outputs":[],"source":["df_mnist_acc_flanders = all_datasets_df[(all_datasets_df['strategy']=='FLANDERS') & (all_datasets_df['num_malicious']==80) & (all_datasets_df['dataset_name']=='MNIST') & (all_datasets_df['aggregate_fn']=='MultiKrum')]\n","df_mnist_acc_fedavg = all_datasets_df[(all_datasets_df['strategy']=='FedAvg') & (all_datasets_df['num_malicious']==80) & (all_datasets_df['dataset_name']=='MNIST')]\n","df_no_attack = all_datasets_df[(all_datasets_df['strategy']=='FedAvg') & (all_datasets_df['num_malicious']==0) & (all_datasets_df['dataset_name']=='MNIST')]"]},{"cell_type":"code","execution_count":140,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":408},"executionInfo":{"elapsed":651,"status":"ok","timestamp":1714668544212,"user":{"displayName":"Edoardo Gabrielli","userId":"12318890431187689267"},"user_tz":-120},"id":"NDkD_Qnd1iT7","outputId":"928cfa90-04a2-4624-9825-adb0d124aaf8"},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["# Figure 3\n","num_plots = 2\n","plt.style.use('default')\n","fig, axs = plt.subplots(1, num_plots, figsize=(11, 4))\n","\n","data = [df_mnist_acc_flanders, df_mnist_acc_fedavg]\n","\n","for i in range(num_plots):\n"," if i == 0:\n"," acc = df_no_attack[df_no_attack[\"attack_fn\"]=='GAUSS']['accuracy'].to_list()\n"," axs[i].plot(acc, label=\"No Attack\", linestyle='--', color='slategray')\n"," for attack in ['GAUSS', 'LIE', 'OPT', 'AGR-MM']:\n"," acc = data[i][data[i]['attack_fn']==attack]['accuracy'].to_list()\n"," x = [i for i in range(len(data))]\n"," axs[i].plot(acc, label=attack)\n"," axs[i].set_ylim((0,1.0))\n"," axs[i].set_xlabel('Round', fontsize=16)\n"," axs[i].set_ylabel('Accuracy', fontsize=16)\n"," axs[i].legend(prop={'size': 12})\n"," axs[i].tick_params(axis='both', which='major', labelsize=16)\n"," axs[i].tick_params(axis='both', which='minor', labelsize=16)\n","\n","plt.show()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"colab":{"authorship_tag":"ABX9TyODCHCYl18UhHkKwq6LlRvG","collapsed_sections":["P_3Z05w0wvNB","dE_uqUeuyl6M","9vVX6wsxT-rc","RctMDJMZyPq2","R9VNz7Cv9RHn","6V4padUiYeac","pZ863s6JJbph","EJDtdXqLJX0H"],"provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.18"}},"nbformat":4,"nbformat_minor":0} diff --git a/baselines/flanders/pyproject.toml b/baselines/flanders/pyproject.toml new file mode 100644 index 000000000000..416247f9c7bb --- /dev/null +++ b/baselines/flanders/pyproject.toml @@ -0,0 +1,151 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.masonry.api" + +[tool.poetry] +name = "flanders" +version = "1.0.0" +description = "FLANDERS" +license = "Apache-2.0" +authors = ["Edoardo Gabrielli "] +readme = "README.md" +homepage = "https://flower.dev" +repository = "https://github.com/adap/flower" +documentation = "https://flower.dev" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[tool.poetry.dependencies] +python = ">=3.10, <3.12.0" +hydra-core = "1.3.2" # don't change this +flwr = {extras = ["simulation"], version = "1.6.0" } +torch = [ + { platform = "darwin", version = "2.1.1" }, + { platform = "linux", url = "https://download.pytorch.org/whl/cu118/torch-2.1.1%2Bcu118-cp310-cp310-linux_x86_64.whl" } + ] +torchvision = [ + { platform = "darwin", version = "0.16.1"}, + { platform = "linux", url = "https://download.pytorch.org/whl/cu118/torchvision-0.16.1%2Bcu118-cp310-cp310-linux_x86_64.whl" } + ] +pandas = "^2.1.3" +scikit-learn = "1.3.2" +ipykernel = "^6.27.1" +natsort = "^8.4.0" +seaborn = "^0.13.0" + +[tool.poetry.dev-dependencies] +isort = "==5.11.5" +black = "==23.1.0" +docformatter = "==1.5.1" +mypy = "==1.4.1" +pylint = "==2.8.2" +flake8 = "==3.9.2" +pytest = "==6.2.4" +pytest-watch = "==4.2.0" +ruff = "==0.0.272" +types-requests = "==2.27.7" +virtualenv = "20.21.0" + +[tool.isort] +line_length = 88 +indent = " " +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true + +[tool.black] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311"] + +[tool.pytest.ini_options] +minversion = "6.2" +addopts = "-qq" +testpaths = [ + "flwr_baselines", +] + +[tool.mypy] +ignore_missing_imports = true +strict = false +plugins = "numpy.typing.mypy_plugin" + +[tool.pylint."MESSAGES CONTROL"] +disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y" +signature-mutators="hydra.main.main" + +[tool.pylint.typecheck] +generated-members="numpy.*, torch.*, tensorflow.*" + +[[tool.mypy.overrides]] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] +follow_imports = "skip" +follow_imports_for_stubs = true +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch.*" +follow_imports = "skip" +follow_imports_for_stubs = true + +[tool.docformatter] +wrap-summaries = 88 +wrap-descriptions = 88 + +[tool.ruff] +target-version = "py38" +line-length = 88 +select = ["D", "E", "F", "W", "B", "ISC", "C4"] +fixable = ["D", "E", "F", "W", "B", "ISC", "C4"] +ignore = ["B024", "B027"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "proto", +] + +[tool.ruff.pydocstyle] +convention = "numpy" diff --git a/baselines/flanders/run.sh b/baselines/flanders/run.sh new file mode 100644 index 000000000000..435c358c4ee7 --- /dev/null +++ b/baselines/flanders/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +python -m flanders.main --multirun server.num_rounds=50 dataset=mnist strategy=flanders aggregate_fn=fedavg,trimmedmean,fedmedian,krum,bulyan server.pool_size=100 server.num_malicious=0,20,60,80 server.attack_fn=gaussian,lie,fang,minmax server.warmup_rounds=2 client_resources.num_cpus=0.1 client_resources.num_gpus=0.1 + +python -m flanders.main --multirun server.num_rounds=50 dataset=mnist strategy=fedavg,trimmedmean,fedmedian,krum,bulyan server.pool_size=100 server.num_malicious=0,20,60,80 server.attack_fn=gaussian,lie,fang,minmax server.warmup_rounds=2 client_resources.num_cpus=0.1 client_resources.num_gpus=0.1 + +python -m flanders.main --multirun server.num_rounds=50 dataset=fmnist strategy=flanders aggregate_fn=fedavg,trimmedmean,fedmedian,krum,bulyan server.pool_size=100 server.num_malicious=0,20,60,80 server.attack_fn=gaussian,lie,fang,minmax server.warmup_rounds=2 client_resources.num_cpus=0.1 client_resources.num_gpus=0.1 + +python -m flanders.main --multirun server.num_rounds=50 dataset=fmnist strategy=fedavg,trimmedmean,fedmedian,krum,bulyan server.pool_size=100 server.num_malicious=0,20,60,80 server.attack_fn=gaussian,lie,fang,minmax server.warmup_rounds=2 client_resources.num_cpus=0.1 client_resources.num_gpus=0.1 \ No newline at end of file