From 5355672a0a127265acb8a94228e7a7170d67ed5b Mon Sep 17 00:00:00 2001 From: Joyce Yuan <17791324+joyce-yuan@users.noreply.github.com> Date: Tue, 22 Oct 2024 22:36:03 -0400 Subject: [PATCH 1/4] small bug fix to make our dataset classes subclass torch dataset (#122) --- src/utils/data_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py index 0dde718..7889b8f 100644 --- a/src/utils/data_utils.py +++ b/src/utils/data_utils.py @@ -9,7 +9,7 @@ from utils.corruptions import corrupt_mapping -class CacheDataset: +class CacheDataset(Dataset): """ Caches the entire dataset in memory. """ @@ -27,7 +27,7 @@ def __len__(self): return len(self.data) -class TransformDataset: +class TransformDataset(Dataset): """ Applies a transformation to the dataset. """ @@ -45,7 +45,7 @@ def __len__(self): return len(self.dset) # Custom dataset wrapper to apply corruption -class CorruptDataset: +class CorruptDataset(Dataset): def __init__(self, dset: CacheDataset, corruption_fn_name, severity: int = 1): print("Initialized CorruptDataset with corruption_fn_name: ", corruption_fn_name) self.dset = dset # Original dataset From b71d53a57fe711cc67a5b09b696f3a61aa8e2c88 Mon Sep 17 00:00:00 2001 From: Joyce Yuan <17791324+joyce-yuan@users.noreply.github.com> Date: Thu, 24 Oct 2024 08:55:06 -0700 Subject: [PATCH 2/4] post hoc plot utils that generate summary metrics and plots (#123) * post hoc plot utils that generate summary metrics and plots * fixed some naming, logging test and train avg and std --- src/utils/post_hoc_plot_utils.py | 193 +++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 src/utils/post_hoc_plot_utils.py diff --git a/src/utils/post_hoc_plot_utils.py b/src/utils/post_hoc_plot_utils.py new file mode 100644 index 0000000..352c7b1 --- /dev/null +++ b/src/utils/post_hoc_plot_utils.py @@ -0,0 +1,193 @@ +import os +import pandas as pd +import numpy as np +from sklearn.metrics import auc +from typing import List, Dict, Tuple, Optional +import matplotlib.pyplot as plt +import json + +# Load Logs +def load_logs(node_id: str, metric_type: str, logs_dir: str) -> pd.DataFrame: + """Loads the csv logs for a given node and metric (train/test, acc/loss)""" + file_path = os.path.join(logs_dir, f'node_{node_id}/csv/{metric_type}.csv') + return pd.read_csv(file_path) + +def get_all_nodes(logs_dir: str) -> List[str]: + """Return all node directories in the log folder""" + return [d for d in os.listdir(logs_dir) if (os.path.isdir(os.path.join(logs_dir, d)) and d != "node_0" and d.startswith('node'))] + +# Calculate Metrics Per User +def calculate_auc(df: pd.DataFrame, metric: str = 'acc') -> float: + """Calculate AUC for the given dataframe's accuracy or loss.""" + return auc(df['iteration'], df[metric]) + +def best_accuracy(df: pd.DataFrame, metric: str = 'acc') -> float: + """Find the best test accuracy or lowest loss for a given metric.""" + return df[metric].max() + +def best_loss(df: pd.DataFrame, metric: str) -> float: + """Find the lowest loss for a given metric.""" + return df[metric].min() + +def compute_per_user_metrics(node_id: str, logs_dir: str) -> Dict[str, float]: + """Computes AUC, best accuracy, and best loss for train/test.""" + train_acc = load_logs(node_id, 'train_acc', logs_dir) + test_acc = load_logs(node_id, 'test_acc', logs_dir) + train_loss = load_logs(node_id, 'train_loss', logs_dir) + test_loss = load_logs(node_id, 'test_loss', logs_dir) + + metrics = { + 'train_auc_acc': calculate_auc(train_acc, 'train_acc'), + 'test_auc_acc': calculate_auc(test_acc, 'test_acc'), + 'train_auc_loss': calculate_auc(train_loss, 'train_loss'), + 'test_auc_loss': calculate_auc(test_loss, 'test_loss'), + 'best_train_acc': best_accuracy(train_acc, 'train_acc'), + 'best_test_acc': best_accuracy(test_acc, 'test_acc'), + 'best_train_loss': best_loss(train_loss, 'train_loss'), + 'best_test_loss': best_loss(test_loss, 'test_loss') + } + + return metrics + +def aggregate_metrics_across_users(logs_dir: str, output_dir: Optional[str] = None) -> Tuple[pd.Series, pd.Series, pd.DataFrame]: + """Aggregate metrics across all users and save the results to CSV files.""" + nodes = get_all_nodes(logs_dir) + all_metrics: List[Dict[str, float]] = [] + + # Ensure the output directory exists + if not output_dir: + output_dir = os.path.join(logs_dir, 'aggregated_metrics') + os.makedirs(output_dir, exist_ok=True) + + for node in nodes: + node_id = node.split('_')[-1] + metrics = compute_per_user_metrics(node_id, logs_dir) + metrics['node'] = node + all_metrics.append(metrics) + + # Convert to DataFrame for easier processing + df_metrics = pd.DataFrame(all_metrics) + + # Calculate average and standard deviation + avg_metrics = df_metrics.mean() + std_metrics = df_metrics.std() + + # Save the DataFrame with per-user metrics + df_metrics.to_csv(os.path.join(output_dir, 'per_user_metrics.csv'), index=False) + + # Save the average and standard deviation statistics + summary_stats = pd.DataFrame({'Average': avg_metrics, 'Standard Deviation': std_metrics}) + summary_stats.to_csv(os.path.join(output_dir, 'summary_statistics.csv')) + + return avg_metrics, std_metrics, df_metrics + +def compute_per_user_round_data(node_id: str, logs_dir: str, metrics_map: Optional[Dict[str, str]] = None) -> Dict[str, np.ndarray]: + """Extract per-round data (accuracy and loss) for train/test from the logs.""" + if metrics_map is None: + metrics_map = { + 'train_acc': 'train_acc', + 'test_acc': 'test_acc', + 'train_loss': 'train_loss', + 'test_loss': 'test_loss', + } + + per_round_data = {} + for key, file_name in metrics_map.items(): + data = load_logs(node_id, file_name, logs_dir) + per_round_data[key] = data[file_name].values + if 'rounds' not in per_round_data: + per_round_data['rounds'] = data['iteration'].values + + return per_round_data + +# Per Round Aggregation +def aggregate_per_round_data(logs_dir: str, metrics_map: Optional[Dict[str, str]] = None) -> Dict[str, pd.DataFrame]: + """Aggregate the per-round data for all users.""" + if metrics_map is None: + metrics_map = { + 'train_acc': 'train_acc', + 'test_acc': 'test_acc', + 'train_loss': 'train_loss', + 'test_loss': 'test_loss', + } + + nodes = get_all_nodes(logs_dir) + all_users_data: Dict[str, List[np.ndarray]] = {metric: [] for metric in metrics_map} + all_users_data['rounds'] = [] + + for node in nodes: + node_id = node.split('_')[-1] + user_data = compute_per_user_round_data(node_id, logs_dir, metrics_map) + + # Collect data for all users + for key in metrics_map: + all_users_data[key].append(user_data[key]) + + # Convert lists of arrays into DataFrames for easier aggregation + rounds = user_data['rounds'] # All users should have the same rounds + all_users_data['rounds'] = rounds + + for key in metrics_map: + all_users_data[key] = pd.DataFrame(all_users_data[key]).transpose() + + return all_users_data + +# Plotting +def plot_metric_per_round(metric_df: pd.DataFrame, rounds: np.ndarray, metric_name: str, ylabel: str, output_dir: str) -> None: + """Plot per-round data for each user and aggregate (mean and std).""" + plt.figure(figsize=(10, 6)) + + # Plot per-user data + for col in metric_df.columns: + plt.plot(rounds, metric_df[col], alpha=0.6, label=f'User {col+1}') + + # Compute mean and std + mean_metric = metric_df.mean(axis=1) + std_metric = metric_df.std(axis=1) + + # Save the mean and std + if not os.path.exists(output_dir): + os.makedirs(output_dir) + mean_metric.to_csv(f'{output_dir}{metric_name}_avg.csv', index=False) + std_metric.to_csv(f'{output_dir}{metric_name}_std.csv', index=False) + + + # Plot the mean with standard deviation as a shaded area + plt.plot(rounds, mean_metric, label='Average', color='black', linestyle='--') + plt.fill_between(rounds, mean_metric - std_metric, mean_metric + std_metric, color='gray', alpha=0.2, label='Std dev') + + plt.xlabel('Rounds (Iterations)') + plt.ylabel(ylabel) + plt.title(f'{ylabel} per User and Aggregate') + plt.legend() + plt.savefig(f'{output_dir}{metric_name}_per_round.png') + plt.close() + +def plot_all_metrics(logs_dir: str, metrics_map: Optional[Dict[str, str]] = None) -> None: + """Generates plots for all metrics over rounds with aggregation.""" + if metrics_map is None: + metrics_map = { + 'test_acc': 'Test Accuracy', + 'train_acc': 'Train Accuracy', + 'test_loss': 'Test Loss', + 'train_loss': 'Train Loss' + } + + all_users_data = aggregate_per_round_data(logs_dir) + + for key, display_name in metrics_map.items(): + plot_metric_per_round( + metric_df=all_users_data[key], + rounds=all_users_data['rounds'], + metric_name=key, + ylabel=display_name, + output_dir=f'{logs_dir}plots/' + ) + + print("Plots saved as PNG files.") + +if __name__ == "__main__": + # Define the path where your experiment logs are saved + logs_dir = '/u/jyuan24/sonar/src/expt_dump/1_malicious_exp/cifar10_40users_1250_data_poison_8_malicious_seed1/logs/' + avg_metrics, std_metrics, df_metrics = aggregate_metrics_across_users(logs_dir) + plot_all_metrics(logs_dir) \ No newline at end of file From 9234a62f339961557e27b46a62fd6a51e5f99079 Mon Sep 17 00:00:00 2001 From: Joyce Yuan <17791324+joyce-yuan@users.noreply.github.com> Date: Thu, 24 Oct 2024 18:28:20 -0700 Subject: [PATCH 3/4] Update automate exp and docs (#124) * updated main exp, with fixes to other files, and also added docs * small changes to grpc doc * bug fixes --- docs/getting-started/experiments.md | 58 +++++++++++++++++++++++ docs/getting-started/grpc.md | 10 ++-- mkdocs.yml | 1 + src/configs/sys_config.py | 15 +++++- src/main_exp.py | 71 +++++++++++++++++++++++++---- src/utils/post_hoc_plot_utils.py | 11 +++-- 6 files changed, 146 insertions(+), 20 deletions(-) create mode 100644 docs/getting-started/experiments.md diff --git a/docs/getting-started/experiments.md b/docs/getting-started/experiments.md new file mode 100644 index 0000000..edeac6f --- /dev/null +++ b/docs/getting-started/experiments.md @@ -0,0 +1,58 @@ +# Automating Experiments + +In this tutorial, we will discuss how to automate running multiple experiments by customizing our experiment script. Note that we currently only support automation on one machine with the gRPC protocol. If you have not already read the [Getting Started](./getting-started.md) guide, we recommend you do so before proceeding. + +## Running the Code +The `main_exp.py` file automates running experiments on one machine using gRPC. You can run this file with the command: +``` bash +python main_exp.py -host randomhost42.mit.edu +``` + +## Customizing the Experiments +To customize your experiment automation, make these changes in `main_exp.py`. + +1. Specify your constant settings in `sys_config.py` and `algo_config.py` +2. Import the sys_config and algo_config setting objects you want to use for your experiments. +``` python +from configs.algo_config import traditional_fl +from configs.sys_config import grpc_system_config +``` + +3. Write the experiment object like the example `exp_dict`, mapping each new experiment ID to the set of keys that you want to change per experiment. Specify the `algo_config` and its specific customizations in `algo`, and likewise for `sys_config` and `sys`. *Note every experiment must have a unique experiment path, and we recommend guarenteeing this by giving every experiment a unique experiment id.* +``` python +exp_dict = exp_dict = { + "test_automation_1": { + "algo_config": traditional_fl, + "sys_config": grpc_system_config, + "algo": { + "rounds": 3, + }, + "sys": { + "seed": 3, + "num_users": 3, + }, + }, + "test_automation_2": { + "algo_config": traditional_fl, + "sys_config": grpc_system_config, + "algo": { + "rounds": 4, + }, + "sys": { + "seed": 4, + "num_users": 4, + }, + }, +} +``` + + +4. (Optional) Specify whether or not to run post hoc metrics and plots by setting the boolean at the top of the file. +``` bash +post_hoc_plot: bool = True +``` + +5. Start the experiments with the command. +``` bash +python main_exp.py -host randomhost42.mit.edu +``` \ No newline at end of file diff --git a/docs/getting-started/grpc.md b/docs/getting-started/grpc.md index b4dc8d0..1a7fd9a 100644 --- a/docs/getting-started/grpc.md +++ b/docs/getting-started/grpc.md @@ -8,16 +8,16 @@ In this tutorial, we will discuss how to use gRPC for training models across mul The main advantage of our abstract communication layer is that the same code runs regardless of the fact you are using MPI or gRPC underneath. As long as the communication layer is implemented correctly, the rest of the code remains the same. This is a huge advantage for the framework as it allows us to switch between different communication layers without changing the code. ## Running the code -Let's say you want to run the decentralized training with 80 users on 4 machines. Our implementation currently requires a coordinating node to manage the orchestration. Therefore, there will be 81 nodes in total. Make sure `sys_config.py` has `num_users: 80` in the config. You should run the following command on all 4 machines: +Let's say you want to run the decentralized training with 80 users on 4 machines. Our implementation currently requires a coordinating node to manage the orchestration. Therefore, there will be 81 nodes in total. In the `sys_config.py`, specify the hostname and port you want to run the coordinator node (i.e. `"comm": { "type": "GRPC", "peer_ids": ["randomhost41.mit.edu:5003"] # the coordinator port will be specified here }`), and set `num_users: 80`. +On the machine that you want to run the coordinator node on, start the coordinator by running the following command: ``` bash -python main_grpc.py -n 20 -host randomhost42.mit.edu +python main.py -super true ``` -On **one** of the machines that you want to use as a coordinator node (let's say it is `randomhost43.mit.edu`), change the `peer_ids` with the hostname and the port you want to run the coordinator node and then run the following command: - +Then, start the user threads by running the following command on all 4 machines (change the name of the host per machine you are using, and note that you may need to open a new terminal if you are using the same machine as the supernode): ``` bash -python main.py -super true +python main_grpc.py -n 20 -host randomhost42.mit.edu ``` > **_NOTE:_** Most of the algorithms right now do not use the new communication protocol, hence you can only use the old MPI version with them. We are working on updating the algorithms to use the new communication protocol. diff --git a/mkdocs.yml b/mkdocs.yml index 09d6dc7..9004dd4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -13,6 +13,7 @@ nav: - Config File: getting-started/config.md - Customizability: getting-started/customize.md - Using GRPC: getting-started/grpc.md + - Automating Experiments: getting-started/experiments.md - CollaBench: - Main: collabench.md - Feature Comparison: feature.md diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 44ae73a..4a60e20 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -44,6 +44,7 @@ def get_algo_configs( assignment_method: Literal[ "sequential", "random", "mapping", "distribution" ] = "sequential", + seed: Optional[int] = 1, mapping: Optional[List[int]] = None, distribution: Optional[Dict[int, int]] = None, ) -> Dict[str, ConfigType]: @@ -75,10 +76,20 @@ def get_algo_configs( ) total_users = sum(distribution.values()) assert total_users == num_users - current_index = 1 + + # List of node indices to assign + node_indices = list(range(1, total_users + 1)) + # Seed for reproducibility + random.seed(seed) + # Shuffle the node indices based on the seed + random.shuffle(node_indices) + + # Assign nodes based on the shuffled indices + current_index = 0 for algo_index, num_nodes in distribution.items(): for i in range(num_nodes): - algo_config_map[f"node_{current_index}"] = algo_configs[algo_index] + node_id = node_indices[current_index] + algo_config_map[f"node_{node_id}"] = algo_configs[algo_index] current_index += 1 else: raise ValueError(f"Invalid assignment method: {assignment_method}") diff --git a/src/main_exp.py b/src/main_exp.py index b68d721..a326024 100644 --- a/src/main_exp.py +++ b/src/main_exp.py @@ -8,8 +8,43 @@ import subprocess from typing import List +from utils.types import ConfigType +from utils.config_utils import process_config +from utils.post_hoc_plot_utils_copy import aggregate_metrics_across_users, plot_all_metrics + +from configs.sys_config import get_algo_configs, get_device_ids +from configs.algo_config import traditional_fl from configs.sys_config import grpc_system_config +post_hoc_plot: bool = True + +# for each experiment key, write the modifications to the config file +gpu_ids = [2, 3, 5, 6] +exp_dict = { + "experiment_1": { + "algo_config": traditional_fl, + "sys_config": grpc_system_config, + "algo": { + "rounds": 3, + }, + "sys": { + "seed": 3, + "num_users": 3, + }, + }, + "experiment_2": { + "algo_config": traditional_fl, + "sys_config": grpc_system_config, + "algo": { + "rounds": 4, + }, + "sys": { + "seed": 4, + "num_users": 4, + }, + }, +} + # parse the arguments parser = argparse.ArgumentParser(description="host address of the nodes") parser.add_argument( @@ -21,20 +56,28 @@ args = parser.parse_args() -# for each experiment key -# write the new config file -exp_ids = ["test_automation_1", "test_automation_2", "test_automation_3"] +for exp_id, exp_config in exp_dict.items(): + # update the algo config with config settings + base_algo_config = exp_config["algo_config"].copy() + base_algo_config.update(exp_config["algo"]) + + # update the sys config with config settings + base_sys_config = exp_config["sys_config"].copy() + base_sys_config.update(exp_config["sys"]) -for e, exp_id in enumerate(exp_ids): - current_config = grpc_system_config - current_config["exp_id"] = exp_id + # set up the full config file by combining the algo and sys config + n: int = base_sys_config["num_users"] + seed: int = base_sys_config["seed"] + base_sys_config["algos"] = get_algo_configs(num_users=n, algo_configs=[base_algo_config], seed=seed) + base_sys_config["device_ids"] = get_device_ids(n, gpu_ids) + + full_config = base_sys_config.copy() + full_config["exp_id"] = exp_id # write the config file as python file configs/temp_config.py with open("./configs/temp_config.py", "w") as f: f.write("current_config = ") - f.write(str(current_config)) - - n: int = current_config["num_users"] + f.write(str(full_config)) # start the supernode supernode_command: List[str] = ["python", "main.py", "-host", args.host, "-super", "true", "-s", "./configs/temp_config.py"] @@ -51,5 +94,15 @@ # Wait for the supernode process to finish process.wait() + # run the post-hoc analysis + if post_hoc_plot: + full_config = process_config(full_config) # this populates the results path + logs_dir = full_config["results_path"] + '/logs/' + + # aggregate metrics across all users + aggregate_metrics_across_users(logs_dir) + # plot all metrics + plot_all_metrics(logs_dir) + # Continue with the next set of commands after supernode finishes print(f"Supernode process {exp_id} finished. Proceeding to next set of commands.") \ No newline at end of file diff --git a/src/utils/post_hoc_plot_utils.py b/src/utils/post_hoc_plot_utils.py index 352c7b1..25b6c29 100644 --- a/src/utils/post_hoc_plot_utils.py +++ b/src/utils/post_hoc_plot_utils.py @@ -140,10 +140,13 @@ def plot_metric_per_round(metric_df: pd.DataFrame, rounds: np.ndarray, metric_na # Plot per-user data for col in metric_df.columns: plt.plot(rounds, metric_df[col], alpha=0.6, label=f'User {col+1}') - - # Compute mean and std - mean_metric = metric_df.mean(axis=1) - std_metric = metric_df.std(axis=1) + + # Select only numeric columns before calculating mean and std + numeric_columns = df_metrics.select_dtypes(include=[np.number]) + + # Calculate average and standard deviation + avg_metrics = numeric_columns.mean() + std_metrics = numeric_columns.std() # Save the mean and std if not os.path.exists(output_dir): From eb8c4b7562e320c94be939c28f576c84d3527446 Mon Sep 17 00:00:00 2001 From: Rishi Sharma Date: Sat, 26 Oct 2024 03:13:39 +0200 Subject: [PATCH 4/4] Arbirtary messages as GRPC Model (#125) --- src/algos/base_class.py | 34 +++++++++++++++------- src/algos/fl.py | 28 +++++++++++++----- src/utils/communication/grpc/grpc_utils.py | 16 +++++++++- src/utils/communication/grpc/main.py | 6 ++-- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 523ed15..7fcd129 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -250,7 +250,13 @@ def get_model_weights(self) -> Dict[str, Tensor]: """ Share the model weights """ - return self.model.state_dict() + message = {"sender": self.node_id, "round": self.round, "model": self.model.state_dict()} + + # Move to CPU before sending + for key in message["model"].keys(): + message["model"][key] = message["model"][key].to("cpu") + + return message def get_local_rounds(self) -> int: return self.round @@ -307,7 +313,7 @@ def receive_and_aggregate(self): raise NotImplementedError - def strip_empty_models(self, models_wts: List[OrderedDict[str, Tensor]], + def strip_empty_models(self, models_wts: List[OrderedDict[str, Any]], collab_weights: Optional[List[float]] = None) -> Any: repr_list = [] if collab_weights is not None: @@ -606,7 +612,12 @@ def receive_and_aggregate(self): """ if self.is_working: repr = self.comm_utils.receive([self.server_node])[0] - self.set_model_weights(repr) + if "round" in repr: + round = repr["round"] + if "sender" in repr: + sender = repr["sender"] + assert "model" in repr, "Model not found in the received message" + self.set_model_weights(repr["model"]) def run_protocol(self) -> None: raise NotImplementedError @@ -673,7 +684,7 @@ def set_data_parameters(self, config: Dict[str, Any]) -> None: self._test_loader = DataLoader(test_dset, batch_size=batch_size) def aggregate( - self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any + self, representation_list: List[OrderedDict[str, Any]], **kwargs: Any ) -> OrderedDict[str, Tensor]: """ Aggregate the knowledge from the users @@ -745,15 +756,9 @@ def local_test(self, **kwargs: Any) -> Tuple[float, float]: self.model_utils.save_model(self.model, self.model_save_path) return test_loss, acc - def get_model_weights(self) -> OrderedDict[str, Tensor]: - """ - Share the model weights (on the cpu) - """ - return OrderedDict({k: v.cpu() for k, v in self.model.state_dict().items()}) - def aggregate( self, - models_wts: List[OrderedDict[str, Tensor]], + models_wts: List[OrderedDict[str, Any]], collab_weights: Optional[List[float]] = None, keys_to_ignore: List[str] = [], ) -> None: @@ -768,6 +773,7 @@ def aggregate( Returns: None """ + models_coeffs: List[Tuple[OrderedDict[str, Tensor], float]] = [] # insert the current model weights at the position self.node_id models_wts.insert(self.node_id - 1, self.get_model_weights()) @@ -778,6 +784,12 @@ def aggregate( models_wts, collab_weights = self.strip_empty_models(models_wts, collab_weights) collab_weights = [w / sum(collab_weights) for w in collab_weights] + senders = [model["sender"] for model in models_wts if "sender" in model] + rounds = [model["round"] for model in models_wts if "round" in model] + for i in range(len(models_wts)): + assert "model" in models_wts[i], "Model not found in the received message" + models_wts[i] = models_wts[i]["model"] + for idx, model_wts in enumerate(models_wts): models_coeffs.append((model_wts, collab_weights[idx])) diff --git a/src/algos/fl.py b/src/algos/fl.py index db80549..a45a65a 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -31,35 +31,42 @@ def local_test(self, **kwargs: Any): return [test_loss, test_acc, time_taken] - def get_model_weights(self, **kwargs: Any) -> Dict[str, Tensor]: + def get_model_weights(self, **kwargs: Any) -> Dict[str, Any]: """ Overwrite the get_model_weights method of the BaseNode to add malicious attacks TODO: this should be moved to BaseClient """ + message = {"sender": self.node_id, "round": self.round} + malicious_type = self.config.get("malicious_type", "normal") if malicious_type == "normal": - return self.model.state_dict() # type: ignore + message["model"] = self.model.state_dict() # type: ignore elif malicious_type == "bad_weights": # Corrupt the weights - return BadWeightsAttack( + message["model"] = BadWeightsAttack( self.config, self.model.state_dict() ).get_representation() elif malicious_type == "sign_flip": # Flip the sign of the weights, also TODO: consider label flipping - return SignFlipAttack( + message["model"] = SignFlipAttack( self.config, self.model.state_dict() ).get_representation() elif malicious_type == "add_noise": # Add noise to the weights - return AddNoiseAttack( + message["model"] = AddNoiseAttack( self.config, self.model.state_dict() ).get_representation() else: - return self.model.state_dict() # type: ignore - return self.model.state_dict() # type: ignore + message["model"] = self.model.state_dict() # type: ignore + + # move the model to cpu before sending + for key in message["model"].keys(): + message["model"][key] = message["model"][key].to("cpu") + + return message # type: ignore def run_protocol(self): stats: Dict[str, Any] = {} @@ -106,13 +113,18 @@ def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): return avgd_wts def aggregate( - self, representation_list: List[OrderedDict[str, Tensor]], **kwargs: Any + self, representation_list: List[OrderedDict[str, Any]], **kwargs: Any ) -> OrderedDict[str, Tensor]: """ Aggregate the model weights """ representation_list, _ = self.strip_empty_models(representation_list) if len(representation_list) > 0: + senders = [rep["sender"] for rep in representation_list if "sender" in rep] + rounds = [rep["round"] for rep in representation_list if "round" in rep] + for i in range(len(representation_list)): + representation_list[i] = representation_list[i]["model"] + avg_wts = self.fed_avg(representation_list) return avg_wts else: diff --git a/src/utils/communication/grpc/grpc_utils.py b/src/utils/communication/grpc/grpc_utils.py index 067e4f3..da7bd4a 100644 --- a/src/utils/communication/grpc/grpc_utils.py +++ b/src/utils/communication/grpc/grpc_utils.py @@ -1,6 +1,6 @@ from collections import OrderedDict import io -from typing import Dict +from typing import Dict, Any import torch @@ -19,3 +19,17 @@ def deserialize_model(model_bytes: bytes) -> OrderedDict[str, torch.Tensor]: buffer.seek(0) model_wts = torch.load(buffer) # type: ignore return model_wts + +def serialize_message(message: Dict[str, Any]) -> bytes: + # assumes all tensors are on cpu + buffer = io.BytesIO() + torch.save(message, buffer) # type: ignore + buffer.seek(0) + return buffer.read() + + +def deserialize_message(model_bytes: bytes) -> OrderedDict[str, Any]: + buffer = io.BytesIO(model_bytes) + buffer.seek(0) + message = torch.load(buffer) # type: ignore + return message \ No newline at end of file diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py index 5af56b0..47a04e4 100644 --- a/src/utils/communication/grpc/main.py +++ b/src/utils/communication/grpc/main.py @@ -11,7 +11,7 @@ from urllib.parse import unquote import grpc # type: ignore from torch import Tensor -from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model +from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model, serialize_message, deserialize_message import os import sys @@ -148,7 +148,7 @@ def get_model(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> c raise Exception("Base node not registered") with self.lock: if self.is_working: - model = comm_pb2.Model(buffer=serialize_model(self.base_node.get_model_weights())) + model = comm_pb2.Model(buffer=serialize_message(self.base_node.get_model_weights())) else: assert self.base_node.dropout.dropout_enabled, "Empty models are only supported when Dropout is enabled." model = comm_pb2.Model(buffer=EMPTY_MODEL_TAG) @@ -414,7 +414,7 @@ def callback_fn(stub: comm_pb2_grpc.CommunicationServerStub) -> OrderedDict[str, return OrderedDict() with self.servicer.lock: self.servicer.communication_cost_received += model.ByteSize() - return deserialize_model(model.buffer) # type: ignore + return deserialize_message(model.buffer) # type: ignore for id in node_ids: rank = self.get_host_from_rank(id)