diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 0bc7a9d..89bd21c 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -3,110 +3,96 @@ from utils.communication.interface import CommunicationInterface import threading import time -from enum import Enum - class MPICommUtils(CommunicationInterface): - def __init__(self, config: Dict[str, Dict[str, Any]]): + def __init__(self, config: Dict[str, Dict[str, Any]], data: Any): self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() - def initialize(self): - pass - - # def send(self, dest: str | int, data: Any): - # self.comm.send(data, dest=int(dest)) - - # def receive(self, node_ids: str | int) -> Any: - # return self.comm.recv(source=int(node_ids)) - - def broadcast(self, data: Any): - for i in range(1, self.size): - if i != self.rank: - self.send(i, data) - - def all_gather(self): - """ - This function is used to gather data from all the nodes. - """ - items: List[Any] = [] - for i in range(1, self.size): - items.append(self.receive(i)) - return items - - def finalize(self): - pass - - -class MPICommunication(MPICommUtils): - def __init__(self, config: Dict[str, Dict[str, Any]]): - super().__init__(config) + # Ensure that we are using thread safe threading level + self.required_threading_level = MPI.THREAD_MULTIPLE + self.threading_level = MPI.Query_thread() + # Make sure to check for MPI_THREAD_MULTIPLE threading level to support + # thread safe calls to send and recv + if self.required_threading_level > self.threading_level: + raise RuntimeError(f"Insufficient thread support. Required: {self.required_threading_level}, Current: {self.threading_level}") + listener_thread = threading.Thread(target=self.listener, daemon=True) listener_thread.start() + send_thread = threading.Thread(target=self.send, args=(data)) + send_thread.start() + self.send_event = threading.Event() + # Ensures that the listener thread and send thread are not using self.request_source at the same time + self.source_node_lock = threading.Lock() self.request_source: int | None = None + def initialize(self): + pass + def listener(self): + """ + Runs on listener thread on each node to receive a send request + Once send request is received, the listener thread informs the main + thread to send the data to the requesting node. + """ while True: status = MPI.Status() - if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status): - source = status.Get_source() - tag = status.Get_tag() - count = status.Get_count(MPI.BYTE) # Get the number of bytes in the message - # If a message is available, receive it - data_to_recv = bytearray(count) - req = self.comm.irecv([data_to_recv, MPI.BYTE], source=source, tag=tag) - req.wait() - # Convert the byte array back to a string - received_message = data_to_recv.decode('utf-8') - - if received_message == "Requesting Information": - self.send_event.set() + # look for message with tag 1 (represents send request) + if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): + with self.source_node_lock: + self.request_source = status.Get_source() - self.send_event.clear() - break + self.comm.irecv(source=self.request_source, tag=1) + self.send_event.set() time.sleep(1) # Simulate waiting time - def send(self, dest: str | int, data: Any, tag: int): + def send(self, data: Any): + """ + Node will wait until request is received and then send + data to requesting node. + """ while True: # Wait until the listener thread detects a request self.send_event.wait() - req = self.comm.isend(data, dest=int(dest), tag=tag) - req.wait() + with self.source_node_lock: + dest = self.request_source - def receive(self, node_ids: str | int, tag: int) -> Any: + if dest is not None: + req = self.comm.isend(data, dest=int(dest)) + req.wait() + + with self.source_node_lock: + self.request_source = None + + self.send_event.clear() + + def receive(self, node_ids: str | int) -> Any: + """ + Node will send a request and wait to receive data. + """ node_ids = int(node_ids) - message = "Requesting Information" - message_bytes = bytearray(message, 'utf-8') - send_req = self.comm.isend([message_bytes, MPI.BYTE], dest=node_ids, tag=tag) + send_req = self.comm.isend("", dest=node_ids, tag=1) send_req.wait() - recv_req = self.comm.irecv(source=node_ids, tag=tag) + recv_req = self.comm.irecv(source=node_ids) return recv_req.wait() - -# MPI Server -""" -initialization(): - node spins up listener thread, threading (an extra thread might not be needed since iprobe exists). - call listen? - -listen(): - listener thread starts listening for send requests (use iprobe and irecv for message) - when send request is received, call the send() function - -send(): - gather and send info to requesting node using comm.isend - comm.wait -""" + # depreciated broadcast function + # def broadcast(self, data: Any): + # for i in range(1, self.size): + # if i != self.rank: + # self.send(i, data) -# MPI Client -""" -initialization(): - node is initialized + def all_gather(self): + """ + This function is used to gather data from all the nodes. + """ + items: List[Any] = [] + for i in range(1, self.size): + items.append(self.receive(i)) + return items -receive(): - node sends request to sending node using isend() - node calls irecv and waits for response -""" + def finalize(self): + pass