-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement the new communication library
- Loading branch information
Showing
9 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"python.analysis.extraPaths": [ | ||
"./src" | ||
], | ||
"python.analysis.autoImportCompletions": true, | ||
"python.analysis.typeCheckingMode": "basic" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
class CommunicationInterface(ABC): | ||
def __init__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def initialize(self): | ||
pass | ||
|
||
@abstractmethod | ||
def send(self, dest, data): | ||
pass | ||
|
||
@abstractmethod | ||
def receive(self, node_ids, data): | ||
pass | ||
|
||
@abstractmethod | ||
def broadcast(self, data): | ||
pass | ||
|
||
@abstractmethod | ||
def finalize(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
syntax = "proto3"; | ||
|
||
service CommunicationServer { | ||
rpc SendData (Data) returns (Empty) {} | ||
} | ||
|
||
message Empty {} | ||
|
||
message Data { | ||
string id = 1; | ||
Model model = 2; | ||
} | ||
|
||
message ID { | ||
int32 num = 1; | ||
string id = 2; | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! | ||
"""Client and server classes corresponding to protobuf-defined services.""" | ||
import grpc | ||
import warnings | ||
|
||
import comm_pb2 as comm__pb2 | ||
|
||
GRPC_GENERATED_VERSION = '1.64.0' | ||
GRPC_VERSION = grpc.__version__ | ||
EXPECTED_ERROR_RELEASE = '1.65.0' | ||
SCHEDULED_RELEASE_DATE = 'June 25, 2024' | ||
_version_not_supported = False | ||
|
||
try: | ||
from grpc._utilities import first_version_is_lower | ||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) | ||
except ImportError: | ||
_version_not_supported = True | ||
|
||
if _version_not_supported: | ||
warnings.warn( | ||
f'The grpc package installed is at version {GRPC_VERSION},' | ||
+ f' but the generated code in comm_pb2_grpc.py depends on' | ||
+ f' grpcio>={GRPC_GENERATED_VERSION}.' | ||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' | ||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' | ||
+ f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' | ||
+ f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', | ||
RuntimeWarning | ||
) | ||
|
||
|
||
class CommunicationServerStub(object): | ||
"""Missing associated documentation comment in .proto file.""" | ||
|
||
def __init__(self, channel): | ||
"""Constructor. | ||
Args: | ||
channel: A grpc.Channel. | ||
""" | ||
self.GetModel = channel.unary_unary( | ||
'/CommunicationServer/GetModel', | ||
request_serializer=comm__pb2.ID.SerializeToString, | ||
response_deserializer=comm__pb2.Data.FromString, | ||
_registered_method=True) | ||
self.SendData = channel.unary_unary( | ||
'/CommunicationServer/SendData', | ||
request_serializer=comm__pb2.Data.SerializeToString, | ||
response_deserializer=comm__pb2.Empty.FromString, | ||
_registered_method=True) | ||
|
||
|
||
class CommunicationServerServicer(object): | ||
"""Missing associated documentation comment in .proto file.""" | ||
|
||
def GetModel(self, request, context): | ||
"""Missing associated documentation comment in .proto file.""" | ||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) | ||
context.set_details('Method not implemented!') | ||
raise NotImplementedError('Method not implemented!') | ||
|
||
def SendData(self, request, context): | ||
"""Missing associated documentation comment in .proto file.""" | ||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) | ||
context.set_details('Method not implemented!') | ||
raise NotImplementedError('Method not implemented!') | ||
|
||
|
||
def add_CommunicationServerServicer_to_server(servicer, server): | ||
rpc_method_handlers = { | ||
'GetModel': grpc.unary_unary_rpc_method_handler( | ||
servicer.GetModel, | ||
request_deserializer=comm__pb2.ID.FromString, | ||
response_serializer=comm__pb2.Data.SerializeToString, | ||
), | ||
'SendData': grpc.unary_unary_rpc_method_handler( | ||
servicer.SendData, | ||
request_deserializer=comm__pb2.Data.FromString, | ||
response_serializer=comm__pb2.Empty.SerializeToString, | ||
), | ||
} | ||
generic_handler = grpc.method_handlers_generic_handler( | ||
'CommunicationServer', rpc_method_handlers) | ||
server.add_generic_rpc_handlers((generic_handler,)) | ||
server.add_registered_method_handlers('CommunicationServer', rpc_method_handlers) | ||
|
||
|
||
# This class is part of an EXPERIMENTAL API. | ||
class CommunicationServer(object): | ||
"""Missing associated documentation comment in .proto file.""" | ||
|
||
@staticmethod | ||
def GetModel(request, | ||
target, | ||
options=(), | ||
channel_credentials=None, | ||
call_credentials=None, | ||
insecure=False, | ||
compression=None, | ||
wait_for_ready=None, | ||
timeout=None, | ||
metadata=None): | ||
return grpc.experimental.unary_unary( | ||
request, | ||
target, | ||
'/CommunicationServer/GetModel', | ||
comm__pb2.ID.SerializeToString, | ||
comm__pb2.Data.FromString, | ||
options, | ||
channel_credentials, | ||
insecure, | ||
call_credentials, | ||
compression, | ||
wait_for_ready, | ||
timeout, | ||
metadata, | ||
_registered_method=True) | ||
|
||
@staticmethod | ||
def SendData(request, | ||
target, | ||
options=(), | ||
channel_credentials=None, | ||
call_credentials=None, | ||
insecure=False, | ||
compression=None, | ||
wait_for_ready=None, | ||
timeout=None, | ||
metadata=None): | ||
return grpc.experimental.unary_unary( | ||
request, | ||
target, | ||
'/CommunicationServer/SendData', | ||
comm__pb2.Data.SerializeToString, | ||
comm__pb2.Empty.FromString, | ||
options, | ||
channel_credentials, | ||
insecure, | ||
call_credentials, | ||
compression, | ||
wait_for_ready, | ||
timeout, | ||
metadata, | ||
_registered_method=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from collections import OrderedDict | ||
import io | ||
import torch | ||
|
||
def serialize_model(state_dict: OrderedDict) -> bytes: | ||
buffer = io.BytesIO() | ||
torch.save(state_dict, buffer) | ||
buffer.seek(0) | ||
return buffer.read() | ||
|
||
def deserialize_model(model_bytes: bytes) -> OrderedDict: | ||
buffer = io.BytesIO(model_bytes) | ||
buffer.seek(0) | ||
model_wts = torch.load(buffer) | ||
return model_wts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from concurrent import futures | ||
import queue | ||
import grpc | ||
from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model | ||
import utils.communication.grpc.comm_pb2 as comm_pb2 | ||
import utils.communication.grpc.comm_pb2_grpc as comm_pb2_grpc | ||
from utils.communication.comm_utils import CommunicationInterface | ||
|
||
class Servicer(comm_pb2_grpc.CommunicationServerServicer): | ||
def __init__(self): | ||
self.received_data = queue.Queue() | ||
|
||
def SendData(self, request, context): | ||
self.received_data.put(deserialize_model(request.model.buffer)) | ||
return comm_pb2.Empty() | ||
|
||
class GRPCCommunication(CommunicationInterface): | ||
def __init__(self, port): | ||
self.port = port | ||
self.server = None | ||
self.servicer = Servicer() | ||
# TODO: Implement this later by creating a super node | ||
# that maintains a list of all peers | ||
# all peers will send their IDs to this super node | ||
# when they start. | ||
self.all_peer_ids = [] | ||
|
||
def initialize(self): | ||
self.listener = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | ||
comm_pb2_grpc.add_CommunicationServerServicer_to_server(self.servicer, self.listener) | ||
self.listener.add_insecure_port(f'[::]:{self.port}') | ||
self.listener.start() | ||
print(f'Started server on port {self.port}') | ||
|
||
def send(self, dest, data): | ||
""" | ||
data should be a torch model | ||
""" | ||
try: | ||
buffer = serialize_model(data.state_dict()) | ||
with grpc.insecure_channel(dest) as channel: | ||
stub = comm_pb2_grpc.CommunicationServerStub(channel) | ||
stub.SendData(comm_pb2.Model(buffer=buffer)) | ||
except grpc.RpcError as e: | ||
print(f"RPC failed: {e}") | ||
|
||
def receive(self, node_ids, data): | ||
# this .get() will block until | ||
# at least 1 item is received in the queue | ||
return self.servicer.received_data.get() | ||
|
||
def broadcast(self, data): | ||
for peer_id in self.all_peer_ids: | ||
self.send(peer_id, data) | ||
|
||
def finalize(self): | ||
if self.listener: | ||
self.listener.stop(0) | ||
print(f'Stopped server on port {self.port}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from mpi4py import MPI | ||
from utils.communication.comm_utils import CommunicationInterface | ||
|
||
class MPICommUtils(CommunicationInterface): | ||
def __init__(self): | ||
self.comm = MPI.COMM_WORLD | ||
self.rank = self.comm.Get_rank() | ||
self.size = self.comm.Get_size() | ||
|
||
def initialize(self): | ||
pass | ||
|
||
def send(self, dest, data): | ||
self.comm.send(data, dest=dest) | ||
|
||
def receive(self, node_ids, data): | ||
return self.comm.recv(source=node_ids) | ||
|
||
def broadcast(self, data): | ||
for i in range(1, self.size): | ||
if i != self.rank: | ||
self.send(i, data) | ||
|
||
def finalize(self): | ||
pass |