Skip to content

Commit

Permalink
mpi works, occassional deadlock issue
Browse files Browse the repository at this point in the history
  • Loading branch information
kathrynle20 committed Nov 3, 2024
1 parent 2f087e1 commit 464a674
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 32 deletions.
1 change: 1 addition & 0 deletions src/utils/communication/comm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from utils.communication.grpc.main import GRPCCommunication
from typing import Any, Dict, List, TYPE_CHECKING
from utils.communication.mpi import MPICommUtils
from mpi4py import MPI

if TYPE_CHECKING:
from algos.base_class import BaseNode
Expand Down
144 changes: 112 additions & 32 deletions src/utils/communication/mpi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections import OrderedDict
from typing import Dict, Any, List, TYPE_CHECKING
from mpi4py import MPI
from torch import Tensor
from utils.communication.interface import CommunicationInterface
import threading
import time
from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model
import random
import numpy as np

if TYPE_CHECKING:
from algos.base_class import BaseNode
Expand All @@ -15,6 +17,9 @@ def __init__(self, config: Dict[str, Dict[str, Any]]):
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()

self.num_users: int = int(config["num_users"]) # type: ignore
self.finished = False

# Ensure that we are using thread safe threading level
self.required_threading_level = MPI.THREAD_MULTIPLE
self.threading_level = MPI.Query_thread()
Expand All @@ -34,16 +39,17 @@ def __init__(self, config: Dict[str, Dict[str, Any]]):

self.base_node: BaseNode | None = None

listener_thread = threading.Thread(target=self.listener, daemon=True)
listener_thread.start()
self.listener_thread = threading.Thread(target=self.listener)
self.listener_thread.start()

self.send_thread = threading.Thread(target=self.send)

def initialize(self):
pass

def register_self(self, obj: "BaseNode"):
self.base_node = obj
send_thread = threading.Thread(target=self.send)
send_thread.start()
self.send_thread.start()

def get_comm_cost(self):
with self.lock:
Expand All @@ -55,26 +61,30 @@ def listener(self):
Once send request is received, the listener thread informs the send
thread to send the data to the requesting node.
"""
while True:
while not self.finished:
status = MPI.Status()
# look for message with tag 1 (represents send request)
if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status):
with self.lock:
self.request_source = status.Get_source()

self.comm.irecv(source=self.request_source, tag=1)
print(f"Node {self.rank} received request from {self.request_source}")
# receive_request = self.comm.irecv(source=self.request_source, tag=1)
# receive_request.wait()
self.comm.recv(source=self.request_source, tag=1)
self.send_event.set()
time.sleep(1) # Simulate waiting time
# time.sleep(1)
print(f"Node {self.rank} listener thread ended")

def get_model(self) -> bytes | None:
def get_model(self) -> List[OrderedDict[str, Tensor]] | None:
print(f"getting model from {self.rank}, {self.base_node}")
if not self.base_node:
raise Exception("Base node not registered")
with self.lock:
if self.is_working:
print("model is working")
model = serialize_model(self.base_node.get_model_weights())
print(f"model data to be sent: {model}")
model = self.base_node.get_model_weights()
model = [model]
print(f"Model from {self.rank} acquired")
else:
assert self.base_node.dropout.dropout_enabled, "Empty models are only supported when Dropout is enabled."
model = None
Expand All @@ -85,43 +95,62 @@ def send(self):
Node will wait for a request to send data and then send the
data to requesting node.
"""
while True:
while not self.finished:
# Wait until the listener thread detects a request
self.send_event.wait()
if self.finished:
break
with self.lock:
dest = self.request_source

if dest is not None:
data = self.get_model()
req = self.comm.isend(data, dest=int(dest))
req.wait()
print(f"Node {self.rank} is sending data to {dest}")
# req = self.comm.Isend(data, dest=int(dest))
# req.wait()
self.comm.send(data, dest=int(dest))

with self.lock:
self.request_source = None

self.send_event.clear()
print(f"Node {self.rank} send thread ended")

def receive(self, node_ids: List[int]) -> Any:
"""
Node will send a request for data and wait to receive data.
"""
max_tries = 10
for node in node_ids:
while max_tries > 0:
try:
self.comm.send("", dest=node, tag=1)
recv_req = self.comm.irecv(source=node)
received_data = recv_req.wait()
print(f"received data: {received_data}")
return deserialize_model(received_data)
except Exception as e:
print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...")
import traceback
print(traceback.print_exc())
# sleep for a random time between 1 and 10 seconds
random_time = random.randint(1, 10)
time.sleep(random_time)
max_tries -= 1
assert len(node_ids) == 1, "Too many node_ids to unpack"
node = node_ids[0]
while max_tries > 0:
try:
print(f"Node {self.rank} receiving from {node}")
self.comm.send("", dest=node, tag=1)
# recv_req = self.comm.Irecv([], source=node)
# received_data = recv_req.wait()
received_data = self.comm.recv(source=node)
print(f"Node {self.rank} received data from {node}: {bool(received_data)}")
if not received_data:
raise Exception("Received empty data")
return received_data
except MPI.Exception as e:
print(f"MPI failed {10 - max_tries} times: MPI ERROR: {e}", "Retrying...")
import traceback
print(f"Traceback: {traceback.print_exc()}")
# sleep for a random time between 1 and 10 seconds
random_time = random.randint(1, 10)
time.sleep(random_time)
max_tries -= 1
except Exception as e:
print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...")
import traceback
print(f"Traceback: {traceback.print_exc()}")
# sleep for a random time between 1 and 10 seconds
random_time = random.randint(1, 10)
time.sleep(random_time)
max_tries -= 1
print(f"Node {self.rank} received")

# deprecated broadcast function
def broadcast(self, data: Any):
Expand All @@ -138,9 +167,60 @@ def all_gather(self):
print(f"receiving this data: {self.receive(i)}")
items.append(self.receive(i))
return items

def send_finished(self):
self.comm.send("Finished", dest=0, tag=2)

def finalize(self):
pass
# 1. All nodes send finished to the super node
# 2. super node will wait for all nodes to send finished
# 3. super node will then send bye to all nodes
# 4. all nodes will wait for the bye and then exit
# this is to ensure that all nodes have finished
# and no one leaves early
if self.rank == 0:
quorum_threshold = self.num_users - 1 # No +1 for the super node because it doesn't send finished
num_finished: set[int] = set()
status = MPI.Status()
while len(num_finished) < quorum_threshold:
# sleep for 5 seconds
print(
f"Waiting for {quorum_threshold} users to finish, {num_finished} have finished so far"
)
# time.sleep(5)
# get finished nodes
self.comm.recv(source=MPI.ANY_SOURCE, tag=2, status=status)
print(f"received finish message from {status.Get_source()}")
num_finished.add(status.Get_source())

else:
# send finished to the super node
print(f"Node {self.rank} sent finish message")
self.send_finished()

# problem: do the other nodes wait for super node to receive finish messages?
message = self.comm.bcast("Done", root=0)
self.finished = True
self.send_event.set()
print(f"Node {self.rank} received {message}, finished")
self.comm.Barrier()
self.listener_thread.join()
print(f"Node {self.rank} listener thread done")
if self.send_thread.is_alive():
self.send_thread.join()
print(f"Node {self.rank} send thread done")
print(f"Node {self.rank} active threads: {threading.active_count()}")
print(f"Node {self.rank} listener thread is {self.listener_thread.is_alive()}")
print(f"Node {self.rank} {threading.enumerate()}")
# for thread in threading.enumerate():
# if thread != threading.main_thread():
# thread.join()
print(f"Node {self.rank} send thread is {self.send_thread.is_alive()}")
self.comm.Barrier()
print(f"Node {self.rank}: all nodes synchronized")
MPI.Finalize()

print("Finalized")

def set_is_working(self, is_working: bool):
with self.lock:
Expand Down

0 comments on commit 464a674

Please sign in to comment.