From f0192d48df0019b6711f2cfe5b60e09e0d5d2466 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sat, 19 Oct 2024 10:15:11 -0400 Subject: [PATCH 01/19] added MPI Communication class --- src/utils/communication/mpi.py | 84 ++++++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 8f78f154..0bc7a9d6 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -1,6 +1,9 @@ from typing import Dict, Any, List from mpi4py import MPI from utils.communication.interface import CommunicationInterface +import threading +import time +from enum import Enum class MPICommUtils(CommunicationInterface): @@ -12,11 +15,11 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): def initialize(self): pass - def send(self, dest: str | int, data: Any): - self.comm.send(data, dest=int(dest)) + # def send(self, dest: str | int, data: Any): + # self.comm.send(data, dest=int(dest)) - def receive(self, node_ids: str | int) -> Any: - return self.comm.recv(source=int(node_ids)) + # def receive(self, node_ids: str | int) -> Any: + # return self.comm.recv(source=int(node_ids)) def broadcast(self, data: Any): for i in range(1, self.size): @@ -34,3 +37,76 @@ def all_gather(self): def finalize(self): pass + + +class MPICommunication(MPICommUtils): + def __init__(self, config: Dict[str, Dict[str, Any]]): + super().__init__(config) + listener_thread = threading.Thread(target=self.listener, daemon=True) + listener_thread.start() + self.send_event = threading.Event() + self.request_source: int | None = None + + def listener(self): + while True: + status = MPI.Status() + if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status): + source = status.Get_source() + tag = status.Get_tag() + count = status.Get_count(MPI.BYTE) # Get the number of bytes in the message + # If a message is available, receive it + data_to_recv = bytearray(count) + req = self.comm.irecv([data_to_recv, MPI.BYTE], source=source, tag=tag) + req.wait() + # Convert the byte array back to a string + received_message = data_to_recv.decode('utf-8') + + if received_message == "Requesting Information": + self.send_event.set() + + self.send_event.clear() + break + time.sleep(1) # Simulate waiting time + + def send(self, dest: str | int, data: Any, tag: int): + while True: + # Wait until the listener thread detects a request + self.send_event.wait() + req = self.comm.isend(data, dest=int(dest), tag=tag) + req.wait() + + def receive(self, node_ids: str | int, tag: int) -> Any: + node_ids = int(node_ids) + message = "Requesting Information" + message_bytes = bytearray(message, 'utf-8') + send_req = self.comm.isend([message_bytes, MPI.BYTE], dest=node_ids, tag=tag) + send_req.wait() + recv_req = self.comm.irecv(source=node_ids, tag=tag) + return recv_req.wait() + +# MPI Server +""" +initialization(): + node spins up listener thread, threading (an extra thread might not be needed since iprobe exists). + call listen? + +listen(): + listener thread starts listening for send requests (use iprobe and irecv for message) + when send request is received, call the send() function + +send(): + gather and send info to requesting node using comm.isend + comm.wait + +""" + +# MPI Client +""" +initialization(): + node is initialized + +receive(): + node sends request to sending node using isend() + node calls irecv and waits for response +""" + From 755fc073f76713f9a0ef6cecd62fa21b764444a7 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 21 Oct 2024 21:17:52 -0400 Subject: [PATCH 02/19] added send thread, merged 2 classes --- src/utils/communication/mpi.py | 144 +++++++++++++++------------------ 1 file changed, 65 insertions(+), 79 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 0bc7a9d6..89bd21cd 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -3,110 +3,96 @@ from utils.communication.interface import CommunicationInterface import threading import time -from enum import Enum - class MPICommUtils(CommunicationInterface): - def __init__(self, config: Dict[str, Dict[str, Any]]): + def __init__(self, config: Dict[str, Dict[str, Any]], data: Any): self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() - def initialize(self): - pass - - # def send(self, dest: str | int, data: Any): - # self.comm.send(data, dest=int(dest)) - - # def receive(self, node_ids: str | int) -> Any: - # return self.comm.recv(source=int(node_ids)) - - def broadcast(self, data: Any): - for i in range(1, self.size): - if i != self.rank: - self.send(i, data) - - def all_gather(self): - """ - This function is used to gather data from all the nodes. - """ - items: List[Any] = [] - for i in range(1, self.size): - items.append(self.receive(i)) - return items - - def finalize(self): - pass - - -class MPICommunication(MPICommUtils): - def __init__(self, config: Dict[str, Dict[str, Any]]): - super().__init__(config) + # Ensure that we are using thread safe threading level + self.required_threading_level = MPI.THREAD_MULTIPLE + self.threading_level = MPI.Query_thread() + # Make sure to check for MPI_THREAD_MULTIPLE threading level to support + # thread safe calls to send and recv + if self.required_threading_level > self.threading_level: + raise RuntimeError(f"Insufficient thread support. Required: {self.required_threading_level}, Current: {self.threading_level}") + listener_thread = threading.Thread(target=self.listener, daemon=True) listener_thread.start() + send_thread = threading.Thread(target=self.send, args=(data)) + send_thread.start() + self.send_event = threading.Event() + # Ensures that the listener thread and send thread are not using self.request_source at the same time + self.source_node_lock = threading.Lock() self.request_source: int | None = None + def initialize(self): + pass + def listener(self): + """ + Runs on listener thread on each node to receive a send request + Once send request is received, the listener thread informs the main + thread to send the data to the requesting node. + """ while True: status = MPI.Status() - if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status): - source = status.Get_source() - tag = status.Get_tag() - count = status.Get_count(MPI.BYTE) # Get the number of bytes in the message - # If a message is available, receive it - data_to_recv = bytearray(count) - req = self.comm.irecv([data_to_recv, MPI.BYTE], source=source, tag=tag) - req.wait() - # Convert the byte array back to a string - received_message = data_to_recv.decode('utf-8') - - if received_message == "Requesting Information": - self.send_event.set() + # look for message with tag 1 (represents send request) + if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): + with self.source_node_lock: + self.request_source = status.Get_source() - self.send_event.clear() - break + self.comm.irecv(source=self.request_source, tag=1) + self.send_event.set() time.sleep(1) # Simulate waiting time - def send(self, dest: str | int, data: Any, tag: int): + def send(self, data: Any): + """ + Node will wait until request is received and then send + data to requesting node. + """ while True: # Wait until the listener thread detects a request self.send_event.wait() - req = self.comm.isend(data, dest=int(dest), tag=tag) - req.wait() + with self.source_node_lock: + dest = self.request_source - def receive(self, node_ids: str | int, tag: int) -> Any: + if dest is not None: + req = self.comm.isend(data, dest=int(dest)) + req.wait() + + with self.source_node_lock: + self.request_source = None + + self.send_event.clear() + + def receive(self, node_ids: str | int) -> Any: + """ + Node will send a request and wait to receive data. + """ node_ids = int(node_ids) - message = "Requesting Information" - message_bytes = bytearray(message, 'utf-8') - send_req = self.comm.isend([message_bytes, MPI.BYTE], dest=node_ids, tag=tag) + send_req = self.comm.isend("", dest=node_ids, tag=1) send_req.wait() - recv_req = self.comm.irecv(source=node_ids, tag=tag) + recv_req = self.comm.irecv(source=node_ids) return recv_req.wait() - -# MPI Server -""" -initialization(): - node spins up listener thread, threading (an extra thread might not be needed since iprobe exists). - call listen? - -listen(): - listener thread starts listening for send requests (use iprobe and irecv for message) - when send request is received, call the send() function - -send(): - gather and send info to requesting node using comm.isend - comm.wait -""" + # depreciated broadcast function + # def broadcast(self, data: Any): + # for i in range(1, self.size): + # if i != self.rank: + # self.send(i, data) -# MPI Client -""" -initialization(): - node is initialized + def all_gather(self): + """ + This function is used to gather data from all the nodes. + """ + items: List[Any] = [] + for i in range(1, self.size): + items.append(self.receive(i)) + return items -receive(): - node sends request to sending node using isend() - node calls irecv and waits for response -""" + def finalize(self): + pass From d37c35bd40c6817476a076265a9044e1a957131a Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Tue, 22 Oct 2024 10:30:21 -0400 Subject: [PATCH 03/19] improved comments --- src/utils/communication/mpi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 89bd21cd..3ea3e334 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -34,7 +34,7 @@ def initialize(self): def listener(self): """ Runs on listener thread on each node to receive a send request - Once send request is received, the listener thread informs the main + Once send request is received, the listener thread informs the send thread to send the data to the requesting node. """ while True: @@ -50,7 +50,7 @@ def listener(self): def send(self, data: Any): """ - Node will wait until request is received and then send + Node will wait for a request to send data and then send the data to requesting node. """ while True: @@ -70,7 +70,7 @@ def send(self, data: Any): def receive(self, node_ids: str | int) -> Any: """ - Node will send a request and wait to receive data. + Node will send a request for data and wait to receive data. """ node_ids = int(node_ids) send_req = self.comm.isend("", dest=node_ids, tag=1) @@ -78,7 +78,7 @@ def receive(self, node_ids: str | int) -> Any: recv_req = self.comm.irecv(source=node_ids) return recv_req.wait() - # depreciated broadcast function + # deprecated broadcast function # def broadcast(self, data: Any): # for i in range(1, self.size): # if i != self.rank: From 2f087e155129e1a868904f4df3f28ac6a2570e24 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 28 Oct 2024 17:16:10 -0400 Subject: [PATCH 04/19] testing mpi, model weights not acquired --- src/algos/fl.py | 2 +- src/configs/algo_config.py | 3 +- src/configs/sys_config.py | 16 ++--- src/main.py | 1 - src/scheduler.py | 2 - src/utils/communication/comm_utils.py | 4 +- src/utils/communication/mpi.py | 94 ++++++++++++++++++++------- 7 files changed, 85 insertions(+), 37 deletions(-) diff --git a/src/algos/fl.py b/src/algos/fl.py index db805490..98a09a47 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -96,7 +96,7 @@ def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]): num_users = len(model_wts) coeff = 1 / num_users avgd_wts: OrderedDict[str, Tensor] = OrderedDict() - + print(f"model weights: {model_wts}") for key in model_wts[0].keys(): avgd_wts[key] = sum(coeff * m[key] for m in model_wts) # type: ignore diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index 9a4aa764..60662d18 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -328,4 +328,5 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st ] -default_config_list: List[ConfigType] = [traditional_fl] +# default_config_list: List[ConfigType] = [traditional_fl] +default_config_list: List[ConfigType] = [fedstatic, fedstatic, fedstatic, fedstatic] diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 44ae73a0..fb88171f 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -145,11 +145,13 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): CIAR10_DPATH = "./datasets/imgs/cifar10/" NUM_COLLABORATORS = 1 -DUMP_DIR = "/mas/camera/Experiments/SONAR/abhi/" +DUMP_DIR = "/Users/kathryn/MIT/UROP/Media Lab/sonar_experiments/" +num_users = 4 mpi_system_config: ConfigType = { "exp_id": "", "comm": {"type": "MPI"}, + "num_users": num_users, "num_collaborators": NUM_COLLABORATORS, "dset": CIFAR10_DSET, "dump_dir": DUMP_DIR, @@ -159,14 +161,12 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): # The device_ids dictionary depicts the GPUs on which the nodes reside. # For a single-GPU environment, the config will look as follows (as it follows a 0-based indexing): # "device_ids": {"node_0": [0], "node_1": [0], "node_2": [0], "node_3": [0]}, - "device_ids": get_device_ids(num_users=3, gpus_available=[1, 2]), + "device_ids": get_device_ids(num_users=4, gpus_available=[1, 2]), # use this when the list needs to be imported from the algo_config # "algo": get_algo_configs(num_users=3, algo_configs=algo_configs_list), "algos": get_algo_configs( - num_users=3, - algo_configs=malicious_algo_config_list, - assignment_method="distribution", - distribution={0: 1, 1: 1, 2: 1}, + num_users=4, + algo_configs=default_config_list ), # type: ignore "samples_per_user": 1000, # TODO: To model scenarios where different users have different number of samples # we need to make this a dictionary with user_id as key and number of samples as value @@ -342,5 +342,5 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dropout_dicts": dropout_dicts, } -current_config = grpc_system_config -# current_config = mpi_system_config +# current_config = grpc_system_config +current_config = mpi_system_config diff --git a/src/main.py b/src/main.py index 655ac65f..d3a7c11d 100644 --- a/src/main.py +++ b/src/main.py @@ -66,6 +66,5 @@ scheduler.install_config() scheduler.initialize() - # Run the job scheduler.run_job() diff --git a/src/scheduler.py b/src/scheduler.py index 0aec0945..b1d0f7d6 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -104,7 +104,6 @@ def initialize(self, copy_souce_code: bool = True) -> None: random.seed(seed) numpy.random.seed(seed) self.merge_configs() - if self.communication.get_rank() == 0: if copy_souce_code: copy_source_code(self.config) @@ -120,7 +119,6 @@ def initialize(self, copy_souce_code: bool = True) -> None: # from a different machine print("Waiting for 10 seconds for the super node to create directories") time.sleep(10) - self.node = get_node( self.config, rank=self.communication.get_rank(), diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py index fc4bc5df..788b0709 100644 --- a/src/utils/communication/comm_utils.py +++ b/src/utils/communication/comm_utils.py @@ -1,7 +1,7 @@ from enum import Enum from utils.communication.grpc.main import GRPCCommunication from typing import Any, Dict, List, TYPE_CHECKING -# from utils.communication.mpi import MPICommUtils +from utils.communication.mpi import MPICommUtils if TYPE_CHECKING: from algos.base_class import BaseNode @@ -20,7 +20,7 @@ def create_communication( ): comm_type = comm_type if comm_type == CommunicationType.MPI: - raise NotImplementedError("MPI's new version not yet implemented. Please use GRPC. See https://github.com/aidecentralized/sonar/issues/96 for more details.") + return MPICommUtils(config) elif comm_type == CommunicationType.GRPC: return GRPCCommunication(config) elif comm_type == CommunicationType.HTTP: diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 3ea3e334..70a96bf0 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -1,11 +1,16 @@ -from typing import Dict, Any, List +from typing import Dict, Any, List, TYPE_CHECKING from mpi4py import MPI from utils.communication.interface import CommunicationInterface import threading import time +from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model +import random + +if TYPE_CHECKING: + from algos.base_class import BaseNode class MPICommUtils(CommunicationInterface): - def __init__(self, config: Dict[str, Dict[str, Any]], data: Any): + def __init__(self, config: Dict[str, Dict[str, Any]]): self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() @@ -18,19 +23,32 @@ def __init__(self, config: Dict[str, Dict[str, Any]], data: Any): if self.required_threading_level > self.threading_level: raise RuntimeError(f"Insufficient thread support. Required: {self.required_threading_level}, Current: {self.threading_level}") - listener_thread = threading.Thread(target=self.listener, daemon=True) - listener_thread.start() - send_thread = threading.Thread(target=self.send, args=(data)) - send_thread.start() - self.send_event = threading.Event() # Ensures that the listener thread and send thread are not using self.request_source at the same time - self.source_node_lock = threading.Lock() + self.lock = threading.Lock() self.request_source: int | None = None + self.is_working = True + self.communication_cost_received: int = 0 + self.communication_cost_sent: int = 0 + + self.base_node: BaseNode | None = None + + listener_thread = threading.Thread(target=self.listener, daemon=True) + listener_thread.start() + def initialize(self): pass + def register_self(self, obj: "BaseNode"): + self.base_node = obj + send_thread = threading.Thread(target=self.send) + send_thread.start() + + def get_comm_cost(self): + with self.lock: + return self.communication_cost_received, self.communication_cost_sent + def listener(self): """ Runs on listener thread on each node to receive a send request @@ -41,14 +59,28 @@ def listener(self): status = MPI.Status() # look for message with tag 1 (represents send request) if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): - with self.source_node_lock: + with self.lock: self.request_source = status.Get_source() self.comm.irecv(source=self.request_source, tag=1) self.send_event.set() time.sleep(1) # Simulate waiting time - def send(self, data: Any): + def get_model(self) -> bytes | None: + print(f"getting model from {self.rank}, {self.base_node}") + if not self.base_node: + raise Exception("Base node not registered") + with self.lock: + if self.is_working: + print("model is working") + model = serialize_model(self.base_node.get_model_weights()) + print(f"model data to be sent: {model}") + else: + assert self.base_node.dropout.dropout_enabled, "Empty models are only supported when Dropout is enabled." + model = None + return model + + def send(self): """ Node will wait for a request to send data and then send the data to requesting node. @@ -56,33 +88,46 @@ def send(self, data: Any): while True: # Wait until the listener thread detects a request self.send_event.wait() - with self.source_node_lock: + with self.lock: dest = self.request_source if dest is not None: + data = self.get_model() req = self.comm.isend(data, dest=int(dest)) req.wait() - with self.source_node_lock: + with self.lock: self.request_source = None self.send_event.clear() - def receive(self, node_ids: str | int) -> Any: + def receive(self, node_ids: List[int]) -> Any: """ Node will send a request for data and wait to receive data. """ - node_ids = int(node_ids) - send_req = self.comm.isend("", dest=node_ids, tag=1) - send_req.wait() - recv_req = self.comm.irecv(source=node_ids) - return recv_req.wait() + max_tries = 10 + for node in node_ids: + while max_tries > 0: + try: + self.comm.send("", dest=node, tag=1) + recv_req = self.comm.irecv(source=node) + received_data = recv_req.wait() + print(f"received data: {received_data}") + return deserialize_model(received_data) + except Exception as e: + print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") + import traceback + print(traceback.print_exc()) + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 # deprecated broadcast function - # def broadcast(self, data: Any): - # for i in range(1, self.size): - # if i != self.rank: - # self.send(i, data) + def broadcast(self, data: Any): + for i in range(1, self.size): + if i != self.rank: + self.comm.send(data, dest=i) def all_gather(self): """ @@ -90,9 +135,14 @@ def all_gather(self): """ items: List[Any] = [] for i in range(1, self.size): + print(f"receiving this data: {self.receive(i)}") items.append(self.receive(i)) return items def finalize(self): pass + def set_is_working(self, is_working: bool): + with self.lock: + self.is_working = is_working + From 464a6748c28e8bd0e2fca6b7a66dea55bfc3b47e Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 3 Nov 2024 14:12:09 -0500 Subject: [PATCH 05/19] mpi works, occassional deadlock issue --- src/utils/communication/comm_utils.py | 1 + src/utils/communication/mpi.py | 144 ++++++++++++++++++++------ 2 files changed, 113 insertions(+), 32 deletions(-) diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py index 788b0709..622da42e 100644 --- a/src/utils/communication/comm_utils.py +++ b/src/utils/communication/comm_utils.py @@ -2,6 +2,7 @@ from utils.communication.grpc.main import GRPCCommunication from typing import Any, Dict, List, TYPE_CHECKING from utils.communication.mpi import MPICommUtils +from mpi4py import MPI if TYPE_CHECKING: from algos.base_class import BaseNode diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 70a96bf0..ec2ec4a4 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -1,10 +1,12 @@ +from collections import OrderedDict from typing import Dict, Any, List, TYPE_CHECKING from mpi4py import MPI +from torch import Tensor from utils.communication.interface import CommunicationInterface import threading import time -from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model import random +import numpy as np if TYPE_CHECKING: from algos.base_class import BaseNode @@ -15,6 +17,9 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() + self.num_users: int = int(config["num_users"]) # type: ignore + self.finished = False + # Ensure that we are using thread safe threading level self.required_threading_level = MPI.THREAD_MULTIPLE self.threading_level = MPI.Query_thread() @@ -34,16 +39,17 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): self.base_node: BaseNode | None = None - listener_thread = threading.Thread(target=self.listener, daemon=True) - listener_thread.start() + self.listener_thread = threading.Thread(target=self.listener) + self.listener_thread.start() + + self.send_thread = threading.Thread(target=self.send) def initialize(self): pass def register_self(self, obj: "BaseNode"): self.base_node = obj - send_thread = threading.Thread(target=self.send) - send_thread.start() + self.send_thread.start() def get_comm_cost(self): with self.lock: @@ -55,26 +61,30 @@ def listener(self): Once send request is received, the listener thread informs the send thread to send the data to the requesting node. """ - while True: + while not self.finished: status = MPI.Status() # look for message with tag 1 (represents send request) if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): with self.lock: self.request_source = status.Get_source() - self.comm.irecv(source=self.request_source, tag=1) + print(f"Node {self.rank} received request from {self.request_source}") + # receive_request = self.comm.irecv(source=self.request_source, tag=1) + # receive_request.wait() + self.comm.recv(source=self.request_source, tag=1) self.send_event.set() - time.sleep(1) # Simulate waiting time + # time.sleep(1) + print(f"Node {self.rank} listener thread ended") - def get_model(self) -> bytes | None: + def get_model(self) -> List[OrderedDict[str, Tensor]] | None: print(f"getting model from {self.rank}, {self.base_node}") if not self.base_node: raise Exception("Base node not registered") with self.lock: if self.is_working: - print("model is working") - model = serialize_model(self.base_node.get_model_weights()) - print(f"model data to be sent: {model}") + model = self.base_node.get_model_weights() + model = [model] + print(f"Model from {self.rank} acquired") else: assert self.base_node.dropout.dropout_enabled, "Empty models are only supported when Dropout is enabled." model = None @@ -85,43 +95,62 @@ def send(self): Node will wait for a request to send data and then send the data to requesting node. """ - while True: + while not self.finished: # Wait until the listener thread detects a request self.send_event.wait() + if self.finished: + break with self.lock: dest = self.request_source if dest is not None: data = self.get_model() - req = self.comm.isend(data, dest=int(dest)) - req.wait() + print(f"Node {self.rank} is sending data to {dest}") + # req = self.comm.Isend(data, dest=int(dest)) + # req.wait() + self.comm.send(data, dest=int(dest)) with self.lock: self.request_source = None self.send_event.clear() + print(f"Node {self.rank} send thread ended") def receive(self, node_ids: List[int]) -> Any: """ Node will send a request for data and wait to receive data. """ max_tries = 10 - for node in node_ids: - while max_tries > 0: - try: - self.comm.send("", dest=node, tag=1) - recv_req = self.comm.irecv(source=node) - received_data = recv_req.wait() - print(f"received data: {received_data}") - return deserialize_model(received_data) - except Exception as e: - print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") - import traceback - print(traceback.print_exc()) - # sleep for a random time between 1 and 10 seconds - random_time = random.randint(1, 10) - time.sleep(random_time) - max_tries -= 1 + assert len(node_ids) == 1, "Too many node_ids to unpack" + node = node_ids[0] + while max_tries > 0: + try: + print(f"Node {self.rank} receiving from {node}") + self.comm.send("", dest=node, tag=1) + # recv_req = self.comm.Irecv([], source=node) + # received_data = recv_req.wait() + received_data = self.comm.recv(source=node) + print(f"Node {self.rank} received data from {node}: {bool(received_data)}") + if not received_data: + raise Exception("Received empty data") + return received_data + except MPI.Exception as e: + print(f"MPI failed {10 - max_tries} times: MPI ERROR: {e}", "Retrying...") + import traceback + print(f"Traceback: {traceback.print_exc()}") + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 + except Exception as e: + print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") + import traceback + print(f"Traceback: {traceback.print_exc()}") + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 + print(f"Node {self.rank} received") # deprecated broadcast function def broadcast(self, data: Any): @@ -138,9 +167,60 @@ def all_gather(self): print(f"receiving this data: {self.receive(i)}") items.append(self.receive(i)) return items + + def send_finished(self): + self.comm.send("Finished", dest=0, tag=2) def finalize(self): - pass + # 1. All nodes send finished to the super node + # 2. super node will wait for all nodes to send finished + # 3. super node will then send bye to all nodes + # 4. all nodes will wait for the bye and then exit + # this is to ensure that all nodes have finished + # and no one leaves early + if self.rank == 0: + quorum_threshold = self.num_users - 1 # No +1 for the super node because it doesn't send finished + num_finished: set[int] = set() + status = MPI.Status() + while len(num_finished) < quorum_threshold: + # sleep for 5 seconds + print( + f"Waiting for {quorum_threshold} users to finish, {num_finished} have finished so far" + ) + # time.sleep(5) + # get finished nodes + self.comm.recv(source=MPI.ANY_SOURCE, tag=2, status=status) + print(f"received finish message from {status.Get_source()}") + num_finished.add(status.Get_source()) + + else: + # send finished to the super node + print(f"Node {self.rank} sent finish message") + self.send_finished() + + # problem: do the other nodes wait for super node to receive finish messages? + message = self.comm.bcast("Done", root=0) + self.finished = True + self.send_event.set() + print(f"Node {self.rank} received {message}, finished") + self.comm.Barrier() + self.listener_thread.join() + print(f"Node {self.rank} listener thread done") + if self.send_thread.is_alive(): + self.send_thread.join() + print(f"Node {self.rank} send thread done") + print(f"Node {self.rank} active threads: {threading.active_count()}") + print(f"Node {self.rank} listener thread is {self.listener_thread.is_alive()}") + print(f"Node {self.rank} {threading.enumerate()}") + # for thread in threading.enumerate(): + # if thread != threading.main_thread(): + # thread.join() + print(f"Node {self.rank} send thread is {self.send_thread.is_alive()}") + self.comm.Barrier() + print(f"Node {self.rank}: all nodes synchronized") + MPI.Finalize() + + print("Finalized") def set_is_working(self, is_working: bool): with self.lock: From 71dd9e86e3f87d0dcff40eddb62c08a9880ad613 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Tue, 5 Nov 2024 22:46:26 -0500 Subject: [PATCH 06/19] merged send and listener threads --- src/utils/communication/mpi.py | 54 ++++++++-------------------------- 1 file changed, 13 insertions(+), 41 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index ec2ec4a4..026d6451 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -42,14 +42,11 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): self.listener_thread = threading.Thread(target=self.listener) self.listener_thread.start() - self.send_thread = threading.Thread(target=self.send) - def initialize(self): pass def register_self(self, obj: "BaseNode"): self.base_node = obj - self.send_thread.start() def get_comm_cost(self): with self.lock: @@ -66,14 +63,14 @@ def listener(self): # look for message with tag 1 (represents send request) if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): with self.lock: - self.request_source = status.Get_source() + # self.request_source = status.Get_source() + dest = status.Get_source() print(f"Node {self.rank} received request from {self.request_source}") # receive_request = self.comm.irecv(source=self.request_source, tag=1) # receive_request.wait() - self.comm.recv(source=self.request_source, tag=1) - self.send_event.set() - # time.sleep(1) + self.comm.recv(source=dest, tag=1) + self.send(dest) print(f"Node {self.rank} listener thread ended") def get_model(self) -> List[OrderedDict[str, Tensor]] | None: @@ -90,31 +87,19 @@ def get_model(self) -> List[OrderedDict[str, Tensor]] | None: model = None return model - def send(self): + def send(self, dest: int): """ Node will wait for a request to send data and then send the data to requesting node. """ - while not self.finished: - # Wait until the listener thread detects a request - self.send_event.wait() - if self.finished: - break - with self.lock: - dest = self.request_source - - if dest is not None: - data = self.get_model() - print(f"Node {self.rank} is sending data to {dest}") - # req = self.comm.Isend(data, dest=int(dest)) - # req.wait() - self.comm.send(data, dest=int(dest)) - - with self.lock: - self.request_source = None - - self.send_event.clear() - print(f"Node {self.rank} send thread ended") + if self.finished: + return + + data = self.get_model() + print(f"Node {self.rank} is sending data to {dest}") + # req = self.comm.Isend(data, dest=int(dest)) + # req.wait() + self.comm.send(data, dest=int(dest)) def receive(self, node_ids: List[int]) -> Any: """ @@ -183,11 +168,9 @@ def finalize(self): num_finished: set[int] = set() status = MPI.Status() while len(num_finished) < quorum_threshold: - # sleep for 5 seconds print( f"Waiting for {quorum_threshold} users to finish, {num_finished} have finished so far" ) - # time.sleep(5) # get finished nodes self.comm.recv(source=MPI.ANY_SOURCE, tag=2, status=status) print(f"received finish message from {status.Get_source()}") @@ -198,7 +181,6 @@ def finalize(self): print(f"Node {self.rank} sent finish message") self.send_finished() - # problem: do the other nodes wait for super node to receive finish messages? message = self.comm.bcast("Done", root=0) self.finished = True self.send_event.set() @@ -206,22 +188,12 @@ def finalize(self): self.comm.Barrier() self.listener_thread.join() print(f"Node {self.rank} listener thread done") - if self.send_thread.is_alive(): - self.send_thread.join() - print(f"Node {self.rank} send thread done") - print(f"Node {self.rank} active threads: {threading.active_count()}") print(f"Node {self.rank} listener thread is {self.listener_thread.is_alive()}") print(f"Node {self.rank} {threading.enumerate()}") - # for thread in threading.enumerate(): - # if thread != threading.main_thread(): - # thread.join() - print(f"Node {self.rank} send thread is {self.send_thread.is_alive()}") self.comm.Barrier() print(f"Node {self.rank}: all nodes synchronized") MPI.Finalize() - print("Finalized") - def set_is_working(self, is_working: bool): with self.lock: self.is_working = is_working From 302a0f086a3b2a912f5b7bac060232242a685771 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 2 Dec 2024 12:58:03 -0500 Subject: [PATCH 07/19] added super init to fl_static server --- src/algos/base_class.py | 4 ++-- src/algos/fl_static.py | 2 +- src/configs/algo_config.py | 2 +- src/configs/sys_config.py | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 026ed0db..3fbf936b 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -257,8 +257,8 @@ def set_shared_exp_parameters(self, config: Dict[str, ConfigType]) -> None: ) else: raise ValueError(f"Unknown community type: {community_type}.") - if self.node_id == 0: - self.log_utils.log_console(f"Communities: {self.communities}") + # if self.node_id == 0: + # self.log_utils.log_console(f"Communities: {self.communities}") def local_round_done(self) -> None: self.round += 1 diff --git a/src/algos/fl_static.py b/src/algos/fl_static.py index 59d7baa7..6cb5808f 100644 --- a/src/algos/fl_static.py +++ b/src/algos/fl_static.py @@ -71,7 +71,7 @@ class FedStaticServer(BaseFedAvgClient): def __init__( self, config: Dict[str, Any], comm_utils: CommunicationManager ) -> None: - pass + super().__init__(config, comm_utils) def run_protocol(self) -> None: pass diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index b31928fb..48227336 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -192,7 +192,7 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st # Collaboration setup "algo": "fedstatic", "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore - "rounds": 20, + "rounds": 1, # Model parameters "model": "resnet10", diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 305109e6..f4327461 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -316,7 +316,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "exp_keys": [], } -num_users = 9 +num_users = 3 dropout_dict = { "distribution_dict": { # leave dict empty to disable dropout @@ -346,8 +346,8 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "seed": 2, "device_ids": get_device_ids(num_users, gpu_ids), # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore - "algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore - "samples_per_user": 50000 // num_users, # distributed equally + "algos": get_algo_configs(num_users=num_users, algo_configs=[traditional_fl]), # type: ignore + "samples_per_user": 100, # distributed equally "train_label_distribution": "non_iid", "test_label_distribution": "iid", "alpha_data": 1.0, From 627a60ee4bd59c8c12651a0d775fa6ca14c61d7f Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 2 Dec 2024 17:52:40 -0500 Subject: [PATCH 08/19] logging dataset loading --- src/algos/base_class.py | 1 + src/configs/sys_config.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 3fbf936b..0b9ba2c8 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -113,6 +113,7 @@ def __init__( self.setup_cuda(config) self.model_utils = ModelUtils(self.device, config) + print("getting dataset!!") self.dset_obj = get_dataset(self.dset, dpath=config["dpath"]) dropout_seed = 1 * config.get("num_users", 9) + self.node_id * config.get("num_users", 9) + config.get("seed", 20) # arbitrarily chosen diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index f4327461..e9df421f 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -346,7 +346,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "seed": 2, "device_ids": get_device_ids(num_users, gpu_ids), # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore - "algos": get_algo_configs(num_users=num_users, algo_configs=[traditional_fl]), # type: ignore + "algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore "samples_per_user": 100, # distributed equally "train_label_distribution": "non_iid", "test_label_distribution": "iid", From a7b2adc25051bd804a573002f4227fb8772b2f26 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 4 Dec 2024 16:59:59 -0500 Subject: [PATCH 09/19] reduced test size during workflow testing --- .github/workflows/train.yml | 8 +++----- src/algos/base_class.py | 8 ++++++++ src/configs/algo_config.py | 1 + src/configs/algo_config_test.py | 1 + src/configs/sys_config.py | 3 ++- src/configs/sys_config_test.py | 5 +++-- src/main_grpc.py | 9 +++++++++ 7 files changed, 27 insertions(+), 8 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 706211e8..b40d4cb3 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -39,18 +39,16 @@ jobs: sudo apt install -y libopenmpi-dev openmpi-bin sudo apt-get install -y libgl1 libglib2.0-0 - pip install -r requirements_cpu.txt + pip install -r requirements.txt # Step 4: Run gRPC server and client - name: Run test run: | cd src - # chmod +x ./configs/algo_config_test.py - echo "starting main grpc" - python main_grpc.py -n 4 -host localhost + python main_grpc.py -n 4 -host localhost -dev True echo "starting main" - python main.py -super true -s "./configs/sys_config_test.py" + python main.py -b "./configs/algo_config_test.py" -s "./configs/sys_config_test.py" -super true echo "done" # further checks: diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 2dfdf1da..c00be6f5 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -532,6 +532,13 @@ def set_data_parameters(self, config: ConfigType) -> None: if config.get("test_samples_per_class", None) is not None: test_dset, _ = balanced_subset(test_dset, config["test_samples_per_class"]) + #reduce test_dset size + if config.get("workflow_test", False): + print("Workflow testing: Reducing test size...") + reduced_test_size = 1000 + indices = np.random.choice(len(test_dset), reduced_test_size, replace=False) + test_dset = Subset(test_dset, indices) + samples_per_user = config["samples_per_user"] batch_size: int = config["batch_size"] # type: ignore print(f"samples per user: {samples_per_user}, batch size: {batch_size}") @@ -686,6 +693,7 @@ def is_same_dest(dset): if self.dset.startswith("domainnet"): test_dset = CacheDataset(test_dset) + print(f"test_dset size: {len(test_dset)}") self._test_loader = DataLoader(test_dset, batch_size=batch_size) # TODO: fix print_data_summary # self.print_data_summary(train_dset, test_dset, val_dset=val_dset) diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index a2242c62..c33de567 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -37,6 +37,7 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st "model": "resnet10", "model_lr": 3e-4, "batch_size": 256, + "workflow_test": False, } test_fl_inversion: ConfigType = { diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py index 2f2c7fcf..c1614d24 100644 --- a/src/configs/algo_config_test.py +++ b/src/configs/algo_config_test.py @@ -21,6 +21,7 @@ "model": "resnet10", "model_lr": 3e-4, "batch_size": 256, + "workflow_test": True, } # default_config_list: List[ConfigType] = [fedstatic, fedstatic, fedstatic, fedstatic] \ No newline at end of file diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 2e7e0438..636ecf50 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -365,7 +365,8 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "matlaber12": [0, 1, 2, 3], "matlaber3": [0, 1, 2, 3], "matlaber4": [0, 2, 3, 4, 5, 6, 7], - } + }, + "workflow_test": False, } grpc_system_config_gia: ConfigType = { diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index 4737d16c..b17af96a 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -120,7 +120,7 @@ def get_algo_configs( "alpha_data": 1.0, "exp_keys": [], "dropout_dicts": dropout_dicts, - "test_samples_per_user": 200, + "test_samples_per_user": 100, "log_memory": True, # "streaming_aggregation": True, # Make it true for fedstatic "assign_based_on_host": True, @@ -129,6 +129,7 @@ def get_algo_configs( "matlaber12": [0, 1, 2, 3], "matlaber3": [0, 1, 2, 3], "matlaber4": [0, 2, 3, 4, 5, 6, 7], - } + }, + "workflow_test": True, } current_config = grpc_system_config \ No newline at end of file diff --git a/src/main_grpc.py b/src/main_grpc.py index 22ab1277..e3678431 100644 --- a/src/main_grpc.py +++ b/src/main_grpc.py @@ -23,10 +23,19 @@ help=f"host address of the nodes", ) +parser.add_argument( + "-dev", + nargs="?", + type=bool, + help=f"whether or not development testing", +) + args : argparse.Namespace = parser.parse_args() # Command for opening each process command_list: List[str] = ["python", "main.py", "-host", args.host] +if args.dev == True: + command_list: List[str] = ["python", "main.py", "-b", "./configs/algo_config_test.py", "-s", "./configs/sys_config_test.py", "-host", args.host] # Start process for each user for i in range(args.n): From 06c9d430a5f8c096379c743f4c655ab4647c95a9 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 4 Dec 2024 17:00:28 -0500 Subject: [PATCH 10/19] workflow debugging --- .github/workflows/train.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index b40d4cb3..14f3d5ec 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -2,7 +2,7 @@ name: Test Training Code with gRPC on: # used for debugging purposes - # workflow_dispatch: + workflow_dispatch: push: branches: # run test on push to main only From 3784b69b40ca1e11234c80fbf9ecded56d6e353a Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 4 Dec 2024 17:02:14 -0500 Subject: [PATCH 11/19] workflow debugging --- .github/workflows/train.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 14f3d5ec..be8cfbce 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -6,8 +6,8 @@ on: push: branches: # run test on push to main only - - main - # - "*" + # - main + - "*" pull_request: branches: - main From 002b1f0c5ab1663d2c614dbffab2f0424d3863d5 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 4 Dec 2024 17:22:38 -0500 Subject: [PATCH 12/19] workflow run on push to main only --- .github/workflows/train.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index be8cfbce..b40d4cb3 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -2,12 +2,12 @@ name: Test Training Code with gRPC on: # used for debugging purposes - workflow_dispatch: + # workflow_dispatch: push: branches: # run test on push to main only - # - main - - "*" + - main + # - "*" pull_request: branches: - main From f9ad68a153aadcb6e2f51fe198ada84c655aed6b Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 4 Dec 2024 17:23:07 -0500 Subject: [PATCH 13/19] using requirements cpu --- .github/workflows/train.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index b40d4cb3..d5364182 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -39,7 +39,7 @@ jobs: sudo apt install -y libopenmpi-dev openmpi-bin sudo apt-get install -y libgl1 libglib2.0-0 - pip install -r requirements.txt + pip install -r requirements_cpu.txt # Step 4: Run gRPC server and client - name: Run test From e24a2f864aa487b30c062539cac4a91674a4244c Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Thu, 5 Dec 2024 13:25:53 -0500 Subject: [PATCH 14/19] using test_samples_per_user to reduce test set --- src/algos/base_class.py | 14 +++++++------- src/configs/algo_config.py | 1 - src/configs/algo_config_test.py | 1 - src/configs/sys_config.py | 1 - src/configs/sys_config_test.py | 3 +-- 5 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/algos/base_class.py b/src/algos/base_class.py index c00be6f5..1df8b8d0 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -532,13 +532,6 @@ def set_data_parameters(self, config: ConfigType) -> None: if config.get("test_samples_per_class", None) is not None: test_dset, _ = balanced_subset(test_dset, config["test_samples_per_class"]) - #reduce test_dset size - if config.get("workflow_test", False): - print("Workflow testing: Reducing test size...") - reduced_test_size = 1000 - indices = np.random.choice(len(test_dset), reduced_test_size, replace=False) - test_dset = Subset(test_dset, indices) - samples_per_user = config["samples_per_user"] batch_size: int = config["batch_size"] # type: ignore print(f"samples per user: {samples_per_user}, batch size: {batch_size}") @@ -693,7 +686,14 @@ def is_same_dest(dset): if self.dset.startswith("domainnet"): test_dset = CacheDataset(test_dset) + # reduce test_dset size + if config.get("test_samples_per_user", 0) != 0: + print(f"Reducing test size to {config.get('test_samples_per_user', 0)}") + reduced_test_size = config.get("test_samples_per_user", 0) + indices = np.random.choice(len(test_dset), reduced_test_size, replace=False) + test_dset = Subset(test_dset, indices) print(f"test_dset size: {len(test_dset)}") + self._test_loader = DataLoader(test_dset, batch_size=batch_size) # TODO: fix print_data_summary # self.print_data_summary(train_dset, test_dset, val_dset=val_dset) diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index c33de567..a2242c62 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -37,7 +37,6 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st "model": "resnet10", "model_lr": 3e-4, "batch_size": 256, - "workflow_test": False, } test_fl_inversion: ConfigType = { diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py index c1614d24..2f2c7fcf 100644 --- a/src/configs/algo_config_test.py +++ b/src/configs/algo_config_test.py @@ -21,7 +21,6 @@ "model": "resnet10", "model_lr": 3e-4, "batch_size": 256, - "workflow_test": True, } # default_config_list: List[ConfigType] = [fedstatic, fedstatic, fedstatic, fedstatic] \ No newline at end of file diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 636ecf50..0b0a484b 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -366,7 +366,6 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "matlaber3": [0, 1, 2, 3], "matlaber4": [0, 2, 3, 4, 5, 6, 7], }, - "workflow_test": False, } grpc_system_config_gia: ConfigType = { diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index b17af96a..b443ff19 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -120,7 +120,7 @@ def get_algo_configs( "alpha_data": 1.0, "exp_keys": [], "dropout_dicts": dropout_dicts, - "test_samples_per_user": 100, + "test_samples_per_user": 1000, "log_memory": True, # "streaming_aggregation": True, # Make it true for fedstatic "assign_based_on_host": True, @@ -130,6 +130,5 @@ def get_algo_configs( "matlaber3": [0, 1, 2, 3], "matlaber4": [0, 2, 3, 4, 5, 6, 7], }, - "workflow_test": True, } current_config = grpc_system_config \ No newline at end of file From c05efb885e3b3dbbbe65ba8c0cbea573ecbba1a3 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Thu, 5 Dec 2024 13:54:31 -0500 Subject: [PATCH 15/19] download data in server --- src/algos/base_class.py | 1 - src/algos/fl_static.py | 15 ++++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 0b9ba2c8..3fbf936b 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -113,7 +113,6 @@ def __init__( self.setup_cuda(config) self.model_utils = ModelUtils(self.device, config) - print("getting dataset!!") self.dset_obj = get_dataset(self.dset, dpath=config["dpath"]) dropout_seed = 1 * config.get("num_users", 9) + self.node_id * config.get("num_users", 9) + config.get("seed", 20) # arbitrarily chosen diff --git a/src/algos/fl_static.py b/src/algos/fl_static.py index 6cb5808f..e2d4812d 100644 --- a/src/algos/fl_static.py +++ b/src/algos/fl_static.py @@ -7,6 +7,7 @@ from algos.base_class import BaseFedAvgClient from algos.topologies.collections import select_topology +from utils.data_utils import get_dataset class FedStaticNode(BaseFedAvgClient): @@ -71,7 +72,19 @@ class FedStaticServer(BaseFedAvgClient): def __init__( self, config: Dict[str, Any], comm_utils: CommunicationManager ) -> None: - super().__init__(config, comm_utils) + self.comm_utils = comm_utils + self.node_id = self.comm_utils.get_rank() + self.comm_utils.register_node(self) + self.is_working = True + if isinstance(config["dset"], dict): + if self.node_id != 0: + config["dset"].pop("0") # type: ignore + self.dset = str(config["dset"][str(self.node_id)]) # type: ignore + config["dpath"] = config["dpath"][self.dset] + else: + self.dset = config["dset"] + print(f"Node {self.node_id} getting dset at {self.dset}") + self.dset_obj = get_dataset(self.dset, dpath=config["dpath"]) def run_protocol(self) -> None: pass From 668bb42244a5b95f26b29c5eeb9dad79a621efa4 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Thu, 5 Dec 2024 14:07:40 -0500 Subject: [PATCH 16/19] fedstatic works with testing --- src/configs/algo_config_test.py | 20 ++++++++++---------- src/configs/sys_config_test.py | 5 +++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py index 2f2c7fcf..bce2b882 100644 --- a/src/configs/algo_config_test.py +++ b/src/configs/algo_config_test.py @@ -1,16 +1,16 @@ from utils.types import ConfigType -# fedstatic: ConfigType = { -# # Collaboration setup -# "algo": "fedstatic", -# "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore -# "rounds": 1, +fedstatic: ConfigType = { + # Collaboration setup + "algo": "fedstatic", + "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore + "rounds": 1, -# # Model parameters -# "model": "resnet10", -# "model_lr": 3e-4, -# "batch_size": 256, -# } + # Model parameters + "model": "resnet10", + "model_lr": 3e-4, + "batch_size": 256, +} traditional_fl: ConfigType = { # Collaboration setup diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index b443ff19..271497f1 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -3,7 +3,8 @@ from utils.types import ConfigType from .algo_config_test import ( - traditional_fl + traditional_fl, + fedstatic ) def get_device_ids(num_users: int, gpus_available: List[int | Literal["cpu"]]) -> Dict[str, List[int | Literal["cpu"]]]: @@ -112,7 +113,7 @@ def get_algo_configs( "seed": 2, "device_ids": get_device_ids(num_users, gpu_ids), # "algos": get_algo_configs(num_users=num_users, algo_configs=default_config_list), # type: ignore - "algos": get_algo_configs(num_users=num_users, algo_configs=[traditional_fl]), # type: ignore + "algos": get_algo_configs(num_users=num_users, algo_configs=[fedstatic]), # type: ignore # "samples_per_user": 50000 // num_users, # distributed equally "samples_per_user": 100, "train_label_distribution": "non_iid", From 0720cda526aff249ef21999243e63b53d78bfe1f Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sat, 7 Dec 2024 14:16:39 -0500 Subject: [PATCH 17/19] change dump_dir --- src/configs/sys_config.py | 2 +- src/configs/sys_config_test.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index d34269e2..2f448509 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -158,7 +158,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): CIAR10_DPATH = "./datasets/imgs/cifar10/" NUM_COLLABORATORS = 1 -DUMP_DIR = "/Users/kathryn/MIT/UROP/Media Lab/sonar_experiments/" +DUMP_DIR = "/tmp/" num_users = 4 diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index 271497f1..b3504360 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -81,7 +81,6 @@ def get_algo_configs( CIFAR10_DSET = "cifar10" CIAR10_DPATH = "./datasets/imgs/cifar10/" -# DUMP_DIR = "../../../../../../../home/" DUMP_DIR = "/tmp/" NUM_COLLABORATORS = 1 From d4376d89418f7754a92509b84c814b972bf06728 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 9 Dec 2024 11:43:46 -0500 Subject: [PATCH 18/19] code cleanup --- src/configs/algo_config.py | 2 +- src/configs/sys_config.py | 10 ++++------ src/configs/sys_config_test.py | 4 ++-- src/utils/communication/mpi.py | 3 --- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index 8aaba8da..a2242c62 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -204,7 +204,7 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st # Collaboration setup "algo": "fedstatic", "topology": {"name": "watts_strogatz", "k": 3, "p": 0.2}, # type: ignore - "rounds": 1, + "rounds": 3, # Model parameters "optimizer": "sgd", # TODO comment out for real training "model": "resnet10", diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 2f448509..19a9873c 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -160,8 +160,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): NUM_COLLABORATORS = 1 DUMP_DIR = "/tmp/" -num_users = 4 - +num_users = 3 mpi_system_config: ConfigType = { "exp_id": "", "comm": {"type": "MPI"}, @@ -175,11 +174,11 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): # The device_ids dictionary depicts the GPUs on which the nodes reside. # For a single-GPU environment, the config will look as follows (as it follows a 0-based indexing): # "device_ids": {"node_0": [0], "node_1": [0], "node_2": [0], "node_3": [0]}, - "device_ids": get_device_ids(num_users=4, gpus_available=[1, 2]), + "device_ids": get_device_ids(num_users=3, gpus_available=[1, 2]), # use this when the list needs to be imported from the algo_config # "algo": get_algo_configs(num_users=3, algo_configs=algo_configs_list), "algos": get_algo_configs( - num_users=4, + num_users=3, algo_configs=default_config_list ), # type: ignore "samples_per_user": 5555, # TODO: To model scenarios where different users have different number of samples @@ -365,7 +364,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "matlaber12": [0, 1, 2, 3], "matlaber3": [0, 1, 2, 3], "matlaber4": [0, 2, 3, 4, 5, 6, 7], - }, + } } grpc_system_config_gia: ConfigType = { @@ -388,6 +387,5 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "gia":True, "gia_attackers":[1] } - current_config = grpc_system_config # current_config = mpi_system_config diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index b3504360..6e692146 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -120,7 +120,7 @@ def get_algo_configs( "alpha_data": 1.0, "exp_keys": [], "dropout_dicts": dropout_dicts, - "test_samples_per_user": 1000, + "test_samples_per_user": 200, "log_memory": True, # "streaming_aggregation": True, # Make it true for fedstatic "assign_based_on_host": True, @@ -129,6 +129,6 @@ def get_algo_configs( "matlaber12": [0, 1, 2, 3], "matlaber3": [0, 1, 2, 3], "matlaber4": [0, 2, 3, 4, 5, 6, 7], - }, + } } current_config = grpc_system_config \ No newline at end of file diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 833f9b04..05c28d60 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -48,7 +48,6 @@ def initialize(self): def send_quorum(self) -> Any: # return super().send_quorum(node_ids) pass - def register_self(self, obj: "BaseNode"): self.base_node = obj @@ -76,7 +75,6 @@ def listener(self): self.comm.recv(source=dest, tag=1) self.send(dest) print(f"Node {self.rank} listener thread ended") - def get_model(self) -> List[OrderedDict[str, Tensor]] | None: print(f"getting model from {self.rank}, {self.base_node}") if not self.base_node: @@ -153,7 +151,6 @@ def all_gather(self): """ items: List[Any] = [] for i in range(1, self.size): - print(f"receiving this data: {self.receive(i)}") items.append(self.receive(i)) return items From 76758f738c32a576f5ad18965e0a35f2b613ac18 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 9 Dec 2024 11:45:11 -0500 Subject: [PATCH 19/19] code cleanup --- src/configs/sys_config.py | 1 + src/utils/communication/mpi.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 19a9873c..738f4c93 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -387,5 +387,6 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "gia":True, "gia_attackers":[1] } + current_config = grpc_system_config # current_config = mpi_system_config diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 05c28d60..fb729073 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -48,6 +48,7 @@ def initialize(self): def send_quorum(self) -> Any: # return super().send_quorum(node_ids) pass + def register_self(self, obj: "BaseNode"): self.base_node = obj @@ -75,6 +76,7 @@ def listener(self): self.comm.recv(source=dest, tag=1) self.send(dest) print(f"Node {self.rank} listener thread ended") + def get_model(self) -> List[OrderedDict[str, Tensor]] | None: print(f"getting model from {self.rank}, {self.base_node}") if not self.base_node: