Skip to content

Commit

Permalink
added send thread, merged 2 classes
Browse files Browse the repository at this point in the history
  • Loading branch information
kathrynle20 committed Oct 22, 2024
1 parent f0192d4 commit 755fc07
Showing 1 changed file with 65 additions and 79 deletions.
144 changes: 65 additions & 79 deletions src/utils/communication/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,110 +3,96 @@
from utils.communication.interface import CommunicationInterface
import threading
import time
from enum import Enum


class MPICommUtils(CommunicationInterface):
def __init__(self, config: Dict[str, Dict[str, Any]]):
def __init__(self, config: Dict[str, Dict[str, Any]], data: Any):
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()

def initialize(self):
pass

# def send(self, dest: str | int, data: Any):
# self.comm.send(data, dest=int(dest))

# def receive(self, node_ids: str | int) -> Any:
# return self.comm.recv(source=int(node_ids))

def broadcast(self, data: Any):
for i in range(1, self.size):
if i != self.rank:
self.send(i, data)

def all_gather(self):
"""
This function is used to gather data from all the nodes.
"""
items: List[Any] = []
for i in range(1, self.size):
items.append(self.receive(i))
return items

def finalize(self):
pass


class MPICommunication(MPICommUtils):
def __init__(self, config: Dict[str, Dict[str, Any]]):
super().__init__(config)
# Ensure that we are using thread safe threading level
self.required_threading_level = MPI.THREAD_MULTIPLE
self.threading_level = MPI.Query_thread()
# Make sure to check for MPI_THREAD_MULTIPLE threading level to support
# thread safe calls to send and recv
if self.required_threading_level > self.threading_level:
raise RuntimeError(f"Insufficient thread support. Required: {self.required_threading_level}, Current: {self.threading_level}")

listener_thread = threading.Thread(target=self.listener, daemon=True)
listener_thread.start()
send_thread = threading.Thread(target=self.send, args=(data))
send_thread.start()

self.send_event = threading.Event()
# Ensures that the listener thread and send thread are not using self.request_source at the same time
self.source_node_lock = threading.Lock()
self.request_source: int | None = None

def initialize(self):
pass

def listener(self):
"""
Runs on listener thread on each node to receive a send request
Once send request is received, the listener thread informs the main
thread to send the data to the requesting node.
"""
while True:
status = MPI.Status()
if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status):
source = status.Get_source()
tag = status.Get_tag()
count = status.Get_count(MPI.BYTE) # Get the number of bytes in the message
# If a message is available, receive it
data_to_recv = bytearray(count)
req = self.comm.irecv([data_to_recv, MPI.BYTE], source=source, tag=tag)
req.wait()
# Convert the byte array back to a string
received_message = data_to_recv.decode('utf-8')

if received_message == "Requesting Information":
self.send_event.set()
# look for message with tag 1 (represents send request)
if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status):
with self.source_node_lock:
self.request_source = status.Get_source()

self.send_event.clear()
break
self.comm.irecv(source=self.request_source, tag=1)
self.send_event.set()
time.sleep(1) # Simulate waiting time

def send(self, dest: str | int, data: Any, tag: int):
def send(self, data: Any):
"""
Node will wait until request is received and then send
data to requesting node.
"""
while True:
# Wait until the listener thread detects a request
self.send_event.wait()
req = self.comm.isend(data, dest=int(dest), tag=tag)
req.wait()
with self.source_node_lock:
dest = self.request_source

def receive(self, node_ids: str | int, tag: int) -> Any:
if dest is not None:
req = self.comm.isend(data, dest=int(dest))
req.wait()

with self.source_node_lock:
self.request_source = None

self.send_event.clear()

def receive(self, node_ids: str | int) -> Any:
"""
Node will send a request and wait to receive data.
"""
node_ids = int(node_ids)
message = "Requesting Information"
message_bytes = bytearray(message, 'utf-8')
send_req = self.comm.isend([message_bytes, MPI.BYTE], dest=node_ids, tag=tag)
send_req = self.comm.isend("", dest=node_ids, tag=1)
send_req.wait()
recv_req = self.comm.irecv(source=node_ids, tag=tag)
recv_req = self.comm.irecv(source=node_ids)
return recv_req.wait()

# MPI Server
"""
initialization():
node spins up listener thread, threading (an extra thread might not be needed since iprobe exists).
call listen?
listen():
listener thread starts listening for send requests (use iprobe and irecv for message)
when send request is received, call the send() function
send():
gather and send info to requesting node using comm.isend
comm.wait

"""
# depreciated broadcast function
# def broadcast(self, data: Any):
# for i in range(1, self.size):
# if i != self.rank:
# self.send(i, data)

# MPI Client
"""
initialization():
node is initialized
def all_gather(self):
"""
This function is used to gather data from all the nodes.
"""
items: List[Any] = []
for i in range(1, self.size):
items.append(self.receive(i))
return items

receive():
node sends request to sending node using isend()
node calls irecv and waits for response
"""
def finalize(self):
pass

0 comments on commit 755fc07

Please sign in to comment.