diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..bcaa0fe --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.analysis.extraPaths": [ + "./src" + ], + "python.analysis.autoImportCompletions": true, + "python.analysis.typeCheckingMode": "basic" +} \ No newline at end of file diff --git a/grpc_expts/client.py b/grpc_expts/client.py index b0aa3c7..843158e 100644 --- a/grpc_expts/client.py +++ b/grpc_expts/client.py @@ -57,6 +57,7 @@ def run_client(args: argparse.Namespace): log_utils = LogUtils(config) log_utils.log_console(f'user got {user_id.id} {user_id.num}') node_id = user_id.num % TEMP_TOTAL_NODES + print(node_id, user_id.num) device = f'cuda:{node_id + device_offset}' dset_obj = get_dataset(dset, dpath=dpath) train_dset = dset_obj.train_dset diff --git a/src/utils/communication/comm_utils.py b/src/utils/communication/comm_utils.py new file mode 100644 index 0000000..656ba76 --- /dev/null +++ b/src/utils/communication/comm_utils.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +class CommunicationInterface(ABC): + def __init__(self): + pass + + @abstractmethod + def initialize(self): + pass + + @abstractmethod + def send(self, dest, data): + pass + + @abstractmethod + def receive(self, node_ids, data): + pass + + @abstractmethod + def broadcast(self, data): + pass + + @abstractmethod + def finalize(self): + pass diff --git a/src/utils/communication/grpc/comm.proto b/src/utils/communication/grpc/comm.proto new file mode 100644 index 0000000..d0ee10e --- /dev/null +++ b/src/utils/communication/grpc/comm.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +service CommunicationServer { + rpc SendData (Data) returns (Empty) {} +} + +message Empty {} + +message Data { + string id = 1; + Model model = 2; +} + +message ID { + int32 num = 1; + string id = 2; +} diff --git a/src/utils/communication/grpc/comm_pb2.py b/src/utils/communication/grpc/comm_pb2.py new file mode 100644 index 0000000..80557d2 --- /dev/null +++ b/src/utils/communication/grpc/comm_pb2.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: comm.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x1d\n\x02ID\x12\x0b\n\x03num\x18\x01 \x01(\x05\x12\n\n\x02id\x18\x02 \x01(\t\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x32L\n\x13\x43ommunicationServer\x12\x18\n\x08GetModel\x12\x03.ID\x1a\x05.Data\"\x00\x12\x1b\n\x08SendData\x12\x05.Data\x1a\x06.Empty\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'comm_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_EMPTY']._serialized_start=14 + _globals['_EMPTY']._serialized_end=21 + _globals['_DATA']._serialized_start=23 + _globals['_DATA']._serialized_end=64 + _globals['_ID']._serialized_start=66 + _globals['_ID']._serialized_end=95 + _globals['_MODEL']._serialized_start=97 + _globals['_MODEL']._serialized_end=120 + _globals['_COMMUNICATIONSERVER']._serialized_start=122 + _globals['_COMMUNICATIONSERVER']._serialized_end=198 +# @@protoc_insertion_point(module_scope) diff --git a/src/utils/communication/grpc/comm_pb2_grpc.py b/src/utils/communication/grpc/comm_pb2_grpc.py new file mode 100644 index 0000000..705549b --- /dev/null +++ b/src/utils/communication/grpc/comm_pb2_grpc.py @@ -0,0 +1,145 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +import comm_pb2 as comm__pb2 + +GRPC_GENERATED_VERSION = '1.64.0' +GRPC_VERSION = grpc.__version__ +EXPECTED_ERROR_RELEASE = '1.65.0' +SCHEDULED_RELEASE_DATE = 'June 25, 2024' +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + warnings.warn( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in comm_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', + RuntimeWarning + ) + + +class CommunicationServerStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GetModel = channel.unary_unary( + '/CommunicationServer/GetModel', + request_serializer=comm__pb2.ID.SerializeToString, + response_deserializer=comm__pb2.Data.FromString, + _registered_method=True) + self.SendData = channel.unary_unary( + '/CommunicationServer/SendData', + request_serializer=comm__pb2.Data.SerializeToString, + response_deserializer=comm__pb2.Empty.FromString, + _registered_method=True) + + +class CommunicationServerServicer(object): + """Missing associated documentation comment in .proto file.""" + + def GetModel(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendData(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_CommunicationServerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GetModel': grpc.unary_unary_rpc_method_handler( + servicer.GetModel, + request_deserializer=comm__pb2.ID.FromString, + response_serializer=comm__pb2.Data.SerializeToString, + ), + 'SendData': grpc.unary_unary_rpc_method_handler( + servicer.SendData, + request_deserializer=comm__pb2.Data.FromString, + response_serializer=comm__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'CommunicationServer', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('CommunicationServer', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class CommunicationServer(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def GetModel(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/GetModel', + comm__pb2.ID.SerializeToString, + comm__pb2.Data.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendData(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/SendData', + comm__pb2.Data.SerializeToString, + comm__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/utils/communication/grpc/grpc_utils.py b/src/utils/communication/grpc/grpc_utils.py new file mode 100644 index 0000000..a7adfd4 --- /dev/null +++ b/src/utils/communication/grpc/grpc_utils.py @@ -0,0 +1,15 @@ +from collections import OrderedDict +import io +import torch + +def serialize_model(state_dict: OrderedDict) -> bytes: + buffer = io.BytesIO() + torch.save(state_dict, buffer) + buffer.seek(0) + return buffer.read() + +def deserialize_model(model_bytes: bytes) -> OrderedDict: + buffer = io.BytesIO(model_bytes) + buffer.seek(0) + model_wts = torch.load(buffer) + return model_wts diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py new file mode 100644 index 0000000..4ebc4c8 --- /dev/null +++ b/src/utils/communication/grpc/main.py @@ -0,0 +1,59 @@ +from concurrent import futures +import queue +import grpc +from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model +import utils.communication.grpc.comm_pb2 as comm_pb2 +import utils.communication.grpc.comm_pb2_grpc as comm_pb2_grpc +from utils.communication.comm_utils import CommunicationInterface + +class Servicer(comm_pb2_grpc.CommunicationServerServicer): + def __init__(self): + self.received_data = queue.Queue() + + def SendData(self, request, context): + self.received_data.put(deserialize_model(request.model.buffer)) + return comm_pb2.Empty() + +class GRPCCommunication(CommunicationInterface): + def __init__(self, port): + self.port = port + self.server = None + self.servicer = Servicer() + # TODO: Implement this later by creating a super node + # that maintains a list of all peers + # all peers will send their IDs to this super node + # when they start. + self.all_peer_ids = [] + + def initialize(self): + self.listener = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + comm_pb2_grpc.add_CommunicationServerServicer_to_server(self.servicer, self.listener) + self.listener.add_insecure_port(f'[::]:{self.port}') + self.listener.start() + print(f'Started server on port {self.port}') + + def send(self, dest, data): + """ + data should be a torch model + """ + try: + buffer = serialize_model(data.state_dict()) + with grpc.insecure_channel(dest) as channel: + stub = comm_pb2_grpc.CommunicationServerStub(channel) + stub.SendData(comm_pb2.Model(buffer=buffer)) + except grpc.RpcError as e: + print(f"RPC failed: {e}") + + def receive(self, node_ids, data): + # this .get() will block until + # at least 1 item is received in the queue + return self.servicer.received_data.get() + + def broadcast(self, data): + for peer_id in self.all_peer_ids: + self.send(peer_id, data) + + def finalize(self): + if self.listener: + self.listener.stop(0) + print(f'Stopped server on port {self.port}') diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py new file mode 100644 index 0000000..76df139 --- /dev/null +++ b/src/utils/communication/mpi.py @@ -0,0 +1,25 @@ +from mpi4py import MPI +from utils.communication.comm_utils import CommunicationInterface + +class MPICommUtils(CommunicationInterface): + def __init__(self): + self.comm = MPI.COMM_WORLD + self.rank = self.comm.Get_rank() + self.size = self.comm.Get_size() + + def initialize(self): + pass + + def send(self, dest, data): + self.comm.send(data, dest=dest) + + def receive(self, node_ids, data): + return self.comm.recv(source=node_ids) + + def broadcast(self, data): + for i in range(1, self.size): + if i != self.rank: + self.send(i, data) + + def finalize(self): + pass