diff --git a/baselines/heterofl/LICENSE b/baselines/heterofl/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/baselines/heterofl/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/baselines/heterofl/README.md b/baselines/heterofl/README.md new file mode 100644 index 000000000000..6e9c32077e9b --- /dev/null +++ b/baselines/heterofl/README.md @@ -0,0 +1,200 @@ +--- +title: "HeteroFL: Computation And Communication Efficient Federated Learning For Heterogeneous Clients" +url: https://openreview.net/forum?id=TNkPBBYFkXg +labels: [system heterogeneity, image classification] +dataset: [MNIST, CIFAR-10] +--- + +# HeteroFL: Computation And Communication Efficient Federated Learning For Heterogeneous Clients + +**Paper:** [openreview.net/forum?id=TNkPBBYFkXg](https://openreview.net/forum?id=TNkPBBYFkXg) + +**Authors:** Enmao Diao, Jie Ding, Vahid Tarokh + +**Abstract:** Federated Learning (FL) is a method of training machine learning models on private data distributed over a large number of possibly heterogeneous clients such as mobile phones and IoT devices. In this work, we propose a new federated learning framework named HeteroFL to address heterogeneous clients equipped with very different computation and communication capabilities. Our solution can enable the training of heterogeneous local models with varying computation complexities and still produce a single global inference model. For the first time, our method challenges the underlying assumption of existing work that local models have to share the same architecture as the global model. We demonstrate several strategies to enhance FL training and conduct extensive empirical evaluations, including five computation complexity levels of three model architecture on three datasets. We show that adaptively distributing subnetworks according to clients’ capabilities is both computation and communication efficient. + + +## About this baseline + +**What’s implemented:** The code in this directory is an implementation of HeteroFL in PyTorch using Flower. The code incorporates references from the authors' implementation. Implementation of custom model split and aggregation as suggested by [@negedng](https://github.com/negedng), is available [here](https://github.com/msck72/heterofl_custom_aggregation). By modifying the configuration in the `base.yaml`, the results in the paper can be replicated, with both fixed and dynamic computational complexities among clients. + +**Key Terminology:** ++ *Model rate* defines the computational complexity of a client. Authors have defined five different computation complexity levels {a, b, c, d, e} with the hidden channel shrinkage ratio r = 0.5. + ++ *Model split mode* specifies whether the computational complexities of clients are fixed (throughout the experiment), or whether they are dynamic (change their mode_rate/computational-complexity every-round). + ++ *Model mode* determines the proportionality of clients with various computation complexity levels, for example, a4-b2-e4 determines at each round, proportion of clients with computational complexity level a = 4 / (4 + 2 + 4) * num_clients, similarly, proportion of clients with computational complexity level b = 2 / (4 + 2 + 4) * num_clients and so on. + +**Implementation Insights:** +*ModelRateManager* manages the model rate of client in simulation, which changes the model rate based on the model mode of the setup and *ClientManagerHeterofl* keeps track of model rates of the clients, so configure fit knows which/how-much subset of the model that needs to be sent to the client. + +**Datasets:** The code utilized benchmark MNIST and CIFAR-10 datasets from Pytorch's torchvision for its experimentation. + +**Hardware Setup:** The experiments were run on Google colab pro with 50GB RAM and T4 TPU. For MNIST dataset & CNN model, it approximately takes 1.5 hours to complete 200 rounds while for CIFAR10 dataset & ResNet18 model it takes around 3-4 hours to complete 400 rounds (may vary based on the model-mode of the setup). + +**Contributors:** M S Chaitanya Kumar [(github.com/msck72)](https://github.com/msck72) + + +## Experimental Setup + +**Task:** Image Classification. +**Model:** This baseline uses two models: ++ Convolutional Neural Network(CNN) model is used for MNIST dataset. ++ PreResNet (preactivated ResNet) model is used for CIFAR10 dataset. + +These models use static batch normalization (sBN) and they incorporate a Scaler module following each convolutional layer. + +**Dataset:** This baseline includes MNIST and CIFAR10 datasets. + +| Dataset | #Classes | IID Partition | non-IID Partition | +| :---: | :---: | :---: | :---: | +| MNIST
CIFAR10 | 10| Distribution of equal number of data examples among n clients | Distribution of data examples such that each client has at most 2 (customizable) classes | + + +**Training Hyperparameters:** + +| Description | Data Setting | MNIST | CIFAR-10 | +| :---: | :---: | :---:| :---: | +Total Clients | both | 100 | 100 | +Clients Per Round | both | 100 | 100 +Local Epcohs | both | 5 | 5 +Num. ROunds | IID
non-IID| 200
400 | 400
800 +Optimizer | both | SGD | SGD +Momentum | both | 0.9 | 0.9 +Weight-decay | both | 5.00e-04 | 5.00e-04 +Learning Rate | both | 0.01 | 0.1 +Decay Schedule | IID
non-IID| [100]
[150, 250] | [200]
[300,500] +Hidden Layers | both | [64 , 128 , 256 , 512] | [64 , 128 , 256 , 512] + + +The hyperparameters of Fedavg baseline are available in [Liang et al (2020)](https://arxiv.org/abs/2001.01523). + +## Environment Setup + +To construct the Python environment, simply run: + +```bash +# Set python version +pyenv install 3.10.6 +pyenv local 3.10.6 + +# Tell poetry to use python 3.10 +poetry env use 3.10.6 + +# install the base Poetry environment +poetry install + +# activate the environment +poetry shell +``` + + +## Running the Experiments +To run HeteroFL experiments in poetry activated environment: +```bash +# The main experiment implemented in your baseline using default hyperparameters (that should be setup in the Hydra configs) +# should run (including dataset download and necessary partitioning) by executing the command: + +python -m heterofl.main # Which runs the heterofl with arguments availbale in heterfl/conf/base.yaml + +# We could override the settings that were specified in base.yaml using the command-line-arguments +# Here's an example for changing the dataset name, non-iid and model +python -m heterofl.main dataset.dataset_name='CIFAR10' dataset.iid=False model.model_name='resnet18' + +# Similarly, another example for changing num_rounds, model_split_mode, and model_mode +python -m heterofl.main num_rounds=400 control.model_split_mode='dynamic' control.model_mode='a1-b1' + +# Similarly, another example for changing num_rounds, model_split_mode, and model_mode +python -m heterofl.main num_rounds=400 control.model_split_mode='dynamic' control.model_mode='a1-b1' + +``` +To run FedAvg experiments: +```bash +python -m heterofl.main --config-name fedavg +# Similarly to the commands illustrated above, we can modify the default settings in the fedavg.yaml file. +``` + +## Expected Results + +```bash +# running the multirun for IID-MNIST with various model-modes using default config +python -m heterofl.main --multirun control.model_mode='a1','a1-e1','a1-b1-c1-d1-e1' + +# running the multirun for IID-CIFAR10 dataset with various model-modes by modifying default config +python -m heterofl.main --multirun control.model_mode='a1','a1-e1','a1-b1-c1-d1-e1' dataset.dataset_name='CIFAR10' model.model_name='resnet18' num_rounds=400 optim_scheduler.lr=0.1 strategy.milestones=[150, 250] + +# running the multirun for non-IID-MNIST with various model-modes by modifying default config +python -m heterofl.main --multirun control.model_mode='a1','a1-e1','a1-b1-c1-d1-e1' dataset.iid=False num_rounds=400 optim_scheduler.milestones=[200] + +# similarly, we can perform for various model-modes, datasets. But we cannot multirun with both non-iid and iid at once for reproducing the tables below, since the number of rounds and milestones for MultiStepLR are different for non-iid and iid. The tables below are the reproduced results of various multiruns. + +#To reproduce the fedavg results +#for MNIST dataset +python -m heterofl.main --config-name fedavg --multirun dataset.iid=True,False +# for CIFAR10 dataset +python -m heterofl.main --config-name fedavg --multirun num_rounds=1800 dataset.dataset_name='CIFAR10' dataset.iid=True,False dataset.batch_size.train=50 dataset.batch_size.test=128 model.model_name='CNNCifar' optim_scheduler.lr=0.1 +``` +
+ +Results of the combination of various computation complexity levels for **MNIST** dataset with **dynamic** scenario(where a client does not belong to a fixed computational complexity level): + +| Model | Ratio | Parameters | FLOPS | Space(MB) | IID-accuracy | non-IId local-acc | non-IID global-acc | +| :--: | :----: | :-----: | :-------: | :-------: | :----------: | :---------------: | :----------------: | +| a | 1 | 1556.874 K | 80.504 M | 5.939 | 99.47 | 99.82 | 98.87 | +| a-e | 0.502 | 781.734 K | 40.452 M | 2.982 | 99.49 | 99.86 | 98.9 | +| a-b-c-d-e | 0.267 | 415.807 K | 21.625 M | 1.586 | 99.23 | 99.84 | 98.5 | +| b | 1 | 391.37 K | 20.493 M | 1.493 | 99.54 | 99.81 | 98.81 | +| b-e | 0.508 | 198.982 K | 10.447 M | 0.759 | 99.48 | 99.87 | 98.98 | +| b-c-d-e | 0.334 | 130.54 K | 6.905 M | 0.498 | 99.34 | 99.81 | 98.73 | +| c | 1 | 98.922 K | 5.307 M | 0.377 | 99.37 | 99.64 | 97.14 | +| c-e | 0.628 | 62.098 K | 3.363 M | 0.237 | 99.16 | 99.72 | 97.68 | +| c-d-e | 0.441 | 43.5965 K | 2.375 M | 0.166 | 99.28 | 99.69 | 97.27 | +| d | 1 | 25.274 K | 1.418 M | 0.096 | 99.07 | 99.77 | 97.58 | +| d-e | 0.63 | 15.934 K | 0.909 M | 0.0608 | 99.12 | 99.65 | 97.33 | +| e | 1 | 6.594 K | 0.4005 M | 0.025 | 98.46 | 99.53 | 96.5 | +| FedAvg | 1 | 633.226 K | 1.264128 M | 2.416 | 97.85 | 97.76 | 97.74 | + + +
+ +Results of the combination of various computation complexity levels for **CIFAR10** dataset with **dynamic** scenario(where a client does not belong to a fixed computational complexity level): +> *The HeteroFL paper reports a model with 1.8M parameters for their FedAvg baseline. However, as stated by the paper authors, those results are borrowed from [Liang et al (2020)](https://arxiv.org/abs/2001.01523), which uses a small CNN with fewer parameters (~64K as shown in this table below). We believe the HeteroFL authors made a mistake when reporting the number of parameters. We borrowed the model from Liang et al (2020)'s [repo](https://github.com/pliang279/LG-FedAvg/blob/master/models/Nets.py). As in the paper, FedAvg was run for 1800 rounds.* + + +| Model | Ratio | Parameters | FLOPS | Space(MB) | IID-acc | non-IId local-acc
Final   Best| non-IID global-acc
Final    Best| +| :--: | :----: | :-----: | :-------: | :-------: | :----------: | :-----: | :------: | + a | 1 | 9622 K | 330.2 M | 36.705 | 90.83 | 89.04    92.41 | 48.72    59.29 | + a-e | 0.502 | 4830 K | 165.9 M | 18.426 | 89.98 | 87.98    91.25 | 50.16    57.66 | + a-b-c-d-e | 0.267 | 2565 K | 88.4 M | 9.785 | 87.46 | 89.75    91.19 | 46.96    55.6 | + b | 1 | 2409 K | 83.3 M | 9.189 | 88.59 | 89.31    92.07 | 49.85    60.79 | + b-e | 0.508 | 1224 K | 42.4 M | 4.667 | 89.23 | 90.93    92.3 | 55.46    61.98 | + b-c-d-e | 0.332 | 801 K | 27.9 M | 3.054 | 87.61 | 89.23    91.83 | 51.59    59.4 | + c | 1 | 604 K | 21.2 M | 2.303 | 85.74 | 89.83    91.75 | 44.03    58.26 | + c-e | 0.532 | 321 K | 11.4 M | 1.225 | 87.32 | 89.28    91.56 | 53.43    59.5 | + c-d-e | 0.438 | 265 K | 9.4 M | 1.010 | 85.59 | 91.48    92.05 | 58.26    61.79 | + d | 1 | 152 K | 5.5 M | 0.579 | 82.91 | 90.81    91.47 | 55.95    58.34 | + d-e | 0.626 | 95 K | 3.5 M | 0.363 | 82.77 | 88.79    90.13 | 48.49    54.18 | + e | 1 | 38 K | 1.5 M | 0.146 | 76.53 | 90.05    90.91 | 54.68    57.05 | +|FedAvg | 1 | 64 K| 1.3 M | 0.2446 | 70.65 | 53.12    58.6 | 52.93    58.47 | + + + diff --git a/baselines/heterofl/heterofl/__init__.py b/baselines/heterofl/heterofl/__init__.py new file mode 100644 index 000000000000..a5e567b59135 --- /dev/null +++ b/baselines/heterofl/heterofl/__init__.py @@ -0,0 +1 @@ +"""Template baseline package.""" diff --git a/baselines/heterofl/heterofl/client.py b/baselines/heterofl/heterofl/client.py new file mode 100644 index 000000000000..cf325cb7e85b --- /dev/null +++ b/baselines/heterofl/heterofl/client.py @@ -0,0 +1,133 @@ +"""Defines the MNIST Flower Client and a function to instantiate it.""" + +from typing import Callable, Dict, List, Optional, Tuple + +import flwr as fl +import torch +from flwr.common.typing import NDArrays + +from heterofl.models import create_model, get_parameters, set_parameters, test, train + +# from torch.utils.data import DataLoader + + +class FlowerNumPyClient(fl.client.NumPyClient): + """Standard Flower client for training.""" + + def __init__( + self, + # cid: str, + net: torch.nn.Module, + dataloader, + model_rate: Optional[float], + client_train_settings: Dict, + ): + # self.cid = cid + self.net = net + self.trainloader = dataloader["trainloader"] + self.label_split = dataloader["label_split"] + self.valloader = dataloader["valloader"] + self.model_rate = model_rate + self.client_train_settings = client_train_settings + self.client_train_settings["device"] = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" + ) + # print( + # "Client_with model rate = {} , cid of client = {}".format( + # self.model_rate, self.cid + # ) + # ) + + def get_parameters(self, config) -> NDArrays: + """Return the parameters of the current net.""" + # print(f"[Client {self.cid}] get_parameters") + return get_parameters(self.net) + + def fit(self, parameters, config) -> Tuple[NDArrays, int, Dict]: + """Implement distributed fit function for a given client.""" + # print(f"cid = {self.cid}") + set_parameters(self.net, parameters) + if "lr" in config: + self.client_train_settings["lr"] = config["lr"] + train( + self.net, + self.trainloader, + self.label_split, + self.client_train_settings, + ) + return get_parameters(self.net), len(self.trainloader), {} + + def evaluate(self, parameters, config) -> Tuple[float, int, Dict]: + """Implement distributed evaluation for a given client.""" + set_parameters(self.net, parameters) + loss, accuracy = test( + self.net, self.valloader, device=self.client_train_settings["device"] + ) + return float(loss), len(self.valloader), {"accuracy": float(accuracy)} + + +def gen_client_fn( + model_config: Dict, + client_to_model_rate_mapping: Optional[List[float]], + client_train_settings: Dict, + data_loaders, +) -> Callable[[str], FlowerNumPyClient]: # pylint: disable=too-many-arguments + """Generate the client function that creates the Flower Clients. + + Parameters + ---------- + model_config : Dict + Dict that contains all the information required to + create a model (data_shape , hidden_layers , classes_size...) + client_to_model_rate: List[float] + List tha contains model_rates of clients. + model_rate of client with cid i = client_to_model_rate_mapping[i] + client_train_settings : Dict + Dict that contains information regarding optimizer , lr , + momentum , device required by the client to train + trainloaders: List[DataLoader] + A list of DataLoaders, each pointing to the dataset training partition + belonging to a particular client. + label_split: torch.tensor + A Tensor of tensors that conatins the labels of the partitioned dataset. + label_split of client with cid i = label_split[i] + valloaders: List[DataLoader] + A list of DataLoaders, each pointing to the dataset validation partition + belonging to a particular client. + + Returns + ------- + Callable[[str], FlowerClient] + A tuple containing the client function that creates Flower Clients + """ + + def client_fn(cid: str) -> FlowerNumPyClient: + """Create a Flower client representing a single organization.""" + # Note: each client gets a different trainloader/valloader, so each client + # will train and evaluate on their own unique data + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + client_dataloader = { + "trainloader": data_loaders["trainloaders"][int(cid)], + "valloader": data_loaders["valloaders"][int(cid)], + "label_split": data_loaders["label_split"][int(cid)], + } + # trainloader = data_loaders["trainloaders"][int(cid)] + # valloader = data_loaders["valloaders"][int(cid)] + model_rate = None + if client_to_model_rate_mapping is not None: + model_rate = client_to_model_rate_mapping[int(cid)] + + return FlowerNumPyClient( + # cid=cid, + net=create_model( + model_config, + model_rate=model_rate, + device=device, + ), + dataloader=client_dataloader, + model_rate=model_rate, + client_train_settings=client_train_settings, + ) + + return client_fn diff --git a/baselines/heterofl/heterofl/client_manager_heterofl.py b/baselines/heterofl/heterofl/client_manager_heterofl.py new file mode 100644 index 000000000000..be5b2227a159 --- /dev/null +++ b/baselines/heterofl/heterofl/client_manager_heterofl.py @@ -0,0 +1,207 @@ +"""HeteroFL ClientManager.""" + +import random +import threading +from logging import INFO +from typing import Dict, List, Optional + +import flwr as fl +import torch +from flwr.common.logger import log +from flwr.server.client_proxy import ClientProxy +from flwr.server.criterion import Criterion + +# from heterofl.utils import ModelRateManager + + +class ClientManagerHeteroFL(fl.server.ClientManager): + """Provides a pool of available clients.""" + + def __init__( + self, + model_rate_manager=None, + clients_to_model_rate_mapping=None, + client_label_split: Optional[list[torch.tensor]] = None, + ) -> None: + super().__init__() + self.clients: Dict[str, ClientProxy] = {} + + self.is_simulation = False + if model_rate_manager is not None and clients_to_model_rate_mapping is not None: + self.is_simulation = True + + self.model_rate_manager = model_rate_manager + + # have a common array in simulation to access in the client_fn and server side + if self.is_simulation is True: + self.clients_to_model_rate_mapping = clients_to_model_rate_mapping + ans = self.model_rate_manager.create_model_rate_mapping( + len(clients_to_model_rate_mapping) + ) + # copy self.clients_to_model_rate_mapping , ans + for i, model_rate in enumerate(ans): + self.clients_to_model_rate_mapping[i] = model_rate + + # shall handle in case of not_simulation... + self.client_label_split = client_label_split + + self._cv = threading.Condition() + + def __len__(self) -> int: + """Return the length of clients Dict. + + Returns + ------- + len : int + Length of Dict (self.clients). + """ + return len(self.clients) + + def num_available(self) -> int: + """Return the number of available clients. + + Returns + ------- + num_available : int + The number of currently available clients. + """ + return len(self) + + def wait_for(self, num_clients: int, timeout: int = 86400) -> bool: + """Wait until at least `num_clients` are available. + + Blocks until the requested number of clients is available or until a + timeout is reached. Current timeout default: 1 day. + + Parameters + ---------- + num_clients : int + The number of clients to wait for. + timeout : int + The time in seconds to wait for, defaults to 86400 (24h). + + Returns + ------- + success : bool + """ + with self._cv: + return self._cv.wait_for( + lambda: len(self.clients) >= num_clients, timeout=timeout + ) + + def register(self, client: ClientProxy) -> bool: + """Register Flower ClientProxy instance. + + Parameters + ---------- + client : flwr.server.client_proxy.ClientProxy + + Returns + ------- + success : bool + Indicating if registration was successful. False if ClientProxy is + already registered or can not be registered for any reason. + """ + if client.cid in self.clients: + return False + + self.clients[client.cid] = client + + # in case of not a simulation, this type of method can be used + # if self.is_simulation is False: + # prop = client.get_properties(None, timeout=86400) + # self.clients_to_model_rate_mapping[int(client.cid)] = prop["model_rate"] + # self.client_label_split[int(client.cid)] = prop["label_split"] + + with self._cv: + self._cv.notify_all() + + return True + + def unregister(self, client: ClientProxy) -> None: + """Unregister Flower ClientProxy instance. + + This method is idempotent. + + Parameters + ---------- + client : flwr.server.client_proxy.ClientProxy + """ + if client.cid in self.clients: + del self.clients[client.cid] + + with self._cv: + self._cv.notify_all() + + def all(self) -> Dict[str, ClientProxy]: + """Return all available clients.""" + return self.clients + + def get_client_to_model_mapping(self, cid) -> float: + """Return model rate of client with cid.""" + return self.clients_to_model_rate_mapping[int(cid)] + + def get_all_clients_to_model_mapping(self) -> List[float]: + """Return all available clients to model rate mapping.""" + return self.clients_to_model_rate_mapping.copy() + + def update(self, server_round: int) -> None: + """Update the client to model rate mapping.""" + if self.is_simulation is True: + if ( + server_round == 1 and self.model_rate_manager.model_split_mode == "fix" + ) or (self.model_rate_manager.model_split_mode == "dynamic"): + ans = self.model_rate_manager.create_model_rate_mapping( + self.num_available() + ) + # copy self.clients_to_model_rate_mapping , ans + for i, model_rate in enumerate(ans): + self.clients_to_model_rate_mapping[i] = model_rate + print( + "clients to model rate mapping ", self.clients_to_model_rate_mapping + ) + return + + # to be handled in case of not a simulation, i.e. to get the properties + # again from the clients as they can change the model_rate + # for i in range(self.num_available): + # # need to test this , accumilates the + # # changing model rate of the client + # self.clients_to_model_rate_mapping[i] = + # self.clients[str(i)].get_properties['model_rate'] + # return + + def sample( + self, + num_clients: int, + min_num_clients: Optional[int] = None, + criterion: Optional[Criterion] = None, + ) -> List[ClientProxy]: + """Sample a number of Flower ClientProxy instances.""" + # Block until at least num_clients are connected. + if min_num_clients is None: + min_num_clients = num_clients + self.wait_for(min_num_clients) + # Sample clients which meet the criterion + available_cids = list(self.clients) + if criterion is not None: + available_cids = [ + cid for cid in available_cids if criterion.select(self.clients[cid]) + ] + + if num_clients > len(available_cids): + log( + INFO, + "Sampling failed: number of available clients" + " (%s) is less than number of requested clients (%s).", + len(available_cids), + num_clients, + ) + return [] + + random_indices = torch.randperm(len(available_cids))[:num_clients] + # Use the random indices to select clients + sampled_cids = [available_cids[i] for i in random_indices] + sampled_cids = random.sample(available_cids, num_clients) + print(f"Sampled CIDS = {sampled_cids}") + return [self.clients[cid] for cid in sampled_cids] diff --git a/baselines/heterofl/heterofl/conf/base.yaml b/baselines/heterofl/heterofl/conf/base.yaml new file mode 100644 index 000000000000..42edf419cc38 --- /dev/null +++ b/baselines/heterofl/heterofl/conf/base.yaml @@ -0,0 +1,47 @@ +num_clients: 100 +num_epochs: 5 +num_rounds: 800 +seed: 0 +client_resources: + num_cpus: 1 + num_gpus: 0.08 + +control: + model_split_mode: 'dynamic' + model_mode: 'a1-b1-c1-d1-e1' + +dataset: + dataset_name: 'CIFAR10' + iid: False + shard_per_user : 2 # only used in case of non-iid (i.e. iid = false) + balance: false + batch_size: + train: 10 + test: 50 + shuffle: + train: true + test: false + + +model: + model_name: resnet18 # use 'conv' for MNIST + hidden_layers: [64 , 128 , 256 , 512] + norm: bn + scale: 1 + mask: 1 + + +optim_scheduler: + optimizer: SGD + lr: 0.1 + momentum: 0.9 + weight_decay: 5.00e-04 + scheduler: MultiStepLR + milestones: [300, 500] + +strategy: + _target_: heterofl.strategy.HeteroFL + fraction_fit: 0.1 + fraction_evaluate: 0.1 + min_fit_clients: 10 + min_evaluate_clients: 10 diff --git a/baselines/heterofl/heterofl/conf/fedavg.yaml b/baselines/heterofl/heterofl/conf/fedavg.yaml new file mode 100644 index 000000000000..d67d0950654a --- /dev/null +++ b/baselines/heterofl/heterofl/conf/fedavg.yaml @@ -0,0 +1,41 @@ +num_clients: 100 +num_epochs: 1 +num_rounds: 800 +seed: 0 +clip: False +enable_train_on_train_data_while_testing: False +client_resources: + num_cpus: 1 + num_gpus: 0.4 + +dataset: + dataset_name: 'MNIST' + iid: False + shard_per_user : 2 + balance: False + batch_size: + train: 10 + test: 10 + shuffle: + train: true + test: false + + +model: + model_name: MLP #use CNNCifar for CIFAR10 + +optim_scheduler: + optimizer: SGD + lr: 0.05 + lr_decay_rate: 1.0 + momentum: 0.5 + weight_decay: 0 + scheduler: MultiStepLR + milestones: [] + +strategy: + _target_: flwr.server.strategy.FedAvg + fraction_fit: 0.1 + fraction_evaluate: 0.1 + min_fit_clients: 10 + min_evaluate_clients: 10 diff --git a/baselines/heterofl/heterofl/dataset.py b/baselines/heterofl/heterofl/dataset.py new file mode 100644 index 000000000000..0e0f4b726842 --- /dev/null +++ b/baselines/heterofl/heterofl/dataset.py @@ -0,0 +1,83 @@ +"""Utilities for creation of DataLoaders for clients and server.""" + +from typing import List, Optional, Tuple + +import torch +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from heterofl.dataset_preparation import _partition_data + + +def load_datasets( # pylint: disable=too-many-arguments + strategy_name: str, + config: DictConfig, + num_clients: int, + seed: Optional[int] = 42, +) -> Tuple[ + DataLoader, List[DataLoader], List[torch.tensor], List[DataLoader], DataLoader +]: + """Create the dataloaders to be fed into the model. + + Parameters + ---------- + config: DictConfig + Parameterises the dataset partitioning process + num_clients : int + The number of clients that hold a part of the data + seed : int, optional + Used to set a fix seed to replicate experiments, by default 42 + + Returns + ------- + Tuple[DataLoader, DataLoader, DataLoader, DataLoader] + The entire trainset Dataloader for testing purposes, + The DataLoader for training, the DataLoader for validation, + the DataLoader for testing. + """ + print(f"Dataset partitioning config: {config}") + trainset, datasets, label_split, client_testsets, testset = _partition_data( + num_clients, + dataset_name=config.dataset_name, + strategy_name=strategy_name, + iid=config.iid, + dataset_division={ + "shard_per_user": config.shard_per_user, + "balance": config.balance, + }, + seed=seed, + ) + # Split each partition into train/val and create DataLoader + entire_trainloader = DataLoader( + trainset, batch_size=config.batch_size.train, shuffle=config.shuffle.train + ) + + trainloaders = [] + valloaders = [] + for dataset in datasets: + trainloaders.append( + DataLoader( + dataset, + batch_size=config.batch_size.train, + shuffle=config.shuffle.train, + ) + ) + + for client_testset in client_testsets: + valloaders.append( + DataLoader( + client_testset, + batch_size=config.batch_size.test, + shuffle=config.shuffle.test, + ) + ) + + return ( + entire_trainloader, + trainloaders, + label_split, + valloaders, + DataLoader( + testset, batch_size=config.batch_size.test, shuffle=config.shuffle.test + ), + ) diff --git a/baselines/heterofl/heterofl/dataset_preparation.py b/baselines/heterofl/heterofl/dataset_preparation.py new file mode 100644 index 000000000000..525e815e9e98 --- /dev/null +++ b/baselines/heterofl/heterofl/dataset_preparation.py @@ -0,0 +1,357 @@ +"""Functions for dataset download and processing.""" + +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch.utils.data import ConcatDataset, Dataset, Subset, random_split +from torchvision import transforms + +import heterofl.datasets as dt + + +def _download_data(dataset_name: str, strategy_name: str) -> Tuple[Dataset, Dataset]: + root = "./data/{}".format(dataset_name) + if dataset_name == "MNIST": + trainset = dt.MNIST( + root=root, + split="train", + subset="label", + transform=dt.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + testset = dt.MNIST( + root=root, + split="test", + subset="label", + transform=dt.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + elif dataset_name == "CIFAR10": + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + if strategy_name == "heterofl": + normalize = transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ) + trainset = dt.CIFAR10( + root=root, + split="train", + subset="label", + transform=dt.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ), + ) + testset = dt.CIFAR10( + root=root, + split="test", + subset="label", + transform=dt.Compose( + [ + transforms.ToTensor(), + normalize, + ] + ), + ) + else: + raise ValueError(f"{dataset_name} is not valid") + + return trainset, testset + + +# pylint: disable=too-many-arguments +def _partition_data( + num_clients: int, + dataset_name: str, + strategy_name: str, + iid: Optional[bool] = False, + dataset_division=None, + seed: Optional[int] = 42, +) -> Tuple[Dataset, List[Dataset], List[torch.tensor], List[Dataset], Dataset]: + trainset, testset = _download_data(dataset_name, strategy_name) + + if dataset_name in ("MNIST", "CIFAR10"): + classes_size = 10 + + if dataset_division["balance"]: + trainset = _balance_classes(trainset, seed) + + if iid: + datasets, label_split = iid_partition(trainset, num_clients, seed=seed) + client_testsets, _ = iid_partition(testset, num_clients, seed=seed) + else: + datasets, label_split = non_iid( + {"dataset": trainset, "classes_size": classes_size}, + num_clients, + dataset_division["shard_per_user"], + ) + client_testsets, _ = non_iid( + { + "dataset": testset, + "classes_size": classes_size, + }, + num_clients, + dataset_division["shard_per_user"], + label_split, + ) + + tensor_label_split = [] + for i in label_split: + tensor_label_split.append(torch.Tensor(i)) + label_split = tensor_label_split + + return trainset, datasets, label_split, client_testsets, testset + + +def iid_partition( + dataset: Dataset, num_clients: int, seed: Optional[int] = 42 +) -> Tuple[List[Dataset], List[torch.tensor]]: + """IID partition of dataset among clients.""" + partition_size = int(len(dataset) / num_clients) + lengths = [partition_size] * num_clients + + divided_dataset = random_split( + dataset, lengths, torch.Generator().manual_seed(seed) + ) + label_split = [] + for i in range(num_clients): + label_split.append( + torch.unique(torch.Tensor([target for _, target in divided_dataset[i]])) + ) + + return divided_dataset, label_split + + +def non_iid( + dataset_info, + num_clients: int, + shard_per_user: int, + label_split=None, + seed=42, +) -> Tuple[List[Dataset], List]: + """Non-IID partition of dataset among clients. + + Adopted from authors (of heterofl) implementation. + """ + data_split: Dict[int, List] = {i: [] for i in range(num_clients)} + + label_idx_split, shard_per_class = _split_dataset_targets_idx( + dataset_info["dataset"], + shard_per_user, + num_clients, + dataset_info["classes_size"], + ) + + if label_split is None: + label_split = list(range(dataset_info["classes_size"])) * shard_per_class + label_split = torch.tensor(label_split)[ + torch.randperm( + len(label_split), generator=torch.Generator().manual_seed(seed) + ) + ].tolist() + label_split = np.array(label_split).reshape((num_clients, -1)).tolist() + + for i, _ in enumerate(label_split): + label_split[i] = np.unique(label_split[i]).tolist() + + for i in range(num_clients): + for label_i in label_split[i]: + idx = torch.arange(len(label_idx_split[label_i]))[ + torch.randperm( + len(label_idx_split[label_i]), + generator=torch.Generator().manual_seed(seed), + )[0] + ].item() + data_split[i].extend(label_idx_split[label_i].pop(idx)) + + return ( + _get_dataset_from_idx(dataset_info["dataset"], data_split, num_clients), + label_split, + ) + + +def _split_dataset_targets_idx(dataset, shard_per_user, num_clients, classes_size): + label = np.array(dataset.target) if hasattr(dataset, "target") else dataset.targets + label_idx_split: Dict = {} + for i, _ in enumerate(label): + label_i = label[i].item() + if label_i not in label_idx_split: + label_idx_split[label_i] = [] + label_idx_split[label_i].append(i) + + shard_per_class = int(shard_per_user * num_clients / classes_size) + + for label_i in label_idx_split: + label_idx = label_idx_split[label_i] + num_leftover = len(label_idx) % shard_per_class + leftover = label_idx[-num_leftover:] if num_leftover > 0 else [] + new_label_idx = ( + np.array(label_idx[:-num_leftover]) + if num_leftover > 0 + else np.array(label_idx) + ) + new_label_idx = new_label_idx.reshape((shard_per_class, -1)).tolist() + + for i, leftover_label_idx in enumerate(leftover): + new_label_idx[i] = np.concatenate([new_label_idx[i], [leftover_label_idx]]) + label_idx_split[label_i] = new_label_idx + return label_idx_split, shard_per_class + + +def _get_dataset_from_idx(dataset, data_split, num_clients): + divided_dataset = [None for i in range(num_clients)] + for i in range(num_clients): + divided_dataset[i] = Subset(dataset, data_split[i]) + return divided_dataset + + +def _balance_classes( + trainset: Dataset, + seed: Optional[int] = 42, +) -> Dataset: + class_counts = np.bincount(trainset.target) + targets = torch.Tensor(trainset.target) + smallest = np.min(class_counts) + idxs = targets.argsort() + tmp = [Subset(trainset, idxs[: int(smallest)])] + tmp_targets = [targets[idxs[: int(smallest)]]] + for count in np.cumsum(class_counts): + tmp.append(Subset(trainset, idxs[int(count) : int(count + smallest)])) + tmp_targets.append(targets[idxs[int(count) : int(count + smallest)]]) + unshuffled = ConcatDataset(tmp) + unshuffled_targets = torch.cat(tmp_targets) + shuffled_idxs = torch.randperm( + len(unshuffled), generator=torch.Generator().manual_seed(seed) + ) + shuffled = Subset(unshuffled, shuffled_idxs) + shuffled.targets = unshuffled_targets[shuffled_idxs] + + return shuffled + + +def _sort_by_class( + trainset: Dataset, +) -> Dataset: + class_counts = np.bincount(trainset.targets) + idxs = trainset.targets.argsort() # sort targets in ascending order + + tmp = [] # create subset of smallest class + tmp_targets = [] # same for targets + + start = 0 + for count in np.cumsum(class_counts): + tmp.append( + Subset(trainset, idxs[start : int(count + start)]) + ) # add rest of classes + tmp_targets.append(trainset.targets[idxs[start : int(count + start)]]) + start += count + sorted_dataset = ConcatDataset(tmp) # concat dataset + sorted_dataset.targets = torch.cat(tmp_targets) # concat targets + return sorted_dataset + + +# pylint: disable=too-many-locals, too-many-arguments +def _power_law_split( + sorted_trainset: Dataset, + num_partitions: int, + num_labels_per_partition: int = 2, + min_data_per_partition: int = 10, + mean: float = 0.0, + sigma: float = 2.0, +) -> Dataset: + """Partition the dataset following a power-law distribution. It follows the. + + implementation of Li et al 2020: https://arxiv.org/abs/1812.06127 with default + values set accordingly. + + Parameters + ---------- + sorted_trainset : Dataset + The training dataset sorted by label/class. + num_partitions: int + Number of partitions to create + num_labels_per_partition: int + Number of labels to have in each dataset partition. For + example if set to two, this means all training examples in + a given partition will be long to the same two classes. default 2 + min_data_per_partition: int + Minimum number of datapoints included in each partition, default 10 + mean: float + Mean value for LogNormal distribution to construct power-law, default 0.0 + sigma: float + Sigma value for LogNormal distribution to construct power-law, default 2.0 + + Returns + ------- + Dataset + The partitioned training dataset. + """ + targets = sorted_trainset.targets + full_idx = list(range(len(targets))) + + class_counts = np.bincount(sorted_trainset.targets) + labels_cs = np.cumsum(class_counts) + labels_cs = [0] + labels_cs[:-1].tolist() + + partitions_idx: List[List[int]] = [] + num_classes = len(np.bincount(targets)) + hist = np.zeros(num_classes, dtype=np.int32) + + # assign min_data_per_partition + min_data_per_class = int(min_data_per_partition / num_labels_per_partition) + for u_id in range(num_partitions): + partitions_idx.append([]) + for cls_idx in range(num_labels_per_partition): + # label for the u_id-th client + cls = (u_id + cls_idx) % num_classes + # record minimum data + indices = list( + full_idx[ + labels_cs[cls] + + hist[cls] : labels_cs[cls] + + hist[cls] + + min_data_per_class + ] + ) + partitions_idx[-1].extend(indices) + hist[cls] += min_data_per_class + + # add remaining images following power-law + probs = np.random.lognormal( + mean, + sigma, + (num_classes, int(num_partitions / num_classes), num_labels_per_partition), + ) + remaining_per_class = class_counts - hist + # obtain how many samples each partition should be assigned for each of the + # labels it contains + # pylint: disable=too-many-function-args + probs = ( + remaining_per_class.reshape(-1, 1, 1) + * probs + / np.sum(probs, (1, 2), keepdims=True) + ) + + for u_id in range(num_partitions): + for cls_idx in range(num_labels_per_partition): + cls = (u_id + cls_idx) % num_classes + count = int(probs[cls, u_id // num_classes, cls_idx]) + + # add count of specific class to partition + indices = full_idx[ + labels_cs[cls] + hist[cls] : labels_cs[cls] + hist[cls] + count + ] + partitions_idx[u_id].extend(indices) + hist[cls] += count + + # construct subsets + partitions = [Subset(sorted_trainset, p) for p in partitions_idx] + return partitions diff --git a/baselines/heterofl/heterofl/datasets/__init__.py b/baselines/heterofl/heterofl/datasets/__init__.py new file mode 100644 index 000000000000..91251db77302 --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/__init__.py @@ -0,0 +1,9 @@ +"""Dataset module. + +The entire datasets module is adopted from authors implementation. +""" +from .cifar import CIFAR10 +from .mnist import MNIST +from .utils import Compose + +__all__ = ("MNIST", "CIFAR10", "Compose") diff --git a/baselines/heterofl/heterofl/datasets/cifar.py b/baselines/heterofl/heterofl/datasets/cifar.py new file mode 100644 index 000000000000..c75194bc8ee7 --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/cifar.py @@ -0,0 +1,150 @@ +"""CIFAR10 dataset class, adopted from authors implementation.""" +import os +import pickle + +import anytree +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +from heterofl.datasets.utils import ( + download_url, + extract_file, + make_classes_counts, + make_flat_index, + make_tree, +) +from heterofl.utils import check_exists, load, makedir_exist_ok, save + + +# pylint: disable=too-many-instance-attributes +class CIFAR10(Dataset): + """CIFAR10 dataset.""" + + data_name = "CIFAR10" + file = [ + ( + "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", + "c58f30108f718f92721af3b95e74349a", + ) + ] + + def __init__(self, root, split, subset, transform=None): + self.root = os.path.expanduser(root) + self.split = split + self.subset = subset + self.transform = transform + if not check_exists(self.processed_folder): + self.process() + self.img, self.target = load( + os.path.join(self.processed_folder, "{}.pt".format(self.split)) + ) + self.target = self.target[self.subset] + self.classes_counts = make_classes_counts(self.target) + self.classes_to_labels, self.classes_size = load( + os.path.join(self.processed_folder, "meta.pt") + ) + self.classes_to_labels, self.classes_size = ( + self.classes_to_labels[self.subset], + self.classes_size[self.subset], + ) + + def __getitem__(self, index): + """Get the item with index.""" + img, target = Image.fromarray(self.img[index]), torch.tensor(self.target[index]) + inp = {"img": img, self.subset: target} + if self.transform is not None: + inp = self.transform(inp) + return inp["img"], inp["label"] + + def __len__(self): + """Length of the dataset.""" + return len(self.img) + + @property + def processed_folder(self): + """Return path of processed folder.""" + return os.path.join(self.root, "processed") + + @property + def raw_folder(self): + """Return path of raw folder.""" + return os.path.join(self.root, "raw") + + def process(self): + """Save the dataset accordingly.""" + if not check_exists(self.raw_folder): + self.download() + train_set, test_set, meta = self.make_data() + save(train_set, os.path.join(self.processed_folder, "train.pt")) + save(test_set, os.path.join(self.processed_folder, "test.pt")) + save(meta, os.path.join(self.processed_folder, "meta.pt")) + + def download(self): + """Download dataset from the url.""" + makedir_exist_ok(self.raw_folder) + for url, md5 in self.file: + filename = os.path.basename(url) + download_url(url, self.raw_folder, filename, md5) + extract_file(os.path.join(self.raw_folder, filename)) + + def __repr__(self): + """Represent CIFAR10 as string.""" + fmt_str = ( + f"Dataset {self.__class__.__name__}\nSize: {self.__len__()}\n" + f"Root: {self.root}\nSplit: {self.split}\nSubset: {self.subset}\n" + f"Transforms: {self.transform.__repr__()}" + ) + return fmt_str + + def make_data(self): + """Make data.""" + train_filenames = [ + "data_batch_1", + "data_batch_2", + "data_batch_3", + "data_batch_4", + "data_batch_5", + ] + test_filenames = ["test_batch"] + train_img, train_label = _read_pickle_file( + os.path.join(self.raw_folder, "cifar-10-batches-py"), train_filenames + ) + test_img, test_label = _read_pickle_file( + os.path.join(self.raw_folder, "cifar-10-batches-py"), test_filenames + ) + train_target, test_target = {"label": train_label}, {"label": test_label} + with open( + os.path.join(self.raw_folder, "cifar-10-batches-py", "batches.meta"), "rb" + ) as fle: + data = pickle.load(fle, encoding="latin1") + classes = data["label_names"] + classes_to_labels = {"label": anytree.Node("U", index=[])} + for cls in classes: + make_tree(classes_to_labels["label"], [cls]) + classes_size = {"label": make_flat_index(classes_to_labels["label"])} + return ( + (train_img, train_target), + (test_img, test_target), + (classes_to_labels, classes_size), + ) + + +def _read_pickle_file(path, filenames): + img, label = [], [] + for filename in filenames: + file_path = os.path.join(path, filename) + with open(file_path, "rb") as file: + entry = pickle.load(file, encoding="latin1") + img.append(entry["data"]) + if "labels" in entry: + label.extend(entry["labels"]) + else: + label.extend(entry["fine_labels"]) + # label.extend(entry["labels"]) if "labels" in entry else label.extend( + # entry["fine_labels"] + # ) + img = np.vstack(img).reshape(-1, 3, 32, 32) + img = img.transpose((0, 2, 3, 1)) + return img, label diff --git a/baselines/heterofl/heterofl/datasets/mnist.py b/baselines/heterofl/heterofl/datasets/mnist.py new file mode 100644 index 000000000000..feae2ea987b4 --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/mnist.py @@ -0,0 +1,167 @@ +"""MNIST dataset class, adopted from authors implementation.""" +import codecs +import os + +import anytree +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +from heterofl.datasets.utils import ( + download_url, + extract_file, + make_classes_counts, + make_flat_index, + make_tree, +) +from heterofl.utils import check_exists, load, makedir_exist_ok, save + + +# pylint: disable=too-many-instance-attributes +class MNIST(Dataset): + """MNIST dataset.""" + + data_name = "MNIST" + file = [ + ( + "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", + "f68b3c2dcbeaaa9fbdd348bbdeb94873", + ), + ( + "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", + "9fb629c4189551a2d022fa330f9573f3", + ), + ( + "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", + "d53e105ee54ea40749a09fcbcd1e9432", + ), + ( + "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", + "ec29112dd5afa0611ce80d1b7f02629c", + ), + ] + + def __init__(self, root, split, subset, transform=None): + self.root = os.path.expanduser(root) + self.split = split + self.subset = subset + self.transform = transform + if not check_exists(self.processed_folder): + self.process() + self.img, self.target = load( + os.path.join(self.processed_folder, "{}.pt".format(self.split)) + ) + self.target = self.target[self.subset] + self.classes_counts = make_classes_counts(self.target) + self.classes_to_labels, self.classes_size = load( + os.path.join(self.processed_folder, "meta.pt") + ) + self.classes_to_labels, self.classes_size = ( + self.classes_to_labels[self.subset], + self.classes_size[self.subset], + ) + + def __getitem__(self, index): + """Get the item with index.""" + img, target = Image.fromarray(self.img[index]), torch.tensor(self.target[index]) + inp = {"img": img, self.subset: target} + if self.transform is not None: + inp = self.transform(inp) + return inp["img"], inp["label"] + + def __len__(self): + """Length of the dataset.""" + return len(self.img) + + @property + def processed_folder(self): + """Return path of processed folder.""" + return os.path.join(self.root, "processed") + + @property + def raw_folder(self): + """Return path of raw folder.""" + return os.path.join(self.root, "raw") + + def process(self): + """Save the dataset accordingly.""" + if not check_exists(self.raw_folder): + self.download() + train_set, test_set, meta = self.make_data() + save(train_set, os.path.join(self.processed_folder, "train.pt")) + save(test_set, os.path.join(self.processed_folder, "test.pt")) + save(meta, os.path.join(self.processed_folder, "meta.pt")) + + def download(self): + """Download and save the dataset accordingly.""" + makedir_exist_ok(self.raw_folder) + for url, md5 in self.file: + filename = os.path.basename(url) + download_url(url, self.raw_folder, filename, md5) + extract_file(os.path.join(self.raw_folder, filename)) + + def __repr__(self): + """Represent CIFAR10 as string.""" + fmt_str = ( + f"Dataset {self.__class__.__name__}\nSize: {self.__len__()}\n" + f"Root: {self.root}\nSplit: {self.split}\nSubset: {self.subset}\n" + f"Transforms: {self.transform.__repr__()}" + ) + return fmt_str + + def make_data(self): + """Make data.""" + train_img = _read_image_file( + os.path.join(self.raw_folder, "train-images-idx3-ubyte") + ) + test_img = _read_image_file( + os.path.join(self.raw_folder, "t10k-images-idx3-ubyte") + ) + train_label = _read_label_file( + os.path.join(self.raw_folder, "train-labels-idx1-ubyte") + ) + test_label = _read_label_file( + os.path.join(self.raw_folder, "t10k-labels-idx1-ubyte") + ) + train_target, test_target = {"label": train_label}, {"label": test_label} + classes_to_labels = {"label": anytree.Node("U", index=[])} + classes = list(map(str, list(range(10)))) + for cls in classes: + make_tree(classes_to_labels["label"], [cls]) + classes_size = {"label": make_flat_index(classes_to_labels["label"])} + return ( + (train_img, train_target), + (test_img, test_target), + (classes_to_labels, classes_size), + ) + + +def _get_int(num): + return int(codecs.encode(num, "hex"), 16) + + +def _read_image_file(path): + with open(path, "rb") as file: + data = file.read() + assert _get_int(data[:4]) == 2051 + length = _get_int(data[4:8]) + num_rows = _get_int(data[8:12]) + num_cols = _get_int(data[12:16]) + parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape( + (length, num_rows, num_cols) + ) + return parsed + + +def _read_label_file(path): + with open(path, "rb") as file: + data = file.read() + assert _get_int(data[:4]) == 2049 + length = _get_int(data[4:8]) + parsed = ( + np.frombuffer(data, dtype=np.uint8, offset=8) + .reshape(length) + .astype(np.int64) + ) + return parsed diff --git a/baselines/heterofl/heterofl/datasets/utils.py b/baselines/heterofl/heterofl/datasets/utils.py new file mode 100644 index 000000000000..6b71811ed50d --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/utils.py @@ -0,0 +1,244 @@ +"""Contains utility functions required for datasests. + +Adopted from authors implementation. +""" +import glob +import gzip +import hashlib +import os +import tarfile +import zipfile +from collections import Counter + +import anytree +import numpy as np +from PIL import Image +from six.moves import urllib +from tqdm import tqdm + +from heterofl.utils import makedir_exist_ok + +IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif"] + + +def find_classes(drctry): + """Find the classes in a directory.""" + classes = [d.name for d in os.scandir(drctry) if d.is_dir()] + classes.sort() + classes_to_labels = {classes[i]: i for i in range(len(classes))} + return classes_to_labels + + +def pil_loader(path): + """Load image from path using PIL.""" + with open(path, "rb") as file: + img = Image.open(file) + return img.convert("RGB") + + +# def accimage_loader(path): +# """Load image from path using accimage_loader.""" +# import accimage + +# try: +# return accimage.Image(path) +# except IOError: +# return pil_loader(path) + + +def default_loader(path): + """Load image from path using default loader.""" + # if get_image_backend() == "accimage": + # return accimage_loader(path) + + return pil_loader(path) + + +def has_file_allowed_extension(filename, extensions): + """Check whether file possesses any of the extensions listed.""" + filename_lower = filename.lower() + return any(filename_lower.endswith(ext) for ext in extensions) + + +def make_classes_counts(label): + """Count number of classes.""" + label = np.array(label) + if label.ndim > 1: + label = label.sum(axis=tuple(range(1, label.ndim))) + classes_counts = Counter(label) + return classes_counts + + +def _make_bar_updater(pbar): + def bar_update(count, block_size, total_size): + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + + +def _calculate_md5(path, chunk_size=1024 * 1024): + md5 = hashlib.md5() + with open(path, "rb") as file: + for chunk in iter(lambda: file.read(chunk_size), b""): + md5.update(chunk) + return md5.hexdigest() + + +def _check_md5(path, md5, **kwargs): + return md5 == _calculate_md5(path, **kwargs) + + +def _check_integrity(path, md5=None): + if not os.path.isfile(path): + return False + if md5 is None: + return True + return _check_md5(path, md5) + + +def download_url(url, root, filename, md5): + """Download files from the url.""" + path = os.path.join(root, filename) + makedir_exist_ok(root) + if os.path.isfile(path) and _check_integrity(path, md5): + print("Using downloaded and verified file: " + path) + else: + try: + print("Downloading " + url + " to " + path) + urllib.request.urlretrieve( + url, path, reporthook=_make_bar_updater(tqdm(unit="B", unit_scale=True)) + ) + except OSError: + if url[:5] == "https": + url = url.replace("https:", "http:") + print( + "Failed download. Trying https -> http instead." + " Downloading " + url + " to " + path + ) + urllib.request.urlretrieve( + url, + path, + reporthook=_make_bar_updater(tqdm(unit="B", unit_scale=True)), + ) + if not _check_integrity(path, md5): + raise RuntimeError("Not valid downloaded file") + + +def extract_file(src, dest=None, delete=False): + """Extract the file.""" + print("Extracting {}".format(src)) + dest = os.path.dirname(src) if dest is None else dest + filename = os.path.basename(src) + if filename.endswith(".zip"): + with zipfile.ZipFile(src, "r") as zip_f: + zip_f.extractall(dest) + elif filename.endswith(".tar"): + with tarfile.open(src) as tar_f: + tar_f.extractall(dest) + elif filename.endswith(".tar.gz") or filename.endswith(".tgz"): + with tarfile.open(src, "r:gz") as tar_f: + tar_f.extractall(dest) + elif filename.endswith(".gz"): + with open(src.replace(".gz", ""), "wb") as out_f, gzip.GzipFile(src) as zip_f: + out_f.write(zip_f.read()) + if delete: + os.remove(src) + + +def make_data(root, extensions): + """Get all the files in the root directory that follows the given extensions.""" + path = [] + files = glob.glob("{}/**/*".format(root), recursive=True) + for file in files: + if has_file_allowed_extension(file, extensions): + path.append(os.path.normpath(file)) + return path + + +# pylint: disable=dangerous-default-value +def make_img(path, classes_to_labels, extensions=IMG_EXTENSIONS): + """Make image.""" + img, label = [], [] + classes = [] + leaf_nodes = classes_to_labels.leaves + for node in leaf_nodes: + classes.append(node.name) + for cls in sorted(classes): + folder = os.path.join(path, cls) + if not os.path.isdir(folder): + continue + for root, _, filenames in sorted(os.walk(folder)): + for filename in sorted(filenames): + if has_file_allowed_extension(filename, extensions): + cur_path = os.path.join(root, filename) + img.append(cur_path) + label.append( + anytree.find_by_attr(classes_to_labels, cls).flat_index + ) + return img, label + + +def make_tree(root, name, attribute=None): + """Create a tree of name.""" + if len(name) == 0: + return + if attribute is None: + attribute = {} + this_name = name[0] + next_name = name[1:] + this_attribute = {k: attribute[k][0] for k in attribute} + next_attribute = {k: attribute[k][1:] for k in attribute} + this_node = anytree.find_by_attr(root, this_name) + this_index = root.index + [len(root.children)] + if this_node is None: + this_node = anytree.Node( + this_name, parent=root, index=this_index, **this_attribute + ) + make_tree(this_node, next_name, next_attribute) + return + + +def make_flat_index(root, given=None): + """Make flat index for each leaf node in the tree.""" + if given: + classes_size = 0 + for node in anytree.PreOrderIter(root): + if len(node.children) == 0: + node.flat_index = given.index(node.name) + classes_size = ( + given.index(node.name) + 1 + if given.index(node.name) + 1 > classes_size + else classes_size + ) + else: + classes_size = 0 + for node in anytree.PreOrderIter(root): + if len(node.children) == 0: + node.flat_index = classes_size + classes_size += 1 + return classes_size + + +class Compose: + """Custom Compose class.""" + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, inp): + """Apply transforms when called.""" + for transform in self.transforms: + inp["img"] = transform(inp["img"]) + return inp + + def __repr__(self): + """Represent Compose as string.""" + format_string = self.__class__.__name__ + "(" + for transform in self.transforms: + format_string += "\n" + format_string += " {0}".format(transform) + format_string += "\n)" + return format_string diff --git a/baselines/heterofl/heterofl/main.py b/baselines/heterofl/heterofl/main.py new file mode 100644 index 000000000000..3973841cb60e --- /dev/null +++ b/baselines/heterofl/heterofl/main.py @@ -0,0 +1,204 @@ +"""Runs federated learning for given configuration in base.yaml.""" +import pickle +from pathlib import Path + +import flwr as fl +import hydra +import torch +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +from heterofl import client, models, server +from heterofl.client_manager_heterofl import ClientManagerHeteroFL +from heterofl.dataset import load_datasets +from heterofl.model_properties import get_model_properties +from heterofl.utils import ModelRateManager, get_global_model_rate, preprocess_input + + +# pylint: disable=too-many-locals,protected-access +@hydra.main(config_path="conf", config_name="base.yaml", version_base=None) +def main(cfg: DictConfig) -> None: + """Run the baseline. + + Parameters + ---------- + cfg : DictConfig + An omegaconf object that stores the hydra config. + """ + # print config structured as YAML + print(OmegaConf.to_yaml(cfg)) + torch.manual_seed(cfg.seed) + + data_loaders = {} + + ( + data_loaders["entire_trainloader"], + data_loaders["trainloaders"], + data_loaders["label_split"], + data_loaders["valloaders"], + data_loaders["testloader"], + ) = load_datasets( + "heterofl" if "heterofl" in cfg.strategy._target_ else "fedavg", + config=cfg.dataset, + num_clients=cfg.num_clients, + seed=cfg.seed, + ) + + model_config = preprocess_input(cfg.model, cfg.dataset) + + model_split_rate = None + model_mode = None + client_to_model_rate_mapping = None + model_rate_manager = None + history = None + + if "HeteroFL" in cfg.strategy._target_: + # send this array(client_model_rate_mapping) as + # an argument to client_manager and client + model_split_rate = {"a": 1, "b": 0.5, "c": 0.25, "d": 0.125, "e": 0.0625} + # model_split_mode = cfg.control.model_split_mode + model_mode = cfg.control.model_mode + + client_to_model_rate_mapping = [float(0) for _ in range(cfg.num_clients)] + model_rate_manager = ModelRateManager( + cfg.control.model_split_mode, model_split_rate, model_mode + ) + + model_config["global_model_rate"] = model_split_rate[ + get_global_model_rate(model_mode) + ] + + test_model = models.create_model( + model_config, + model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else None, + track=True, + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + ) + + get_model_properties( + model_config, + model_split_rate, + model_mode + "" if model_mode is not None else None, + data_loaders["entire_trainloader"], + cfg.dataset.batch_size.train, + ) + + # prepare function that will be used to spawn each client + client_train_settings = { + "epochs": cfg.num_epochs, + "optimizer": cfg.optim_scheduler.optimizer, + "lr": cfg.optim_scheduler.lr, + "momentum": cfg.optim_scheduler.momentum, + "weight_decay": cfg.optim_scheduler.weight_decay, + "scheduler": cfg.optim_scheduler.scheduler, + "milestones": cfg.optim_scheduler.milestones, + } + + if "clip" in cfg: + client_train_settings["clip"] = cfg.clip + + optim_scheduler_settings = { + "optimizer": cfg.optim_scheduler.optimizer, + "lr": cfg.optim_scheduler.lr, + "momentum": cfg.optim_scheduler.momentum, + "weight_decay": cfg.optim_scheduler.weight_decay, + "scheduler": cfg.optim_scheduler.scheduler, + "milestones": cfg.optim_scheduler.milestones, + } + + client_fn = client.gen_client_fn( + model_config=model_config, + client_to_model_rate_mapping=client_to_model_rate_mapping, + client_train_settings=client_train_settings, + data_loaders=data_loaders, + ) + + evaluate_fn = server.gen_evaluate_fn( + data_loaders, + torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + test_model, + models.create_model( + model_config, + model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else None, + track=False, + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + ) + .state_dict() + .keys(), + enable_train_on_train_data=cfg.enable_train_on_train_data_while_testing + if "enable_train_on_train_data_while_testing" in cfg + else True, + ) + client_resources = { + "num_cpus": cfg.client_resources.num_cpus, + "num_gpus": cfg.client_resources.num_gpus if torch.cuda.is_available() else 0, + } + + if "HeteroFL" in cfg.strategy._target_: + strategy_heterofl = instantiate( + cfg.strategy, + model_name=cfg.model.model_name, + net=models.create_model( + model_config, + model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else None, + device="cpu", + ), + optim_scheduler_settings=optim_scheduler_settings, + global_model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else 1.0, + evaluate_fn=evaluate_fn, + min_available_clients=cfg.num_clients, + ) + + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + client_resources=client_resources, + client_manager=ClientManagerHeteroFL( + model_rate_manager, + client_to_model_rate_mapping, + client_label_split=data_loaders["label_split"], + ), + strategy=strategy_heterofl, + ) + else: + strategy_fedavg = instantiate( + cfg.strategy, + # on_fit_config_fn=lambda server_round: { + # "lr": cfg.optim_scheduler.lr + # * pow(cfg.optim_scheduler.lr_decay_rate, server_round) + # }, + evaluate_fn=evaluate_fn, + min_available_clients=cfg.num_clients, + ) + + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + client_resources=client_resources, + strategy=strategy_fedavg, + ) + + # save the results + save_path = HydraConfig.get().runtime.output_dir + + # save the results as a python pickle + with open(str(Path(save_path) / "results.pkl"), "wb") as file_handle: + pickle.dump({"history": history}, file_handle, protocol=pickle.HIGHEST_PROTOCOL) + + # save the model + torch.save(test_model.state_dict(), str(Path(save_path) / "model.pth")) + + +if __name__ == "__main__": + main() diff --git a/baselines/heterofl/heterofl/model_properties.py b/baselines/heterofl/heterofl/model_properties.py new file mode 100644 index 000000000000..0739fe4fde22 --- /dev/null +++ b/baselines/heterofl/heterofl/model_properties.py @@ -0,0 +1,123 @@ +"""Determine number of model parameters, space it requires.""" +import numpy as np +import torch +import torch.nn as nn + +from heterofl.models import create_model + + +def get_model_properties( + model_config, model_split_rate, model_mode, data_loader, batch_size +): + """Calculate space occupied & number of parameters of model.""" + model_mode = model_mode.split("-") if model_mode is not None else None + # model = create_model(model_config, model_rate=model_split_rate(i[0])) + + total_flops = 0 + total_model_parameters = 0 + ttl_prcntg = 0 + if model_mode is None: + total_flops = _calculate_model_memory(create_model(model_config), data_loader) + total_model_parameters = _count_parameters(create_model(model_config)) + else: + for i in model_mode: + total_flops += _calculate_model_memory( + create_model(model_config, model_rate=model_split_rate[i[0]]), + data_loader, + ) * int(i[1]) + total_model_parameters += _count_parameters( + create_model(model_config, model_rate=model_split_rate[i[0]]) + ) * int(i[1]) + ttl_prcntg += int(i[1]) + + total_flops = total_flops / ttl_prcntg if ttl_prcntg != 0 else total_flops + total_flops /= batch_size + total_model_parameters = ( + total_model_parameters / ttl_prcntg + if ttl_prcntg != 0 + else total_model_parameters + ) + + space = total_model_parameters * 32.0 / 8 / (1024**2.0) + print("num_of_parameters = ", total_model_parameters / 1000, " K") + print("total_flops = ", total_flops / 1000000, " M") + print("space = ", space) + + return total_model_parameters, total_flops, space + + +def _calculate_model_memory(model, data_loader): + def register_hook(module): + def hook(module, inp, output): + # temp = _make_flops(module, inp, output) + # print(temp) + for _ in module.named_parameters(): + flops.append(_make_flops(module, inp, output)) + + if ( + not isinstance(module, nn.Sequential) + and not isinstance(module, nn.ModuleList) + and not isinstance(module, nn.ModuleDict) + and module != model + ): + hooks.append(module.register_forward_hook(hook)) + + hooks = [] + flops = [] + model.apply(register_hook) + + one_dl = next(iter(data_loader)) + input_dict = {"img": one_dl[0], "label": one_dl[1]} + with torch.no_grad(): + model(input_dict) + + for hook in hooks: + hook.remove() + + return sum(fl for fl in flops) + + +def _count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def _make_flops(module, inp, output): + if isinstance(inp, tuple): + return _make_flops(module, inp[0], output) + if isinstance(output, tuple): + return _make_flops(module, inp, output[0]) + flops = _compute_flops(module, inp, output) + return flops + + +def _compute_flops(module, inp, out): + flops = 0 + if isinstance(module, nn.Conv2d): + flops = _compute_conv2d_flops(module, inp, out) + elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)): + flops = np.prod(inp.shape).item() + if isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)) and module.affine: + flops *= 2 + elif isinstance(module, nn.Linear): + flops = np.prod(inp.size()[:-1]).item() * inp.size()[-1] * out.size()[-1] + # else: + # print(f"[Flops]: {type(module).__name__} is not supported!") + return flops + + +def _compute_conv2d_flops(module, inp, out): + batch_size = inp.size()[0] + in_c = inp.size()[1] + out_c, out_h, out_w = out.size()[1:] + groups = module.groups + filters_per_channel = out_c // groups + conv_per_position_flops = ( + module.kernel_size[0] * module.kernel_size[1] * in_c * filters_per_channel + ) + active_elements_count = batch_size * out_h * out_w + total_conv_flops = conv_per_position_flops * active_elements_count + bias_flops = 0 + if module.bias is not None: + bias_flops = out_c * active_elements_count + total_flops = total_conv_flops + bias_flops + return total_flops diff --git a/baselines/heterofl/heterofl/models.py b/baselines/heterofl/heterofl/models.py new file mode 100644 index 000000000000..9426ee8b2789 --- /dev/null +++ b/baselines/heterofl/heterofl/models.py @@ -0,0 +1,839 @@ +"""Conv & resnet18 model architecture, training, testing functions. + +Classes Conv, Block, Resnet18 are adopted from authors implementation. +""" +import copy +from typing import List, OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F +from flwr.common import parameters_to_ndarrays +from torch import nn + +from heterofl.utils import make_optimizer + + +class Conv(nn.Module): + """Convolutional Neural Network architecture with sBN.""" + + def __init__( + self, + model_config, + ): + super().__init__() + self.model_config = model_config + + blocks = [ + nn.Conv2d( + model_config["data_shape"][0], model_config["hidden_size"][0], 3, 1, 1 + ), + self._get_scale(), + self._get_norm(0), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + ] + for i in range(len(model_config["hidden_size"]) - 1): + blocks.extend( + [ + nn.Conv2d( + model_config["hidden_size"][i], + model_config["hidden_size"][i + 1], + 3, + 1, + 1, + ), + self._get_scale(), + self._get_norm(i + 1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + ] + ) + blocks = blocks[:-1] + blocks.extend( + [ + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.Linear( + model_config["hidden_size"][-1], model_config["classes_size"] + ), + ] + ) + self.blocks = nn.Sequential(*blocks) + + def _get_norm(self, j: int): + """Return the relavant norm.""" + if self.model_config["norm"] == "bn": + norm = nn.BatchNorm2d( + self.model_config["hidden_size"][j], + momentum=None, + track_running_stats=self.model_config["track"], + ) + elif self.model_config["norm"] == "in": + norm = nn.GroupNorm( + self.model_config["hidden_size"][j], self.model_config["hidden_size"][j] + ) + elif self.model_config["norm"] == "ln": + norm = nn.GroupNorm(1, self.model_config["hidden_size"][j]) + elif self.model_config["norm"] == "gn": + norm = nn.GroupNorm(4, self.model_config["hidden_size"][j]) + elif self.model_config["norm"] == "none": + norm = nn.Identity() + else: + raise ValueError("Not valid norm") + + return norm + + def _get_scale(self): + """Return the relavant scaler.""" + if self.model_config["scale"]: + scaler = _Scaler(self.model_config["rate"]) + else: + scaler = nn.Identity() + return scaler + + def forward(self, input_dict): + """Forward pass of the Conv. + + Parameters + ---------- + input_dict : Dict + Conatins input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + # output = {"loss": torch.tensor(0, device=self.device, dtype=torch.float32)} + output = {} + out = self.blocks(input_dict["img"]) + if "label_split" in input_dict and self.model_config["mask"]: + label_mask = torch.zeros( + self.model_config["classes_size"], device=out.device + ) + label_mask[input_dict["label_split"]] = 1 + out = out.masked_fill(label_mask == 0, 0) + output["score"] = out + output["loss"] = F.cross_entropy(out, input_dict["label"], reduction="mean") + return output + + +def conv( + model_rate, + model_config, + device="cpu", +): + """Create the Conv model.""" + model_config["hidden_size"] = [ + int(np.ceil(model_rate * x)) for x in model_config["hidden_layers"] + ] + scaler_rate = model_rate / model_config["global_model_rate"] + model_config["rate"] = scaler_rate + model = Conv(model_config) + model.apply(_init_param) + return model.to(device) + + +class Block(nn.Module): + """Block.""" + + expansion = 1 + + def __init__(self, in_planes, planes, stride, model_config): + super().__init__() + if model_config["norm"] == "bn": + n_1 = nn.BatchNorm2d( + in_planes, momentum=None, track_running_stats=model_config["track"] + ) + n_2 = nn.BatchNorm2d( + planes, momentum=None, track_running_stats=model_config["track"] + ) + elif model_config["norm"] == "in": + n_1 = nn.GroupNorm(in_planes, in_planes) + n_2 = nn.GroupNorm(planes, planes) + elif model_config["norm"] == "ln": + n_1 = nn.GroupNorm(1, in_planes) + n_2 = nn.GroupNorm(1, planes) + elif model_config["norm"] == "gn": + n_1 = nn.GroupNorm(4, in_planes) + n_2 = nn.GroupNorm(4, planes) + elif model_config["norm"] == "none": + n_1 = nn.Identity() + n_2 = nn.Identity() + else: + raise ValueError("Not valid norm") + self.n_1 = n_1 + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.n_2 = n_2 + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) + if model_config["scale"]: + self.scaler = _Scaler(model_config["rate"]) + else: + self.scaler = nn.Identity() + + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ) + + def forward(self, x): + """Forward pass of the Block. + + Parameters + ---------- + x : Dict + Dict that contains Input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + out = F.relu(self.n_1(self.scaler(x))) + shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x + out = self.conv1(out) + out = self.conv2(F.relu(self.n_2(self.scaler(out)))) + out += shortcut + return out + + +# pylint: disable=too-many-instance-attributes +class ResNet(nn.Module): + """Implementation of a Residual Neural Network (ResNet) model with sBN.""" + + def __init__( + self, + model_config, + block, + num_blocks, + ): + self.model_config = model_config + super().__init__() + self.in_planes = model_config["hidden_size"][0] + self.conv1 = nn.Conv2d( + model_config["data_shape"][0], + model_config["hidden_size"][0], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + + self.layer1 = self._make_layer( + block, + model_config["hidden_size"][0], + num_blocks[0], + stride=1, + ) + self.layer2 = self._make_layer( + block, + model_config["hidden_size"][1], + num_blocks[1], + stride=2, + ) + self.layer3 = self._make_layer( + block, + model_config["hidden_size"][2], + num_blocks[2], + stride=2, + ) + self.layer4 = self._make_layer( + block, + model_config["hidden_size"][3], + num_blocks[3], + stride=2, + ) + + # self.layers = [layer1, layer2, layer3, layer4] + + if model_config["norm"] == "bn": + n_4 = nn.BatchNorm2d( + model_config["hidden_size"][3] * block.expansion, + momentum=None, + track_running_stats=model_config["track"], + ) + elif model_config["norm"] == "in": + n_4 = nn.GroupNorm( + model_config["hidden_size"][3] * block.expansion, + model_config["hidden_size"][3] * block.expansion, + ) + elif model_config["norm"] == "ln": + n_4 = nn.GroupNorm(1, model_config["hidden_size"][3] * block.expansion) + elif model_config["norm"] == "gn": + n_4 = nn.GroupNorm(4, model_config["hidden_size"][3] * block.expansion) + elif model_config["norm"] == "none": + n_4 = nn.Identity() + else: + raise ValueError("Not valid norm") + self.n_4 = n_4 + if model_config["scale"]: + self.scaler = _Scaler(model_config["rate"]) + else: + self.scaler = nn.Identity() + self.linear = nn.Linear( + model_config["hidden_size"][3] * block.expansion, + model_config["classes_size"], + ) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for strd in strides: + layers.append(block(self.in_planes, planes, strd, self.model_config.copy())) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, input_dict): + """Forward pass of the ResNet. + + Parameters + ---------- + input_dict : Dict + Dict that contains Input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + output = {} + x = input_dict["img"] + out = self.conv1(x) + # for layer in self.layers: + # out = layer(out) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.relu(self.n_4(self.scaler(out))) + out = F.adaptive_avg_pool2d(out, 1) + out = out.view(out.size(0), -1) + out = self.linear(out) + if "label_split" in input_dict and self.model_config["mask"]: + label_mask = torch.zeros( + self.model_config["classes_size"], device=out.device + ) + label_mask[input_dict["label_split"]] = 1 + out = out.masked_fill(label_mask == 0, 0) + output["score"] = out + output["loss"] = F.cross_entropy(output["score"], input_dict["label"]) + return output + + +def resnet18( + model_rate, + model_config, + device="cpu", +): + """Create the ResNet18 model.""" + model_config["hidden_size"] = [ + int(np.ceil(model_rate * x)) for x in model_config["hidden_layers"] + ] + scaler_rate = model_rate / model_config["global_model_rate"] + model_config["rate"] = scaler_rate + model = ResNet(model_config, block=Block, num_blocks=[1, 1, 1, 2]) + model.apply(_init_param) + return model.to(device) + + +class MLP(nn.Module): + """Multi Layer Perceptron.""" + + def __init__(self): + super().__init__() + self.layer_input = nn.Linear(784, 512) + self.relu = nn.ReLU() + self.dropout = nn.Dropout() + self.layer_hidden1 = nn.Linear(512, 256) + self.layer_hidden2 = nn.Linear(256, 256) + self.layer_hidden3 = nn.Linear(256, 128) + self.layer_out = nn.Linear(128, 10) + self.softmax = nn.Softmax(dim=1) + self.weight_keys = [ + ["layer_input.weight", "layer_input.bias"], + ["layer_hidden1.weight", "layer_hidden1.bias"], + ["layer_hidden2.weight", "layer_hidden2.bias"], + ["layer_hidden3.weight", "layer_hidden3.bias"], + ["layer_out.weight", "layer_out.bias"], + ] + + def forward(self, input_dict): + """Forward pass of the Conv. + + Parameters + ---------- + input_dict : Dict + Conatins input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + output = {} + x = input_dict["img"] + x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1]) + x = self.layer_input(x) + x = self.relu(x) + + x = self.layer_hidden1(x) + x = self.relu(x) + + x = self.layer_hidden2(x) + x = self.relu(x) + + x = self.layer_hidden3(x) + x = self.relu(x) + + x = self.layer_out(x) + out = self.softmax(x) + output["score"] = out + output["loss"] = F.cross_entropy(out, input_dict["label"], reduction="mean") + return output + + +class CNNCifar(nn.Module): + """Convolutional Neural Network architecture for cifar dataset.""" + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 100) + self.fc3 = nn.Linear(100, 10) + + self.weight_keys = [ + ["fc1.weight", "fc1.bias"], + ["fc2.weight", "fc2.bias"], + ["fc3.weight", "fc3.bias"], + ["conv2.weight", "conv2.bias"], + ["conv1.weight", "conv1.bias"], + ] + + def forward(self, input_dict): + """Forward pass of the Conv. + + Parameters + ---------- + input_dict : Dict + Conatins input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + output = {} + x = input_dict["img"] + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + out = F.log_softmax(x, dim=1) + output["score"] = out + output["loss"] = F.cross_entropy(out, input_dict["label"], reduction="mean") + return output + + +def create_model(model_config, model_rate=None, track=False, device="cpu"): + """Create the model based on the configuration given in hydra.""" + model = None + model_config = model_config.copy() + model_config["track"] = track + + if model_config["model"] == "MLP": + model = MLP() + model.to(device) + elif model_config["model"] == "CNNCifar": + model = CNNCifar() + model.to(device) + elif model_config["model"] == "conv": + model = conv(model_rate=model_rate, model_config=model_config, device=device) + elif model_config["model"] == "resnet18": + model = resnet18( + model_rate=model_rate, model_config=model_config, device=device + ) + return model + + +def _init_param(m_param): + if isinstance(m_param, (nn.BatchNorm2d, nn.InstanceNorm2d)): + m_param.weight.data.fill_(1) + m_param.bias.data.zero_() + elif isinstance(m_param, nn.Linear): + m_param.bias.data.zero_() + return m_param + + +class _Scaler(nn.Module): + def __init__(self, rate): + super().__init__() + self.rate = rate + + def forward(self, inp): + """Forward of Scalar nn.Module.""" + output = inp / self.rate if self.training else inp + return output + + +def get_parameters(net) -> List[np.ndarray]: + """Return the parameters of model as numpy.NDArrays.""" + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_parameters(net, parameters: List[np.ndarray]): + """Set the model parameters with given parameters.""" + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + + +def train(model, train_loader, label_split, settings): + """Train a model with given settings. + + Parameters + ---------- + model : nn.Module + The neural network to train. + train_loader : DataLoader + The DataLoader containing the data to train the network on. + label_split : torch.tensor + Tensor containing the labels of the data. + settings: Dict + Dictionary conatining the information about eopchs, optimizer, + lr, momentum, weight_decay, device to train on. + """ + # criterion = torch.nn.CrossEntropyLoss() + optimizer = make_optimizer( + settings["optimizer"], + model.parameters(), + learning_rate=settings["lr"], + momentum=settings["momentum"], + weight_decay=settings["weight_decay"], + ) + + model.train() + for _ in range(settings["epochs"]): + for images, labels in train_loader: + input_dict = {} + input_dict["img"] = images.to(settings["device"]) + input_dict["label"] = labels.to(settings["device"]) + input_dict["label_split"] = label_split.type(torch.int).to( + settings["device"] + ) + optimizer.zero_grad() + output = model(input_dict) + output["loss"].backward() + if ("clip" not in settings) or ( + "clip" in settings and settings["clip"] is True + ): + torch.nn.utils.clip_grad_norm_(model.parameters(), 1) + optimizer.step() + + +def test(model, test_loader, label_split=None, device="cpu"): + """Evaluate the network on the test set. + + Parameters + ---------- + model : nn.Module + The neural network to test. + test_loader : DataLoader + The DataLoader containing the data to test the network on. + device : torch.device + The device on which the model should be tested, either 'cpu' or 'cuda'. + + Returns + ------- + Tuple[float, float] + The loss and the accuracy of the input model on the given data. + """ + model.eval() + size = len(test_loader.dataset) + num_batches = len(test_loader) + test_loss, correct = 0, 0 + + with torch.no_grad(): + model.train(False) + for images, labels in test_loader: + input_dict = {} + input_dict["img"] = images.to(device) + input_dict["label"] = labels.to(device) + if label_split is not None: + input_dict["label_split"] = label_split.type(torch.int).to(device) + output = model(input_dict) + test_loss += output["loss"].item() + correct += ( + (output["score"].argmax(1) == input_dict["label"]) + .type(torch.float) + .sum() + .item() + ) + + test_loss /= num_batches + correct /= size + return test_loss, correct + + +def param_model_rate_mapping( + model_name, parameters, clients_model_rate, global_model_rate=1 +): + """Map the model rate to subset of global parameters(as list of indices). + + Parameters + ---------- + model_name : str + The name of the neural network of global model. + parameters : Dict + state_dict of the global model. + client_model_rate : List[float] + List of model rates of active clients. + global_model_rate: float + Model rate of the global model. + + Returns + ------- + Dict + model rate to parameters indices relative to global model mapping. + """ + unique_client_model_rate = list(set(clients_model_rate)) + print(unique_client_model_rate) + + if "conv" in model_name: + idx = _mr_to_param_idx_conv( + parameters, unique_client_model_rate, global_model_rate + ) + elif "resnet" in model_name: + idx = _mr_to_param_idx_resnet18( + parameters, unique_client_model_rate, global_model_rate + ) + else: + raise ValueError("Not valid model name") + + # add model rate as key to the params calculated + param_idx_model_rate_mapping = OrderedDict() + for i, _ in enumerate(unique_client_model_rate): + param_idx_model_rate_mapping[unique_client_model_rate[i]] = idx[i] + + return param_idx_model_rate_mapping + + +def _mr_to_param_idx_conv(parameters, unique_client_model_rate, global_model_rate): + idx_i = [None for _ in range(len(unique_client_model_rate))] + idx = [OrderedDict() for _ in range(len(unique_client_model_rate))] + output_weight_name = [k for k in parameters.keys() if "weight" in k][-1] + output_bias_name = [k for k in parameters.keys() if "bias" in k][-1] + for k, val in parameters.items(): + parameter_type = k.split(".")[-1] + for index, _ in enumerate(unique_client_model_rate): + if "weight" in parameter_type or "bias" in parameter_type: + scaler_rate = unique_client_model_rate[index] / global_model_rate + _get_key_k_idx_conv( + idx, + idx_i, + { + "index": index, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + output_names={ + "output_weight_name": output_weight_name, + "output_bias_name": output_bias_name, + }, + scaler_rate=scaler_rate, + ) + else: + pass + return idx + + +def _get_key_k_idx_conv( + idx, + idx_i, + param_info, + output_names, + scaler_rate, +): + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + input_size = param_info["val"].size(1) + output_size = param_info["val"].size(0) + if idx_i[param_info["index"]] is None: + idx_i[param_info["index"]] = torch.arange( + input_size, device=param_info["val"].device + ) + input_idx_i_m = idx_i[param_info["index"]] + if param_info["k"] == output_names["output_weight_name"]: + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + ) + else: + local_output_size = int(np.ceil(output_size * (scaler_rate))) + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + )[:local_output_size] + idx[param_info["index"]][param_info["k"]] = output_idx_i_m, input_idx_i_m + idx_i[param_info["index"]] = output_idx_i_m + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + if param_info["k"] == output_names["output_bias_name"]: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + + +def _mr_to_param_idx_resnet18(parameters, unique_client_model_rate, global_model_rate): + idx_i = [None for _ in range(len(unique_client_model_rate))] + idx = [OrderedDict() for _ in range(len(unique_client_model_rate))] + for k, val in parameters.items(): + parameter_type = k.split(".")[-1] + for index, _ in enumerate(unique_client_model_rate): + if "weight" in parameter_type or "bias" in parameter_type: + scaler_rate = unique_client_model_rate[index] / global_model_rate + _get_key_k_idx_resnet18( + idx, + idx_i, + { + "index": index, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + scaler_rate=scaler_rate, + ) + else: + pass + return idx + + +def _get_key_k_idx_resnet18( + idx, + idx_i, + param_info, + scaler_rate, +): + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + input_size = param_info["val"].size(1) + output_size = param_info["val"].size(0) + if "conv1" in param_info["k"] or "conv2" in param_info["k"]: + if idx_i[param_info["index"]] is None: + idx_i[param_info["index"]] = torch.arange( + input_size, device=param_info["val"].device + ) + input_idx_i_m = idx_i[param_info["index"]] + local_output_size = int(np.ceil(output_size * (scaler_rate))) + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + )[:local_output_size] + idx_i[param_info["index"]] = output_idx_i_m + elif "shortcut" in param_info["k"]: + input_idx_i_m = idx[param_info["index"]][ + param_info["k"].replace("shortcut", "conv1") + ][1] + output_idx_i_m = idx_i[param_info["index"]] + elif "linear" in param_info["k"]: + input_idx_i_m = idx_i[param_info["index"]] + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + ) + else: + raise ValueError("Not valid k") + idx[param_info["index"]][param_info["k"]] = (output_idx_i_m, input_idx_i_m) + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + input_size = param_info["val"].size(0) + if "linear" in param_info["k"]: + input_idx_i_m = torch.arange(input_size, device=param_info["val"].device) + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + + +def param_idx_to_local_params(global_parameters, client_param_idx): + """Get the local parameters from the list of param indices. + + Parameters + ---------- + global_parameters : Dict + The state_dict of global model. + client_param_idx : List + Local parameters indices with respect to global model. + + Returns + ------- + Dict + state dict of local model. + """ + local_parameters = OrderedDict() + for k, val in global_parameters.items(): + parameter_type = k.split(".")[-1] + if "weight" in parameter_type or "bias" in parameter_type: + if "weight" in parameter_type: + if val.dim() > 1: + local_parameters[k] = copy.deepcopy( + val[torch.meshgrid(client_param_idx[k])] + ) + else: + local_parameters[k] = copy.deepcopy(val[client_param_idx[k]]) + else: + local_parameters[k] = copy.deepcopy(val[client_param_idx[k]]) + else: + local_parameters[k] = copy.deepcopy(val) + return local_parameters + + +def get_state_dict_from_param(model, parameters): + """Get the state dict from model & parameters as np.NDarrays. + + Parameters + ---------- + model : nn.Module + The neural network. + parameters : np.NDarray + Parameters of the model as np.NDarrays. + + Returns + ------- + Dict + state dict of model. + """ + # Load the parameters into the model + for param_tensor, param_ndarray in zip( + model.state_dict(), parameters_to_ndarrays(parameters) + ): + model.state_dict()[param_tensor].copy_(torch.from_numpy(param_ndarray)) + # Step 3: Obtain the state_dict of the model + state_dict = model.state_dict() + return state_dict diff --git a/baselines/heterofl/heterofl/server.py b/baselines/heterofl/heterofl/server.py new file mode 100644 index 000000000000..f82db0a59fff --- /dev/null +++ b/baselines/heterofl/heterofl/server.py @@ -0,0 +1,101 @@ +"""Flower Server.""" +import time +from collections import OrderedDict +from typing import Callable, Dict, Optional, Tuple + +import torch +from flwr.common.typing import NDArrays, Scalar +from torch import nn + +from heterofl.models import test +from heterofl.utils import save_model + + +def gen_evaluate_fn( + data_loaders, + device: torch.device, + model: nn.Module, + keys, + enable_train_on_train_data: bool, +) -> Callable[ + [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]] +]: + """Generate the function for centralized evaluation. + + Parameters + ---------- + data_loaders : + A dictionary containing dataloaders for testing and + label split of each client. + device : torch.device + The device to test the model on. + model : + Model for testing. + keys : + keys of the model that it is trained on. + + Returns + ------- + Callable[ [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]] ] + The centralized evaluation function. + """ + intermediate_keys = keys + + def evaluate( + server_round: int, parameters_ndarrays: NDArrays, config: Dict[str, Scalar] + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + # pylint: disable=unused-argument + """Use the entire test set for evaluation.""" + # if server_round % 5 != 0 and server_round < 395: + # return 1, {} + + net = model + params_dict = zip(intermediate_keys, parameters_ndarrays) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=False) + net.to(device) + + if server_round % 100 == 0: + save_model(net, f"model_after_round_{server_round}.pth") + + if enable_train_on_train_data is True: + print("start of testing") + start_time = time.time() + with torch.no_grad(): + net.train(True) + for images, labels in data_loaders["entire_trainloader"]: + input_dict = {} + input_dict["img"] = images.to(device) + input_dict["label"] = labels.to(device) + net(input_dict) + print(f"end of stat, time taken = {time.time() - start_time}") + + local_metrics = {} + local_metrics["loss"] = 0 + local_metrics["accuracy"] = 0 + for i, clnt_tstldr in enumerate(data_loaders["valloaders"]): + client_test_res = test( + net, + clnt_tstldr, + data_loaders["label_split"][i].type(torch.int), + device=device, + ) + local_metrics["loss"] += client_test_res[0] + local_metrics["accuracy"] += client_test_res[1] + + global_metrics = {} + global_metrics["loss"], global_metrics["accuracy"] = test( + net, data_loaders["testloader"], device=device + ) + + # return statistics + print(f"global accuracy = {global_metrics['accuracy']}") + print(f"local_accuracy = {local_metrics['accuracy']}") + return global_metrics["loss"], { + "global_accuracy": global_metrics["accuracy"], + "local_loss": local_metrics["loss"], + "local_accuracy": local_metrics["accuracy"], + } + + return evaluate diff --git a/baselines/heterofl/heterofl/strategy.py b/baselines/heterofl/heterofl/strategy.py new file mode 100644 index 000000000000..70dbd19594df --- /dev/null +++ b/baselines/heterofl/heterofl/strategy.py @@ -0,0 +1,467 @@ +"""Flower strategy for HeteroFL.""" +import copy +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union + +import flwr as fl +import torch +from flwr.common import ( + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from torch import nn + +from heterofl.client_manager_heterofl import ClientManagerHeteroFL +from heterofl.models import ( + get_parameters, + get_state_dict_from_param, + param_idx_to_local_params, + param_model_rate_mapping, +) +from heterofl.utils import make_optimizer, make_scheduler + + +# pylint: disable=too-many-instance-attributes +class HeteroFL(fl.server.strategy.Strategy): + """HeteroFL strategy. + + Distribute subsets of a global model to clients according to their + + computational complexity and aggregate received models from clients. + """ + + # pylint: disable=too-many-arguments + def __init__( + self, + model_name: str, + net: nn.Module, + optim_scheduler_settings: Dict, + global_model_rate: float = 1.0, + evaluate_fn=None, + fraction_fit: float = 1.0, + fraction_evaluate: float = 1.0, + min_fit_clients: int = 2, + min_evaluate_clients: int = 2, + min_available_clients: int = 2, + ) -> None: + super().__init__() + self.fraction_fit = fraction_fit + self.fraction_evaluate = fraction_evaluate + self.min_fit_clients = min_fit_clients + self.min_evaluate_clients = min_evaluate_clients + self.min_available_clients = min_available_clients + self.evaluate_fn = evaluate_fn + # # created client_to_model_mapping + # self.client_to_model_rate_mapping: Dict[str, ClientProxy] = {} + + self.model_name = model_name + self.net = net + self.global_model_rate = global_model_rate + # info required for configure and aggregate + # to be filled in initialize + self.local_param_model_rate: OrderedDict = OrderedDict() + # to be filled in initialize + self.active_cl_labels: List[torch.tensor] = [] + # to be filled in configure + self.active_cl_mr: OrderedDict = OrderedDict() + # required for scheduling the lr + self.optimizer = make_optimizer( + optim_scheduler_settings["optimizer"], + self.net.parameters(), + learning_rate=optim_scheduler_settings["lr"], + momentum=optim_scheduler_settings["momentum"], + weight_decay=optim_scheduler_settings["weight_decay"], + ) + self.scheduler = make_scheduler( + optim_scheduler_settings["scheduler"], + self.optimizer, + milestones=optim_scheduler_settings["milestones"], + ) + + def __repr__(self) -> str: + """Return a string representation of the HeteroFL object.""" + return "HeteroFL" + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters.""" + # self.make_client_to_model_rate_mapping(client_manager) + # net = conv(model_rate = 1) + if not isinstance(client_manager, ClientManagerHeteroFL): + raise ValueError( + "Not valid client manager, use ClientManagerHeterFL instead" + ) + clnt_mngr_heterofl: ClientManagerHeteroFL = client_manager + + ndarrays = get_parameters(self.net) + self.local_param_model_rate = param_model_rate_mapping( + self.model_name, + self.net.state_dict(), + clnt_mngr_heterofl.get_all_clients_to_model_mapping(), + self.global_model_rate, + ) + + if clnt_mngr_heterofl.client_label_split is not None: + self.active_cl_labels = clnt_mngr_heterofl.client_label_split.copy() + + return fl.common.ndarrays_to_parameters(ndarrays) + + def configure_fit( + self, + server_round: int, + parameters: Parameters, + client_manager: ClientManager, + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + print(f"in configure fit , server round no. = {server_round}") + if not isinstance(client_manager, ClientManagerHeteroFL): + raise ValueError( + "Not valid client manager, use ClientManagerHeterFL instead" + ) + clnt_mngr_heterofl: ClientManagerHeteroFL = client_manager + # Sample clients + # no need to change this + clientts_selection_config = {} + ( + clientts_selection_config["sample_size"], + clientts_selection_config["min_num_clients"], + ) = self.num_fit_clients(clnt_mngr_heterofl.num_available()) + + # for sampling we pass the criterion to select the required clients + clients = clnt_mngr_heterofl.sample( + num_clients=clientts_selection_config["sample_size"], + min_num_clients=clientts_selection_config["min_num_clients"], + ) + + # update client model rate mapping + clnt_mngr_heterofl.update(server_round) + + global_parameters = get_state_dict_from_param(self.net, parameters) + + self.active_cl_mr = OrderedDict() + + # Create custom configs + fit_configurations = [] + learning_rate = self.optimizer.param_groups[0]["lr"] + print(f"lr = {learning_rate}") + for client in clients: + model_rate = clnt_mngr_heterofl.get_client_to_model_mapping(client.cid) + client_param_idx = self.local_param_model_rate[model_rate] + local_param = param_idx_to_local_params( + global_parameters=global_parameters, client_param_idx=client_param_idx + ) + self.active_cl_mr[client.cid] = model_rate + # local param are in the form of state_dict, + # so converting them only to values of tensors + local_param_fitres = [val.cpu() for val in local_param.values()] + fit_configurations.append( + ( + client, + FitIns( + ndarrays_to_parameters(local_param_fitres), + {"lr": learning_rate}, + ), + ) + ) + + self.scheduler.step() + return fit_configurations + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using weighted average. + + Adopted from authors implementation. + """ + print("in aggregate fit") + gl_model = self.net.state_dict() + + param_idx = [] + for res in results: + param_idx.append( + copy.deepcopy( + self.local_param_model_rate[self.active_cl_mr[res[0].cid]] + ) + ) + + local_param_as_parameters = [fit_res.parameters for _, fit_res in results] + local_parameters_as_ndarrays = [ + parameters_to_ndarrays(local_param_as_parameters[i]) + for i in range(len(local_param_as_parameters)) + ] + local_parameters: List[OrderedDict] = [ + OrderedDict() for _ in range(len(local_param_as_parameters)) + ] + for i in range(len(results)): + j = 0 + for k, _ in gl_model.items(): + local_parameters[i][k] = local_parameters_as_ndarrays[i][j] + j += 1 + + if "conv" in self.model_name: + self._aggregate_conv(param_idx, local_parameters, results) + + elif "resnet" in self.model_name: + self._aggregate_resnet18(param_idx, local_parameters, results) + else: + raise ValueError("Not valid model name") + + return ndarrays_to_parameters([v for k, v in gl_model.items()]), {} + + def _aggregate_conv(self, param_idx, local_parameters, results): + gl_model = self.net.state_dict() + count = OrderedDict() + output_bias_name = [k for k in gl_model.keys() if "bias" in k][-1] + output_weight_name = [k for k in gl_model.keys() if "weight" in k][-1] + for k, val in gl_model.items(): + parameter_type = k.split(".")[-1] + count[k] = val.new_zeros(val.size(), dtype=torch.float32) + tmp_v = val.new_zeros(val.size(), dtype=torch.float32) + for clnt, _ in enumerate(local_parameters): + if "weight" in parameter_type or "bias" in parameter_type: + self._agg_layer_conv( + { + "cid": int(results[clnt][0].cid), + "param_idx": param_idx, + "local_parameters": local_parameters, + }, + { + "tmp_v": tmp_v, + "count": count, + }, + { + "clnt": clnt, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + { + "output_weight_name": output_weight_name, + "output_bias_name": output_bias_name, + }, + ) + else: + tmp_v += local_parameters[clnt][k] + count[k] += 1 + tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(count[k][count[k] > 0]) + val[count[k] > 0] = tmp_v[count[k] > 0].to(val.dtype) + + def _agg_layer_conv( + self, + clnt_params, + tmp_v_count, + param_info, + output_names, + ): + # pi = param_info + param_idx = clnt_params["param_idx"] + clnt = param_info["clnt"] + k = param_info["k"] + tmp_v = tmp_v_count["tmp_v"] + count = tmp_v_count["count"] + + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + if k == output_names["output_weight_name"]: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = list(param_idx[clnt][k]) + param_idx[clnt][k][0] = param_idx[clnt][k][0][label_split] + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k][label_split] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + else: + if k == output_names["output_bias_name"]: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = param_idx[clnt][k][label_split] + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k][ + label_split + ] + count[k][param_idx[clnt][k]] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + + def _aggregate_resnet18(self, param_idx, local_parameters, results): + gl_model = self.net.state_dict() + count = OrderedDict() + for k, val in gl_model.items(): + parameter_type = k.split(".")[-1] + count[k] = val.new_zeros(val.size(), dtype=torch.float32) + tmp_v = val.new_zeros(val.size(), dtype=torch.float32) + for clnt, _ in enumerate(local_parameters): + if "weight" in parameter_type or "bias" in parameter_type: + self._agg_layer_resnet18( + { + "cid": int(results[clnt][0].cid), + "param_idx": param_idx, + "local_parameters": local_parameters, + }, + tmp_v, + count, + { + "clnt": clnt, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + ) + else: + tmp_v += local_parameters[clnt][k] + count[k] += 1 + tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(count[k][count[k] > 0]) + val[count[k] > 0] = tmp_v[count[k] > 0].to(val.dtype) + + def _agg_layer_resnet18(self, clnt_params, tmp_v, count, param_info): + param_idx = clnt_params["param_idx"] + k = param_info["k"] + clnt = param_info["clnt"] + + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + if "linear" in k: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = list(param_idx[clnt][k]) + param_idx[clnt][k][0] = param_idx[clnt][k][0][label_split] + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k][label_split] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + else: + if "linear" in k: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = param_idx[clnt][k][label_split] + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k][ + label_split + ] + count[k][param_idx[clnt][k]] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + # if self.fraction_evaluate == 0.0: + # return [] + # config = {} + # evaluate_ins = EvaluateIns(parameters, config) + + # # Sample clients + # sample_size, min_num_clients = self.num_evaluation_clients( + # client_manager.num_available() + # ) + # clients = client_manager.sample( + # num_clients=sample_size, min_num_clients=min_num_clients + # ) + + # global_parameters = get_state_dict_from_param(self.net, parameters) + + # self.active_cl_mr = OrderedDict() + + # # Create custom configs + # evaluate_configurations = [] + # for idx, client in enumerate(clients): + # model_rate = client_manager.get_client_to_model_mapping(client.cid) + # client_param_idx = self.local_param_model_rate[model_rate] + # local_param = + # param_idx_to_local_params(global_parameters, client_param_idx) + # self.active_cl_mr[client.cid] = model_rate + # # local param are in the form of state_dict, + # # so converting them only to values of tensors + # local_param_fitres = [v.cpu() for v in local_param.values()] + # evaluate_configurations.append( + # (client, EvaluateIns(ndarrays_to_parameters(local_param_fitres), {})) + # ) + # return evaluate_configurations + + return [] + + # return self.configure_fit(server_round , parameters , client_manager) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using weighted average.""" + # if not results: + # return None, {} + + # loss_aggregated = weighted_loss_avg( + # [ + # (evaluate_res.num_examples, evaluate_res.loss) + # for _, evaluate_res in results + # ] + # ) + + # accuracy_aggregated = 0 + # for cp, y in results: + # print(f"{cp.cid}-->{y.metrics['accuracy']}", end=" ") + # accuracy_aggregated += y.metrics["accuracy"] + # accuracy_aggregated /= len(results) + + # metrics_aggregated = {"accuracy": accuracy_aggregated} + # print(f"\npaneer lababdar {metrics_aggregated}") + # return loss_aggregated, metrics_aggregated + + return None, {} + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function.""" + if self.evaluate_fn is None: + # No evaluation function provided + return None + parameters_ndarrays = parameters_to_ndarrays(parameters) + eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {}) + if eval_res is None: + return None + loss, metrics = eval_res + return loss, metrics + + def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: + """Return sample size and required number of clients.""" + num_clients = int(num_available_clients * self.fraction_fit) + return max(num_clients, self.min_fit_clients), self.min_available_clients + + def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: + """Use a fraction of available clients for evaluation.""" + num_clients = int(num_available_clients * self.fraction_evaluate) + return max(num_clients, self.min_evaluate_clients), self.min_available_clients diff --git a/baselines/heterofl/heterofl/utils.py b/baselines/heterofl/heterofl/utils.py new file mode 100644 index 000000000000..3bcb7f3d8ea7 --- /dev/null +++ b/baselines/heterofl/heterofl/utils.py @@ -0,0 +1,218 @@ +"""Contains utility functions.""" +import errno +import os +from pathlib import Path + +import numpy as np +import torch +from hydra.core.hydra_config import HydraConfig + + +def preprocess_input(cfg_model, cfg_data): + """Preprocess the input to get input shape, other derivables. + + Parameters + ---------- + cfg_model : DictConfig + Retrieve model-related information from the base.yaml configuration in Hydra. + cfg_data : DictConfig + Retrieve data-related information required to construct the model. + + Returns + ------- + Dict + Dictionary contained derived information from config. + """ + model_config = {} + # if cfg_model.model_name == "conv": + # model_config["model_name"] = + # elif for others... + model_config["model"] = cfg_model.model_name + if cfg_data.dataset_name == "MNIST": + model_config["data_shape"] = [1, 28, 28] + model_config["classes_size"] = 10 + elif cfg_data.dataset_name == "CIFAR10": + model_config["data_shape"] = [3, 32, 32] + model_config["classes_size"] = 10 + + if "hidden_layers" in cfg_model: + model_config["hidden_layers"] = cfg_model.hidden_layers + if "norm" in cfg_model: + model_config["norm"] = cfg_model.norm + if "scale" in cfg_model: + model_config["scale"] = cfg_model.scale + if "mask" in cfg_model: + model_config["mask"] = cfg_model.mask + + return model_config + + +def make_optimizer(optimizer_name, parameters, learning_rate, weight_decay, momentum): + """Make the optimizer with given config. + + Parameters + ---------- + optimizer_name : str + Name of the optimizer. + parameters : Dict + Parameters of the model. + learning_rate: float + Learning rate of the optimizer. + weight_decay: float + weight_decay of the optimizer. + + Returns + ------- + torch.optim.Optimizer + Optimizer. + """ + optimizer = None + if optimizer_name == "SGD": + optimizer = torch.optim.SGD( + parameters, lr=learning_rate, momentum=momentum, weight_decay=weight_decay + ) + return optimizer + + +def make_scheduler(scheduler_name, optimizer, milestones): + """Make the scheduler with given config. + + Parameters + ---------- + scheduler_name : str + Name of the scheduler. + optimizer : torch.optim.Optimizer + Parameters of the model. + milestones: List[int] + List of epoch indices. Must be increasing. + + Returns + ------- + torch.optim.lr_scheduler.Scheduler + scheduler. + """ + scheduler = None + if scheduler_name == "MultiStepLR": + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=milestones + ) + return scheduler + + +def get_global_model_rate(model_mode): + """Give the global model rate from string(cfg.control.model_mode) . + + Parameters + ---------- + model_mode : str + Contains the division of computational complexties among clients. + + Returns + ------- + str + global model computational complexity. + """ + model_mode = "" + model_mode + model_mode = model_mode.split("-")[0][0] + return model_mode + + +class ModelRateManager: + """Control the model rate of clients in case of simulation.""" + + def __init__(self, model_split_mode, model_split_rate, model_mode): + self.model_split_mode = model_split_mode + self.model_split_rate = model_split_rate + self.model_mode = model_mode + self.model_mode = self.model_mode.split("-") + + def create_model_rate_mapping(self, num_users): + """Change the client to model rate mapping accordingly.""" + client_model_rate = [] + + if self.model_split_mode == "fix": + mode_rate, proportion = [], [] + for comp_level_prop in self.model_mode: + mode_rate.append(self.model_split_rate[comp_level_prop[0]]) + proportion.append(int(comp_level_prop[1:])) + num_users_proportion = num_users // sum(proportion) + for i, comp_level in enumerate(mode_rate): + client_model_rate += np.repeat( + comp_level, num_users_proportion * proportion[i] + ).tolist() + client_model_rate = client_model_rate + [ + client_model_rate[-1] for _ in range(num_users - len(client_model_rate)) + ] + # return client_model_rate + + elif self.model_split_mode == "dynamic": + mode_rate, proportion = [], [] + + for comp_level_prop in self.model_mode: + mode_rate.append(self.model_split_rate[comp_level_prop[0]]) + proportion.append(int(comp_level_prop[1:])) + + proportion = (np.array(proportion) / sum(proportion)).tolist() + + rate_idx = torch.multinomial( + torch.tensor(proportion), num_samples=num_users, replacement=True + ).tolist() + client_model_rate = np.array(mode_rate)[rate_idx] + + # return client_model_rate + + else: + raise ValueError("Not valid model split mode") + + return client_model_rate + + +def save_model(model, path): + """To save the model in the given path.""" + # print('in save model') + current_path = HydraConfig.get().runtime.output_dir + model_save_path = Path(current_path) / path + torch.save(model.state_dict(), model_save_path) + + +# """ The following functions(check_exists, makedir_exit_ok, save, load) +# are adopted from authors (of heterofl) implementation.""" + + +def check_exists(path): + """Check if the given path exists.""" + return os.path.exists(path) + + +def makedir_exist_ok(path): + """Create a directory.""" + try: + os.makedirs(path) + except OSError as os_err: + if os_err.errno == errno.EEXIST: + pass + else: + raise + + +def save(inp, path, protocol=2, mode="torch"): + """Save the inp in a given path.""" + dirname = os.path.dirname(path) + makedir_exist_ok(dirname) + if mode == "torch": + torch.save(inp, path, pickle_protocol=protocol) + elif mode == "numpy": + np.save(path, inp, allow_pickle=True) + else: + raise ValueError("Not valid save mode") + + +# pylint: disable=no-else-return +def load(path, mode="torch"): + """Load the file from given path.""" + if mode == "torch": + return torch.load(path, map_location=lambda storage, loc: storage) + elif mode == "numpy": + return np.load(path, allow_pickle=True) + else: + raise ValueError("Not valid save mode") diff --git a/baselines/heterofl/pyproject.toml b/baselines/heterofl/pyproject.toml new file mode 100644 index 000000000000..0f72edf20345 --- /dev/null +++ b/baselines/heterofl/pyproject.toml @@ -0,0 +1,145 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.masonry.api" + +[tool.poetry] +name = "heterofl" # <----- Ensure it matches the name of your baseline directory containing all the source code +version = "1.0.0" +description = "HeteroFL : Computation And Communication Efficient Federated Learning For Heterogeneous Clients" +license = "Apache-2.0" +authors = ["M S Chaitanya Kumar ", "The Flower Authors "] +readme = "README.md" +homepage = "https://flower.dev" +repository = "https://github.com/adap/flower" +documentation = "https://flower.dev" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[tool.poetry.dependencies] +python = ">=3.10.0, <3.11.0" +flwr = { extras = ["simulation"], version = "1.5.0" } +hydra-core = "1.3.2" # don't change this +torch = { url = "https://download.pytorch.org/whl/cu118/torch-2.1.0%2Bcu118-cp310-cp310-linux_x86_64.whl"} +torchvision = { url = "https://download.pytorch.org/whl/cu118/torchvision-0.16.0%2Bcu118-cp310-cp310-linux_x86_64.whl"} +anytree = "^2.12.1" +types-six = "^1.16.21.9" +tqdm = "4.66.1" + +[tool.poetry.dev-dependencies] +isort = "==5.11.5" +black = "==23.1.0" +docformatter = "==1.5.1" +mypy = "==1.4.1" +pylint = "==2.8.2" +flake8 = "==3.9.2" +pytest = "==6.2.4" +pytest-watch = "==4.2.0" +ruff = "==0.0.272" +types-requests = "==2.27.7" +virtualenv = "20.21.0" + +[tool.isort] +line_length = 88 +indent = " " +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true + +[tool.black] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311"] + +[tool.pytest.ini_options] +minversion = "6.2" +addopts = "-qq" +testpaths = [ + "flwr_baselines", +] + +[tool.mypy] +ignore_missing_imports = true +strict = false +plugins = "numpy.typing.mypy_plugin" + +[tool.pylint."MESSAGES CONTROL"] +disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y" +signature-mutators="hydra.main.main" + + +[tool.pylint.typecheck] +generated-members="numpy.*, torch.*, tensorflow.*" + + +[[tool.mypy.overrides]] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] +follow_imports = "skip" +follow_imports_for_stubs = true +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch.*" +follow_imports = "skip" +follow_imports_for_stubs = true + +[tool.docformatter] +wrap-summaries = 88 +wrap-descriptions = 88 + +[tool.ruff] +target-version = "py38" +line-length = 88 +select = ["D", "E", "F", "W", "B", "ISC", "C4"] +fixable = ["D", "E", "F", "W", "B", "ISC", "C4"] +ignore = ["B024", "B027"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "proto", +] + +[tool.ruff.pydocstyle] +convention = "numpy" diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 7168386eaf0a..507489e76e7b 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -14,6 +14,8 @@ - FedNova [#2179](https://github.com/adap/flower/pull/2179) + - HeteroFL [#2439](https://github.com/adap/flower/pull/2439) + ## v1.6.0 (2023-11-28) ### Thanks to our contributors