From cd614bbafbc9743070e00ac33b4bec7ded22a993 Mon Sep 17 00:00:00 2001 From: Kathryn L <58600921+kathrynle20@users.noreply.github.com> Date: Sun, 1 Dec 2024 13:30:42 -0500 Subject: [PATCH] Automated GitHub Actions Test for gRPC Training (#148) * added MPI Communication class * added send thread, merged 2 classes * improved comments * testing mpi, model weights not acquired * mpi works, occassional deadlock issue * merged send and listener threads * first draft of test * using python3.10 * made testing sys and algo configs * testing workflow * predict next move ish * moved quorum send * moved quorum send * using traditional fl algo * run test only during push to main * new dump_dir * remove send_status from proto * changed dump_dir * small changes --- .github/workflows/train.yml | 56 ++++++++ .vscode/settings.json | 4 +- src/algos/fl.py | 1 - src/configs/algo_config.py | 3 +- src/configs/algo_config_test.py | 26 ++++ src/configs/sys_config.py | 15 ++- src/configs/sys_config_test.py | 126 ++++++++++++++++++ src/main.py | 1 - src/scheduler.py | 2 +- src/utils/communication/comm_utils.py | 6 +- src/utils/communication/grpc/main.py | 12 +- src/utils/communication/interface.py | 4 + src/utils/communication/mpi.py | 182 +++++++++++++++++++++++++- 13 files changed, 417 insertions(+), 21 deletions(-) create mode 100644 .github/workflows/train.yml create mode 100644 src/configs/algo_config_test.py create mode 100644 src/configs/sys_config_test.py diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml new file mode 100644 index 00000000..3d7e86fc --- /dev/null +++ b/.github/workflows/train.yml @@ -0,0 +1,56 @@ +name: Test Training Code with gRPC + +on: + workflow_dispatch: + push: + branches: + # - main + - "*" + pull_request: + branches: + - main + +env: + ACTIONS_STEP_DEBUG: true + +jobs: + train-check: + runs-on: ubuntu-latest + + steps: + # Step 1: Checkout the code + - name: Checkout repository + uses: actions/checkout@v3 + + # Step 2: Set up Python + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" # Specify the Python version you're using + + # Step 3: Install dependencies + - name: Install dependencies + run: | + sudo apt update + sudo apt install -y libopenmpi-dev openmpi-bin + sudo apt-get install -y libgl1 libglib2.0-0 + + pip install -r requirements.txt + + # Step 4: Run gRPC server and client + - name: Run test + run: | + cd src + # chmod +x ./configs/algo_config_test.py + + echo "starting main grpc" + python main_grpc.py -n 4 -host localhost + echo "starting main" + python main.py -super true -s "./configs/sys_config_test.py" + echo "done" + + # further checks: + # only 5 rounds + # gRPC only? or also MPI? + # num of samples + # num users and nodes diff --git a/.vscode/settings.json b/.vscode/settings.json index 9e6483a3..d6e26387 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,3 @@ { - "python.analysis.typeCheckingMode": "strict" -} \ No newline at end of file + "python.analysis.typeCheckingMode": "strict" +} diff --git a/src/algos/fl.py b/src/algos/fl.py index 320bb413..569213e3 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -119,7 +119,6 @@ def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): num_users = len(model_wts) coeff = 1 / num_users avgd_wts: OrderedDict[str, Tensor] = OrderedDict() - for key in model_wts[0].keys(): avgd_wts[key] = sum(coeff * m[key] for m in model_wts) # type: ignore diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index 3d859009..974ca1b3 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -204,7 +204,6 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st "algo": "fedstatic", "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore "rounds": 3, - # Model parameters "optimizer": "sgd", # TODO comment out for real training "model": "resnet10", @@ -362,4 +361,6 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st malicious_traditional_model_update_attack, ] + default_config_list: List[ConfigType] = [traditional_fl] +# default_config_list: List[ConfigType] = [fedstatic, fedstatic, fedstatic, fedstatic] diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py new file mode 100644 index 00000000..2f2c7fcf --- /dev/null +++ b/src/configs/algo_config_test.py @@ -0,0 +1,26 @@ +from utils.types import ConfigType + +# fedstatic: ConfigType = { +# # Collaboration setup +# "algo": "fedstatic", +# "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore +# "rounds": 1, + +# # Model parameters +# "model": "resnet10", +# "model_lr": 3e-4, +# "batch_size": 256, +# } + +traditional_fl: ConfigType = { + # Collaboration setup + "algo": "fedavg", + "rounds": 1, + + # Model parameters + "model": "resnet10", + "model_lr": 3e-4, + "batch_size": 256, +} + +# default_config_list: List[ConfigType] = [fedstatic, fedstatic, fedstatic, fedstatic] \ No newline at end of file diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 99a33fea..2e7e0438 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -158,11 +158,14 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): CIAR10_DPATH = "./datasets/imgs/cifar10/" NUM_COLLABORATORS = 1 -DUMP_DIR = "/mas/camera/Experiments/SONAR/abhi/" +# DUMP_DIR = "../../../../../../../home/" +DUMP_DIR = "/tmp/" +num_users = 3 mpi_system_config: ConfigType = { "exp_id": "", "comm": {"type": "MPI"}, + "num_users": num_users, "num_collaborators": NUM_COLLABORATORS, "dset": CIFAR10_DSET, "dump_dir": DUMP_DIR, @@ -177,11 +180,9 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): # "algo": get_algo_configs(num_users=3, algo_configs=algo_configs_list), "algos": get_algo_configs( num_users=3, - algo_configs=malicious_algo_config_list, - assignment_method="distribution", - distribution={0: 1, 1: 1, 2: 1}, + algo_configs=default_config_list ), # type: ignore - "samples_per_user": 1000, # TODO: To model scenarios where different users have different number of samples + "samples_per_user": 5555, # TODO: To model scenarios where different users have different number of samples # we need to make this a dictionary with user_id as key and number of samples as value "train_label_distribution": "iid", # Either "iid", "non_iid" "support" "test_label_distribution": "iid", # Either "iid", "non_iid" "support" @@ -348,7 +349,8 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "device_ids": get_device_ids(num_users, gpu_ids), # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore "algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore - "samples_per_user": 50000 // num_users, # distributed equally + # "samples_per_user": 50000 // num_users, # distributed equally + "samples_per_user": 100, "train_label_distribution": "non_iid", "test_label_distribution": "iid", "alpha_data": 1.0, @@ -389,3 +391,4 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): current_config = grpc_system_config # current_config = mpi_system_config + diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py new file mode 100644 index 00000000..f3575419 --- /dev/null +++ b/src/configs/sys_config_test.py @@ -0,0 +1,126 @@ +from typing import Dict, List, Literal, Optional +import random +from utils.types import ConfigType + +from .algo_config_test import ( + traditional_fl +) + +def get_device_ids(num_users: int, gpus_available: List[int | Literal["cpu"]]) -> Dict[str, List[int | Literal["cpu"]]]: + """ + Get the GPU device IDs for the users. + """ + # TODO: Make it multi-host + device_ids: Dict[str, List[int | Literal["cpu"]]] = {} + for i in range(num_users + 1): # +1 for the super-node + index = i % len(gpus_available) + gpu_id = gpus_available[index] + device_ids[f"node_{i}"] = [gpu_id] + return device_ids + + +def get_algo_configs( + num_users: int, + algo_configs: List[ConfigType], + assignment_method: Literal[ + "sequential", "random", "mapping", "distribution" + ] = "sequential", + seed: Optional[int] = 1, + mapping: Optional[List[int]] = None, + distribution: Optional[Dict[int, int]] = None, +) -> Dict[str, ConfigType]: + """ + Assign an algorithm configuration to each node, allowing for repetition. + sequential: Assigns the algo_configs sequentially to the nodes + random: Assigns the algo_configs randomly to the nodes + mapping: Assigns the algo_configs based on the mapping of node index to algo index provided + distribution: Assigns the algo_configs based on the distribution of algo index to number of nodes provided + """ + algo_config_map: Dict[str, ConfigType] = {} + algo_config_map["node_0"] = algo_configs[0] # Super-node + if assignment_method == "sequential": + for i in range(1, num_users + 1): + algo_config_map[f"node_{i}"] = algo_configs[i % len(algo_configs)] + elif assignment_method == "random": + for i in range(1, num_users + 1): + algo_config_map[f"node_{i}"] = random.choice(algo_configs) + elif assignment_method == "mapping": + if not mapping: + raise ValueError("Mapping must be provided for assignment method 'mapping'") + assert len(mapping) == num_users + for i in range(1, num_users + 1): + algo_config_map[f"node_{i}"] = algo_configs[mapping[i - 1]] + elif assignment_method == "distribution": + if not distribution: + raise ValueError( + "Distribution must be provided for assignment method 'distribution'" + ) + total_users = sum(distribution.values()) + assert total_users == num_users + + # List of node indices to assign + node_indices = list(range(1, total_users + 1)) + # Seed for reproducibility + random.seed(seed) + # Shuffle the node indices based on the seed + random.shuffle(node_indices) + + # Assign nodes based on the shuffled indices + current_index = 0 + for algo_index, num_nodes in distribution.items(): + for i in range(num_nodes): + node_id = node_indices[current_index] + algo_config_map[f"node_{node_id}"] = algo_configs[algo_index] + current_index += 1 + else: + raise ValueError(f"Invalid assignment method: {assignment_method}") + # print("algo config mapping is: ", algo_config_map) + return algo_config_map + +CIFAR10_DSET = "cifar10" +CIAR10_DPATH = "./datasets/imgs/cifar10/" + +# DUMP_DIR = "../../../../../../../home/" +DUMP_DIR = "/tmp/" + +NUM_COLLABORATORS = 1 +num_users = 4 + +dropout_dict = { + "distribution_dict": { # leave dict empty to disable dropout + "method": "uniform", # "uniform", "normal" + "parameters": {} # "mean": 0.5, "std": 0.1 in case of normal distribution + }, + "dropout_rate": 0.0, # cutoff for dropout: [0,1] + "dropout_correlation": 0.0, # correlation between dropouts of successive rounds: [0,1] +} + +dropout_dicts = {"node_0": {}} +for i in range(1, num_users + 1): + dropout_dicts[f"node_{i}"] = dropout_dict + +gpu_ids = [2, 3, 5, 6] + +grpc_system_config: ConfigType = { + "exp_id": "static", + "num_users": num_users, + "num_collaborators": NUM_COLLABORATORS, + "comm": {"type": "GRPC", "synchronous": True, "peer_ids": ["localhost:50048"]}, # The super-node + "dset": CIFAR10_DSET, + "dump_dir": DUMP_DIR, + "dpath": CIAR10_DPATH, + "seed": 2, + "device_ids": get_device_ids(num_users, gpu_ids), + # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore + "algos": get_algo_configs(num_users=num_users, algo_configs=[traditional_fl]), # type: ignore + # "samples_per_user": 50000 // num_users, # distributed equally + "samples_per_user": 100, + "train_label_distribution": "non_iid", + "test_label_distribution": "iid", + "alpha_data": 1.0, + "exp_keys": [], + "dropout_dicts": dropout_dicts, + "test_samples_per_user": 200, +} + +current_config = grpc_system_config \ No newline at end of file diff --git a/src/main.py b/src/main.py index 384c76fd..fe4aaa7b 100644 --- a/src/main.py +++ b/src/main.py @@ -70,6 +70,5 @@ # Start the scheduler scheduler.install_config() scheduler.initialize() - # Run the job scheduler.run_job() diff --git a/src/scheduler.py b/src/scheduler.py index fa363c6a..55da449d 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -108,7 +108,6 @@ def initialize(self, copy_souce_code: bool = True) -> None: random.seed(seed) numpy.random.seed(seed) self.merge_configs() - if self.communication.get_rank() == 0: if copy_souce_code: copy_source_code(self.config) @@ -130,6 +129,7 @@ def initialize(self, copy_souce_code: bool = True) -> None: rank=self.communication.get_rank(), comm_utils=self.communication, ) + self.communication.send_quorum() def run_job(self) -> None: self.node.run_protocol() diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py index c57d02d7..e2155de5 100644 --- a/src/utils/communication/comm_utils.py +++ b/src/utils/communication/comm_utils.py @@ -2,6 +2,7 @@ from utils.communication.grpc.main import GRPCCommunication from typing import Any, Dict, List, Tuple, TYPE_CHECKING # from utils.communication.mpi import MPICommUtils +# from mpi4py import MPI if TYPE_CHECKING: from algos.base_class import BaseNode @@ -21,7 +22,7 @@ def create_communication( ): comm_type = comm_type if comm_type == CommunicationType.MPI: - raise NotImplementedError("MPI's new version not yet implemented. Please use GRPC. See https://github.com/aidecentralized/sonar/issues/96 for more details.") + return MPICommUtils(config) elif comm_type == CommunicationType.GRPC: return GRPCCommunication(config) elif comm_type == CommunicationType.HTTP: @@ -71,6 +72,9 @@ def receive(self, node_ids: List[int]) -> Any: def broadcast(self, data: Any, tag: int = 0): self.comm.broadcast(data) + def send_quorum(self): + self.comm.send_quorum() + def all_gather(self, tag: int = 0): return self.comm.all_gather() diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py index 7fb687ac..b850f0e3 100644 --- a/src/utils/communication/grpc/main.py +++ b/src/utils/communication/grpc/main.py @@ -343,7 +343,18 @@ def initialize(self): peer_ids=self.peer_ids_to_proto(self.servicer.peer_ids) ) stub.send_peer_ids(proto_msg) # type: ignore + # stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore + + def send_quorum(self): + """ Send the quorum status to all nodes after peer IDs are sent. """ + if self.rank == 0: + for peer_id in self.servicer.peer_ids: + if not self.is_own_id(peer_id): + host = self.get_host_from_rank(peer_id) + with grpc.insecure_channel(host) as channel: # type: ignore + stub = comm_pb2_grpc.CommunicationServerStub(channel) stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore + print(f"Quorum status sent to all nodes.") def get_host_from_rank(self, rank: int) -> str: for peer_id in self.servicer.peer_ids: @@ -370,7 +381,6 @@ def send_with_retries(self, dest_host: str, buffer: Any) -> Any: raise Exception("Failed to send data. Receiver unreachable.") - def send(self, dest: str | int, data: OrderedDict[str, Any]): """ data should be a python dictionary diff --git a/src/utils/communication/interface.py b/src/utils/communication/interface.py index b1daac27..8df3c067 100644 --- a/src/utils/communication/interface.py +++ b/src/utils/communication/interface.py @@ -18,6 +18,10 @@ def send(self, dest: str | int, data: Any): def receive(self, node_ids: List[int]) -> Any: pass + @abstractmethod + def send_quorum(self) -> Any: + pass + @abstractmethod def broadcast(self, data: Any): pass diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 8f78f154..e9b20042 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -1,7 +1,15 @@ -from typing import Dict, Any, List +from collections import OrderedDict +from typing import Dict, Any, List, TYPE_CHECKING from mpi4py import MPI +from torch import Tensor from utils.communication.interface import CommunicationInterface +import threading +import time +import random +import numpy as np +if TYPE_CHECKING: + from algos.base_class import BaseNode class MPICommUtils(CommunicationInterface): def __init__(self, config: Dict[str, Dict[str, Any]]): @@ -9,19 +17,135 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() + self.num_users: int = int(config["num_users"]) # type: ignore + self.finished = False + + # Ensure that we are using thread safe threading level + self.required_threading_level = MPI.THREAD_MULTIPLE + self.threading_level = MPI.Query_thread() + # Make sure to check for MPI_THREAD_MULTIPLE threading level to support + # thread safe calls to send and recv + if self.required_threading_level > self.threading_level: + raise RuntimeError(f"Insufficient thread support. Required: {self.required_threading_level}, Current: {self.threading_level}") + + self.send_event = threading.Event() + # Ensures that the listener thread and send thread are not using self.request_source at the same time + self.lock = threading.Lock() + self.request_source: int | None = None + + self.is_working = True + self.communication_cost_received: int = 0 + self.communication_cost_sent: int = 0 + + self.base_node: BaseNode | None = None + + self.listener_thread = threading.Thread(target=self.listener) + self.listener_thread.start() + def initialize(self): pass - def send(self, dest: str | int, data: Any): - self.comm.send(data, dest=int(dest)) + def send_quorum(self) -> Any: + # return super().send_quorum(node_ids) + pass + + def register_self(self, obj: "BaseNode"): + self.base_node = obj + + def get_comm_cost(self): + with self.lock: + return self.communication_cost_received, self.communication_cost_sent + + def listener(self): + """ + Runs on listener thread on each node to receive a send request + Once send request is received, the listener thread informs the send + thread to send the data to the requesting node. + """ + while not self.finished: + status = MPI.Status() + # look for message with tag 1 (represents send request) + if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): + with self.lock: + # self.request_source = status.Get_source() + dest = status.Get_source() - def receive(self, node_ids: str | int) -> Any: - return self.comm.recv(source=int(node_ids)) + print(f"Node {self.rank} received request from {self.request_source}") + # receive_request = self.comm.irecv(source=self.request_source, tag=1) + # receive_request.wait() + self.comm.recv(source=dest, tag=1) + self.send(dest) + print(f"Node {self.rank} listener thread ended") + def get_model(self) -> List[OrderedDict[str, Tensor]] | None: + print(f"getting model from {self.rank}, {self.base_node}") + if not self.base_node: + raise Exception("Base node not registered") + with self.lock: + if self.is_working: + model = self.base_node.get_model_weights() + model = [model] + print(f"Model from {self.rank} acquired") + else: + assert self.base_node.dropout.dropout_enabled, "Empty models are only supported when Dropout is enabled." + model = None + return model + + def send(self, dest: int): + """ + Node will wait for a request to send data and then send the + data to requesting node. + """ + if self.finished: + return + + data = self.get_model() + print(f"Node {self.rank} is sending data to {dest}") + # req = self.comm.Isend(data, dest=int(dest)) + # req.wait() + self.comm.send(data, dest=int(dest)) + + def receive(self, node_ids: List[int]) -> Any: + """ + Node will send a request for data and wait to receive data. + """ + max_tries = 10 + assert len(node_ids) == 1, "Too many node_ids to unpack" + node = node_ids[0] + while max_tries > 0: + try: + print(f"Node {self.rank} receiving from {node}") + self.comm.send("", dest=node, tag=1) + # recv_req = self.comm.Irecv([], source=node) + # received_data = recv_req.wait() + received_data = self.comm.recv(source=node) + print(f"Node {self.rank} received data from {node}: {bool(received_data)}") + if not received_data: + raise Exception("Received empty data") + return received_data + except MPI.Exception as e: + print(f"MPI failed {10 - max_tries} times: MPI ERROR: {e}", "Retrying...") + import traceback + print(f"Traceback: {traceback.print_exc()}") + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 + except Exception as e: + print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") + import traceback + print(f"Traceback: {traceback.print_exc()}") + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 + print(f"Node {self.rank} received") + + # deprecated broadcast function def broadcast(self, data: Any): for i in range(1, self.size): if i != self.rank: - self.send(i, data) + self.comm.send(data, dest=i) def all_gather(self): """ @@ -29,8 +153,52 @@ def all_gather(self): """ items: List[Any] = [] for i in range(1, self.size): + print(f"receiving this data: {self.receive(i)}") items.append(self.receive(i)) return items + + def send_finished(self): + self.comm.send("Finished", dest=0, tag=2) def finalize(self): - pass + # 1. All nodes send finished to the super node + # 2. super node will wait for all nodes to send finished + # 3. super node will then send bye to all nodes + # 4. all nodes will wait for the bye and then exit + # this is to ensure that all nodes have finished + # and no one leaves early + if self.rank == 0: + quorum_threshold = self.num_users - 1 # No +1 for the super node because it doesn't send finished + num_finished: set[int] = set() + status = MPI.Status() + while len(num_finished) < quorum_threshold: + print( + f"Waiting for {quorum_threshold} users to finish, {num_finished} have finished so far" + ) + # get finished nodes + self.comm.recv(source=MPI.ANY_SOURCE, tag=2, status=status) + print(f"received finish message from {status.Get_source()}") + num_finished.add(status.Get_source()) + + else: + # send finished to the super node + print(f"Node {self.rank} sent finish message") + self.send_finished() + + message = self.comm.bcast("Done", root=0) + self.finished = True + self.send_event.set() + print(f"Node {self.rank} received {message}, finished") + self.comm.Barrier() + self.listener_thread.join() + print(f"Node {self.rank} listener thread done") + print(f"Node {self.rank} listener thread is {self.listener_thread.is_alive()}") + print(f"Node {self.rank} {threading.enumerate()}") + self.comm.Barrier() + print(f"Node {self.rank}: all nodes synchronized") + MPI.Finalize() + + def set_is_working(self, is_working: bool): + with self.lock: + self.is_working = is_working +