diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index b15ef4c..73b320f 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -2,7 +2,6 @@ # TODO: Set up multiple non-iid configurations here. The goal of a separate system config # is to simulate different real-world scenarios without changing the algorithm configuration. from typing import Dict, List -import socket mpi_system_config = { @@ -27,44 +26,11 @@ "folder_deletion_signal_path":"./expt_dump/folder_deletion.signal" } -def is_port_available(port: int) -> bool: - """ - Check if a port is available for use. - """ - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # type: ignore - return s.connect_ex(('localhost', port)) != 0 # type: ignore - - -def generate_ports(num_users: int) -> List[int]: - """ - Generate a list of ports that are available for use. - """ - ports: List[int] = [] - i = 0 - while len(ports) < num_users: - port = 50051 + i - # check if the port is available - if is_port_available(port): - ports.append(port) - else: - print(f"Port {port} is not available, skipping...") - i += 1 - return ports - -def generate_peer_ids(num_users: int) -> List[str]: - """ - Generate a list of peer IDs for the users. - """ - peer_ids: List[str] = [] - ports = generate_ports(num_users) - for i in range(num_users): - peer_ids.append(f"localhost:{ports[i]}") - return peer_ids - def get_device_ids(num_users: int, gpus_available: List[int]) -> Dict[str, List[int]]: """ Get the GPU device IDs for the users. """ + # TODO: Make it multi-host device_ids: Dict[str, List[int]] = {} for i in range(num_users): index = i % len(gpus_available) @@ -72,13 +38,13 @@ def get_device_ids(num_users: int, gpus_available: List[int]) -> Dict[str, List[ device_ids[f"node_{i}"] = [gpu_id] return device_ids -num_users = 50 +num_users = 10 gpu_ids = [0, 1, 2, 3, 4, 5, 6, 7] grpc_system_config = { "num_users": num_users, "comm": { "type": "GRPC", - "all_peer_ids": generate_peer_ids(num_users + 1) # +1 for the super-node, + "peer_ids": ["localhost:50050"] # The super-node }, "dset": "cifar10", "dump_dir": "./expt_dump/", diff --git a/src/main_grpc.py b/src/main_grpc.py index c0bcdc0..546e6c0 100644 --- a/src/main_grpc.py +++ b/src/main_grpc.py @@ -30,7 +30,7 @@ # 1. find the number of users in the system configuration # 2. start separate processes by running python main.py for each user -num_users = sys_config["num_users"] + 1 # +1 for the server +num_users = sys_config["num_users"] + 1 # +1 for the super-node for i in range(num_users): print(f"Starting process for user {i}") # start a Popen process diff --git a/src/utils/communication/grpc/comm.proto b/src/utils/communication/grpc/comm.proto index 6bdd38d..f3d35d1 100644 --- a/src/utils/communication/grpc/comm.proto +++ b/src/utils/communication/grpc/comm.proto @@ -1,7 +1,11 @@ syntax = "proto3"; service CommunicationServer { - rpc SendData (Data) returns (Empty) {} + rpc send_data (Data) returns (Empty) {} + rpc get_rank (Empty) returns (Rank) {} + rpc update_port (PeerId) returns (Empty) {} + rpc send_peer_ids (PeerIds) returns (Empty) {} + rpc send_quorum (Quorum) returns (Empty) {} } message Empty {} @@ -14,3 +18,25 @@ message Data { string id = 1; Model model = 2; } + +message Rank { + int32 rank = 1; +} + +message Port { + int32 port = 1; +} + +message PeerId { + Rank rank = 1; + Port port = 2; + string ip = 3; +} + +message PeerIds { + map peer_ids = 1; +} + +message Quorum { + bool quorum = 1; +} \ No newline at end of file diff --git a/src/utils/communication/grpc/comm_pb2.py b/src/utils/communication/grpc/comm_pb2.py index aef780c..5c1623f 100644 --- a/src/utils/communication/grpc/comm_pb2.py +++ b/src/utils/communication/grpc/comm_pb2.py @@ -14,19 +14,33 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model22\n\x13\x43ommunicationServer\x12\x1b\n\x08SendData\x12\x05.Data\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xb9\x01\n\x13\x43ommunicationServer\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'comm_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None + _globals['_PEERIDS_PEERIDSENTRY']._loaded_options = None + _globals['_PEERIDS_PEERIDSENTRY']._serialized_options = b'8\001' _globals['_EMPTY']._serialized_start=14 _globals['_EMPTY']._serialized_end=21 _globals['_MODEL']._serialized_start=23 _globals['_MODEL']._serialized_end=46 _globals['_DATA']._serialized_start=48 _globals['_DATA']._serialized_end=89 - _globals['_COMMUNICATIONSERVER']._serialized_start=91 - _globals['_COMMUNICATIONSERVER']._serialized_end=141 + _globals['_RANK']._serialized_start=91 + _globals['_RANK']._serialized_end=111 + _globals['_PORT']._serialized_start=113 + _globals['_PORT']._serialized_end=133 + _globals['_PEERID']._serialized_start=135 + _globals['_PEERID']._serialized_end=197 + _globals['_PEERIDS']._serialized_start=199 + _globals['_PEERIDS']._serialized_end=306 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=251 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=306 + _globals['_QUORUM']._serialized_start=308 + _globals['_QUORUM']._serialized_end=332 + _globals['_COMMUNICATIONSERVER']._serialized_start=335 + _globals['_COMMUNICATIONSERVER']._serialized_end=520 # @@protoc_insertion_point(module_scope) diff --git a/src/utils/communication/grpc/comm_pb2.pyi b/src/utils/communication/grpc/comm_pb2.pyi new file mode 100644 index 0000000..e81bc4e --- /dev/null +++ b/src/utils/communication/grpc/comm_pb2.pyi @@ -0,0 +1,65 @@ +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Empty(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class Model(_message.Message): + __slots__ = ("buffer",) + BUFFER_FIELD_NUMBER: _ClassVar[int] + buffer: bytes + def __init__(self, buffer: _Optional[bytes] = ...) -> None: ... + +class Data(_message.Message): + __slots__ = ("id", "model") + ID_FIELD_NUMBER: _ClassVar[int] + MODEL_FIELD_NUMBER: _ClassVar[int] + id: str + model: Model + def __init__(self, id: _Optional[str] = ..., model: _Optional[_Union[Model, _Mapping]] = ...) -> None: ... + +class Rank(_message.Message): + __slots__ = ("rank",) + RANK_FIELD_NUMBER: _ClassVar[int] + rank: int + def __init__(self, rank: _Optional[int] = ...) -> None: ... + +class Port(_message.Message): + __slots__ = ("port",) + PORT_FIELD_NUMBER: _ClassVar[int] + port: int + def __init__(self, port: _Optional[int] = ...) -> None: ... + +class PeerId(_message.Message): + __slots__ = ("rank", "port", "ip") + RANK_FIELD_NUMBER: _ClassVar[int] + PORT_FIELD_NUMBER: _ClassVar[int] + IP_FIELD_NUMBER: _ClassVar[int] + rank: Rank + port: Port + ip: str + def __init__(self, rank: _Optional[_Union[Rank, _Mapping]] = ..., port: _Optional[_Union[Port, _Mapping]] = ..., ip: _Optional[str] = ...) -> None: ... + +class PeerIds(_message.Message): + __slots__ = ("peer_ids",) + class PeerIdsEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: int + value: PeerId + def __init__(self, key: _Optional[int] = ..., value: _Optional[_Union[PeerId, _Mapping]] = ...) -> None: ... + PEER_IDS_FIELD_NUMBER: _ClassVar[int] + peer_ids: _containers.MessageMap[int, PeerId] + def __init__(self, peer_ids: _Optional[_Mapping[int, PeerId]] = ...) -> None: ... + +class Quorum(_message.Message): + __slots__ = ("quorum",) + QUORUM_FIELD_NUMBER: _ClassVar[int] + quorum: bool + def __init__(self, quorum: bool = ...) -> None: ... diff --git a/src/utils/communication/grpc/comm_pb2_grpc.py b/src/utils/communication/grpc/comm_pb2_grpc.py index aa28b04..feb7eb8 100644 --- a/src/utils/communication/grpc/comm_pb2_grpc.py +++ b/src/utils/communication/grpc/comm_pb2_grpc.py @@ -39,17 +39,61 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.SendData = channel.unary_unary( - '/CommunicationServer/SendData', + self.send_data = channel.unary_unary( + '/CommunicationServer/send_data', request_serializer=comm__pb2.Data.SerializeToString, response_deserializer=comm__pb2.Empty.FromString, _registered_method=True) + self.get_rank = channel.unary_unary( + '/CommunicationServer/get_rank', + request_serializer=comm__pb2.Empty.SerializeToString, + response_deserializer=comm__pb2.Rank.FromString, + _registered_method=True) + self.update_port = channel.unary_unary( + '/CommunicationServer/update_port', + request_serializer=comm__pb2.PeerId.SerializeToString, + response_deserializer=comm__pb2.Empty.FromString, + _registered_method=True) + self.send_peer_ids = channel.unary_unary( + '/CommunicationServer/send_peer_ids', + request_serializer=comm__pb2.PeerIds.SerializeToString, + response_deserializer=comm__pb2.Empty.FromString, + _registered_method=True) + self.send_quorum = channel.unary_unary( + '/CommunicationServer/send_quorum', + request_serializer=comm__pb2.Quorum.SerializeToString, + response_deserializer=comm__pb2.Empty.FromString, + _registered_method=True) class CommunicationServerServicer(object): """Missing associated documentation comment in .proto file.""" - def SendData(self, request, context): + def send_data(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def get_rank(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def update_port(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def send_peer_ids(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def send_quorum(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -58,11 +102,31 @@ def SendData(self, request, context): def add_CommunicationServerServicer_to_server(servicer, server): rpc_method_handlers = { - 'SendData': grpc.unary_unary_rpc_method_handler( - servicer.SendData, + 'send_data': grpc.unary_unary_rpc_method_handler( + servicer.send_data, request_deserializer=comm__pb2.Data.FromString, response_serializer=comm__pb2.Empty.SerializeToString, ), + 'get_rank': grpc.unary_unary_rpc_method_handler( + servicer.get_rank, + request_deserializer=comm__pb2.Empty.FromString, + response_serializer=comm__pb2.Rank.SerializeToString, + ), + 'update_port': grpc.unary_unary_rpc_method_handler( + servicer.update_port, + request_deserializer=comm__pb2.PeerId.FromString, + response_serializer=comm__pb2.Empty.SerializeToString, + ), + 'send_peer_ids': grpc.unary_unary_rpc_method_handler( + servicer.send_peer_ids, + request_deserializer=comm__pb2.PeerIds.FromString, + response_serializer=comm__pb2.Empty.SerializeToString, + ), + 'send_quorum': grpc.unary_unary_rpc_method_handler( + servicer.send_quorum, + request_deserializer=comm__pb2.Quorum.FromString, + response_serializer=comm__pb2.Empty.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'CommunicationServer', rpc_method_handlers) @@ -75,7 +139,7 @@ class CommunicationServer(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def SendData(request, + def send_data(request, target, options=(), channel_credentials=None, @@ -88,7 +152,7 @@ def SendData(request, return grpc.experimental.unary_unary( request, target, - '/CommunicationServer/SendData', + '/CommunicationServer/send_data', comm__pb2.Data.SerializeToString, comm__pb2.Empty.FromString, options, @@ -100,3 +164,111 @@ def SendData(request, timeout, metadata, _registered_method=True) + + @staticmethod + def get_rank(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/get_rank', + comm__pb2.Empty.SerializeToString, + comm__pb2.Rank.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def update_port(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/update_port', + comm__pb2.PeerId.SerializeToString, + comm__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def send_peer_ids(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/send_peer_ids', + comm__pb2.PeerIds.SerializeToString, + comm__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def send_quorum(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/send_quorum', + comm__pb2.Quorum.SerializeToString, + comm__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py index 9ad9ecc..20e7915 100644 --- a/src/utils/communication/grpc/main.py +++ b/src/utils/communication/grpc/main.py @@ -1,6 +1,12 @@ from concurrent import futures from queue import Queue -from typing import Any, Dict, List, OrderedDict +import random +import re +import threading +import time +import socket +from typing import Any, Dict, List, OrderedDict, Union +from urllib.parse import unquote import grpc # type: ignore from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model import os @@ -14,40 +20,206 @@ import comm_pb2_grpc as comm_pb2_grpc from utils.communication.interface import CommunicationInterface + +def is_port_available(port: int) -> bool: + """ + Check if a port is available for use. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # type: ignore + return s.connect_ex(('localhost', port)) != 0 # type: ignore + +def get_port(rank: int, num_users: int) -> int: + """ + Get the port number for the given rank. + """ + start = 50051 + while True: + port = start + rank + if is_port_available(port): + return port + # if we increment by 1 then it's likely that + # the next node will also have the same port number + start += num_users + if start > 65535: + raise Exception(f"No available ports for node {rank}") + +def parse_peer_address(peer: str) -> str: + # Remove 'ipv4:' or 'ipv6:' prefix + if peer.startswith(('ipv4:', 'ipv6:')): + peer = peer.split(':', 1)[1] + + # Handle IPv6 address + if peer.startswith('['): + # Extract IPv6 address + match = re.match(r'\[([^\]]+)\]', peer) + if match: + return unquote(match.group(1)) # Decode URL-encoded characters + else: + # Handle IPv4 address or hostname + return peer.rsplit(':', 1)[0] # Remove port number + return "" + class Servicer(comm_pb2_grpc.CommunicationServerServicer): - def __init__(self): + def __init__(self, super_node_host: str): + self.lock = threading.Lock() + self.condition = threading.Condition(self.lock) self.received_data: Queue[Any] = Queue() + self.quorum: Queue[bool] = Queue() + port = int(super_node_host.split(":")[1]) + ip = super_node_host.split(":")[0] + self.peer_ids: OrderedDict[int, Dict[str, int|str]] = OrderedDict({ + 0: {"rank": 0, "port": port, "ip": ip} + }) - def SendData(self, request, context) -> comm_pb2.Empty: # type: ignore + def send_data(self, request, context) -> comm_pb2.Empty: # type: ignore self.received_data.put(deserialize_model(request.model.buffer)) # type: ignore return comm_pb2.Empty() # type: ignore + def get_rank(self, request: comm_pb2.Empty, context: grpc.ServicerContext) -> comm_pb2.Rank: # type: ignore + try: + with self.lock: + peer = context.peer() # type: ignore + # parse the hostname from peer + peer_str = parse_peer_address(peer) # type: ignore + rank = len(self.peer_ids) + # TODO: index the peer_ids by a unique identifier + self.peer_ids[rank] = {"rank": rank, "port": 0, "ip": peer_str} + rank = self.peer_ids[rank].get("rank", -1) # Default to -1 if not found + return comm_pb2.Rank(rank=rank) # type: ignore + except Exception as e: + context.abort(grpc.StatusCode.INTERNAL, f"Error in get_rank: {str(e)}") # type: ignore + + def update_port(self, request: comm_pb2.PeerIds, context: grpc.ServicerContext) -> comm_pb2.Empty: + with self.lock: + self.peer_ids[request.rank.rank]["port"] = request.port.port # type: ignore + return comm_pb2.Empty() # type: ignore + + def send_peer_ids(self, request, context) -> comm_pb2.Empty: # type: ignore + """ + Used by the super node to update all peers with the peer_ids + after achieving quorum. + """ + peer_ids: comm_pb2.PeerIds = request.peer_ids # type: ignore + for rank in peer_ids: # type: ignore + peer_id_proto = peer_ids[rank] # type: ignore + peer_id_dict: Dict[str, Union[int, str]] = { + "rank": peer_id_proto.rank.rank, # type: ignore + "port": peer_id_proto.port.port, # type: ignore + "ip": peer_id_proto.ip # type: ignore + } + self.peer_ids[rank] = peer_id_dict + return comm_pb2.Empty() + + def send_quorum(self, request, context) -> comm_pb2.Empty: # type: ignore + self.quorum.put(request.quorum) # type: ignore + return comm_pb2.Empty() # type: ignore + + class GRPCCommunication(CommunicationInterface): def __init__(self, config: Dict[str, Dict[str, Any]]): # TODO: Implement this differently later by creating a super node # that maintains a list of all peers # all peers will send their IDs to this super node # when they start. - self.all_peer_ids: List[str] = config["comm"]["all_peer_ids"] - self.rank: int = config["comm"]["rank"] - address: str = str(self.all_peer_ids[self.rank]) - self.port: int = int(address.split(":")[1]) - self.host: str = address.split(":")[0] - self.server = None - self.servicer = Servicer() - - # remove self from the list of all peer IDs - self.all_peer_ids.remove(self.host + ":" + str(self.port)) - - def initialize(self): - self.listener: Any = grpc.server(futures.ThreadPoolExecutor(max_workers=10), options=[ # type: ignore + # The implementation will have + # 1. Registration phase where every peer registers itself and gets a rank + # 2. Once a threshold number of peers have registered, the super node sets quorum to True + # 3. The super node broadcasts the peer_ids to all peers + # 4. The nodes will execute rest of the protocol in the same way as before + self.num_users: int = int(config["num_users"]) # type: ignore + self.rank: int|None = config["comm"]["rank"] + self.super_node_host: str = config["comm"]["peer_ids"][0] + if self.rank == 0: + node_id: List[str] = self.super_node_host.split(":") + self.host: str = node_id[0] + self.port: int = int(node_id[1]) + self.listener: Any = None + self.servicer = Servicer(self.super_node_host) + + @staticmethod + def get_registered_users(peer_ids: OrderedDict[int, Dict[str, int|str]]) -> int: + # count the number of entries that have a non-zero port + return len([peer_id for peer_id, values in peer_ids.items() if values.get("port") != 0]) + + def register(self): + with grpc.insecure_channel(self.super_node_host) as channel: # type: ignore + stub = comm_pb2_grpc.CommunicationServerStub(channel) + max_tries = 5 + while max_tries > 0: + try: + self.rank = stub.get_rank(comm_pb2.Empty()).rank # type: ignore + break + except grpc.RpcError as e: + print(f"RPC failed: {e}", "Retrying...") + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 + self.port = get_port(self.rank, self.num_users) # type: ignore because we are setting it in the register method + rank = comm_pb2.Rank(rank=self.rank) # type: ignore + port = comm_pb2.Port(port=self.port) + peer_id = comm_pb2.PeerId(rank=rank, port=port) + stub.update_port(peer_id) # type: ignore + + def start_listener(self): + self.listener = grpc.server(futures.ThreadPoolExecutor(max_workers=4), options=[ # type: ignore ('grpc.max_send_message_length', 100 * 1024 * 1024), # 100MB ('grpc.max_receive_message_length', 100 * 1024 * 1024), # 100MB ]) comm_pb2_grpc.add_CommunicationServerServicer_to_server(self.servicer, self.listener) # type: ignore self.listener.add_insecure_port(f'[::]:{self.port}') self.listener.start() - print(f'Started server on port {self.port}') + print(f'Started listener on port {self.port}') + + def peer_ids_to_proto(self, peer_ids: OrderedDict[int, Dict[str, int|str]]) -> Dict[int, comm_pb2.PeerId]: + peer_ids_proto: Dict[int, comm_pb2.PeerId] = {} + for peer_id in peer_ids: + rank = comm_pb2.Rank(rank=peer_ids[peer_id].get("rank")) # type: ignore + port = comm_pb2.Port(port=peer_ids[peer_id].get("port")) # type: ignore + ip = str(peer_ids[peer_id].get("ip")) + peer_ids_proto[peer_id] = comm_pb2.PeerId(rank=rank, port=port, ip=ip) + return peer_ids_proto + + def initialize(self): + if self.rank != 0: + self.register() + + self.start_listener() + + # wait for the quorum to be set + if self.rank != 0: + status = self.servicer.quorum.get() + if not status: + print("Quorum became false!") + sys.exit(1) + else: + quorum_threshold = self.num_users + 1 # +1 for the super node + while self.get_registered_users(self.servicer.peer_ids) < quorum_threshold: + # sleep for 5 seconds + print(f"Waiting for {quorum_threshold} users to register") + time.sleep(5) + # TODO: Implement a timeout here and if the timeout is reached + # then set quorum to False for all registered users + # and exit the program + + print("All users have registered", self.servicer.peer_ids) + for peer_id in self.servicer.peer_ids: + host_ip = self.servicer.peer_ids[peer_id].get("ip") + own_ip = self.servicer.peer_ids[0].get("ip") + if host_ip != own_ip: + port = self.servicer.peer_ids[peer_id].get("port") + address = f"{host_ip}:{port}" + with grpc.insecure_channel(address) as channel: # type: ignore + stub = comm_pb2_grpc.CommunicationServerStub(channel) + proto_msg = comm_pb2.PeerIds(peer_ids=self.peer_ids_to_proto(self.servicer.peer_ids)) + stub.send_peer_ids(proto_msg) # type: ignore + stub.send_quorum(comm_pb2.Quorum(quorum=True)) # type: ignore + + def get_host_from_rank(self, rank: int) -> str: + for peer_id in self.servicer.peer_ids: + if self.servicer.peer_ids[peer_id].get("rank") == rank: + return self.servicer.peer_ids[peer_id].get("ip") + ":" + str(self.servicer.peer_ids[peer_id].get("port")) # type: ignore + raise Exception(f"Rank {rank} not found in peer_ids") def send(self, dest: str|int, data: OrderedDict[str, Any]): """ @@ -55,7 +227,7 @@ def send(self, dest: str|int, data: OrderedDict[str, Any]): """ dest_host: str = "" if type(dest) == int: - dest_host = self.all_peer_ids[int(dest)] + dest_host = self.get_host_from_rank(dest) else: dest_host = str(dest) try: @@ -63,7 +235,7 @@ def send(self, dest: str|int, data: OrderedDict[str, Any]): with grpc.insecure_channel(dest_host) as channel: # type: ignore stub = comm_pb2_grpc.CommunicationServerStub(channel) # type: ignore model = comm_pb2.Model(buffer=buffer) # type: ignore - stub.SendData(comm_pb2.Data(model=model, id="tempID")) # type: ignore + stub.send_data(comm_pb2.Data(model=model, id=str(self.rank))) # type: ignore except grpc.RpcError as e: print(f"RPC failed: {e}") sys.exit(1) @@ -73,16 +245,24 @@ def receive(self, node_ids: str|int) -> Any: # at least 1 item is received in the queue return self.servicer.received_data.get() + def is_own_id(self, peer_id: int) -> bool: + rank = self.servicer.peer_ids[peer_id].get("rank") + if rank != self.rank: + return False + return True + def broadcast(self, data: Any): - for peer_id in self.all_peer_ids: - self.send(peer_id, data) + for peer_id in self.servicer.peer_ids: + if not self.is_own_id(peer_id): + self.send(peer_id, data) def all_gather(self) -> Any: # this will block until all items are received # from all peers items: List[Any] = [] - for peer_id in self.all_peer_ids: - items.append(self.receive(peer_id)) + for peer_id in self.servicer.peer_ids: + if not self.is_own_id(peer_id): + items.append(self.receive(peer_id)) return items def finalize(self):