Skip to content

Commit

Permalink
implement the new communication library
Browse files Browse the repository at this point in the history
  • Loading branch information
tremblerz committed Aug 11, 2024
1 parent d0cd873 commit 1b96ffe
Show file tree
Hide file tree
Showing 9 changed files with 328 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.analysis.extraPaths": [
"./src"
],
"python.analysis.autoImportCompletions": true,
"python.analysis.typeCheckingMode": "basic"
}
1 change: 1 addition & 0 deletions grpc_expts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def run_client(args: argparse.Namespace):
log_utils = LogUtils(config)
log_utils.log_console(f'user got {user_id.id} {user_id.num}')
node_id = user_id.num % TEMP_TOTAL_NODES
print(node_id, user_id.num)
device = f'cuda:{node_id + device_offset}'
dset_obj = get_dataset(dset, dpath=dpath)
train_dset = dset_obj.train_dset
Expand Down
25 changes: 25 additions & 0 deletions src/utils/communication/comm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod

class CommunicationInterface(ABC):
def __init__(self):
pass

@abstractmethod
def initialize(self):
pass

@abstractmethod
def send(self, dest, data):
pass

@abstractmethod
def receive(self, node_ids, data):
pass

@abstractmethod
def broadcast(self, data):
pass

@abstractmethod
def finalize(self):
pass
17 changes: 17 additions & 0 deletions src/utils/communication/grpc/comm.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
syntax = "proto3";

service CommunicationServer {
rpc SendData (Data) returns (Empty) {}
}

message Empty {}

message Data {
string id = 1;
Model model = 2;
}

message ID {
int32 num = 1;
string id = 2;
}
34 changes: 34 additions & 0 deletions src/utils/communication/grpc/comm_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

145 changes: 145 additions & 0 deletions src/utils/communication/grpc/comm_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings

import comm_pb2 as comm__pb2

GRPC_GENERATED_VERSION = '1.64.0'
GRPC_VERSION = grpc.__version__
EXPECTED_ERROR_RELEASE = '1.65.0'
SCHEDULED_RELEASE_DATE = 'June 25, 2024'
_version_not_supported = False

try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True

if _version_not_supported:
warnings.warn(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in comm_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+ f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+ f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
RuntimeWarning
)


class CommunicationServerStub(object):
"""Missing associated documentation comment in .proto file."""

def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.GetModel = channel.unary_unary(
'/CommunicationServer/GetModel',
request_serializer=comm__pb2.ID.SerializeToString,
response_deserializer=comm__pb2.Data.FromString,
_registered_method=True)
self.SendData = channel.unary_unary(
'/CommunicationServer/SendData',
request_serializer=comm__pb2.Data.SerializeToString,
response_deserializer=comm__pb2.Empty.FromString,
_registered_method=True)


class CommunicationServerServicer(object):
"""Missing associated documentation comment in .proto file."""

def GetModel(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def SendData(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_CommunicationServerServicer_to_server(servicer, server):
rpc_method_handlers = {
'GetModel': grpc.unary_unary_rpc_method_handler(
servicer.GetModel,
request_deserializer=comm__pb2.ID.FromString,
response_serializer=comm__pb2.Data.SerializeToString,
),
'SendData': grpc.unary_unary_rpc_method_handler(
servicer.SendData,
request_deserializer=comm__pb2.Data.FromString,
response_serializer=comm__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'CommunicationServer', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('CommunicationServer', rpc_method_handlers)


# This class is part of an EXPERIMENTAL API.
class CommunicationServer(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
def GetModel(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/CommunicationServer/GetModel',
comm__pb2.ID.SerializeToString,
comm__pb2.Data.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def SendData(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/CommunicationServer/SendData',
comm__pb2.Data.SerializeToString,
comm__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
15 changes: 15 additions & 0 deletions src/utils/communication/grpc/grpc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from collections import OrderedDict
import io
import torch

def serialize_model(state_dict: OrderedDict) -> bytes:
buffer = io.BytesIO()
torch.save(state_dict, buffer)
buffer.seek(0)
return buffer.read()

def deserialize_model(model_bytes: bytes) -> OrderedDict:
buffer = io.BytesIO(model_bytes)
buffer.seek(0)
model_wts = torch.load(buffer)
return model_wts
59 changes: 59 additions & 0 deletions src/utils/communication/grpc/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from concurrent import futures
import queue
import grpc
from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model
import utils.communication.grpc.comm_pb2 as comm_pb2
import utils.communication.grpc.comm_pb2_grpc as comm_pb2_grpc
from utils.communication.comm_utils import CommunicationInterface

class Servicer(comm_pb2_grpc.CommunicationServerServicer):
def __init__(self):
self.received_data = queue.Queue()

def SendData(self, request, context):
self.received_data.put(deserialize_model(request.model.buffer))
return comm_pb2.Empty()

class GRPCCommunication(CommunicationInterface):
def __init__(self, port):
self.port = port
self.server = None
self.servicer = Servicer()
# TODO: Implement this later by creating a super node
# that maintains a list of all peers
# all peers will send their IDs to this super node
# when they start.
self.all_peer_ids = []

def initialize(self):
self.listener = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
comm_pb2_grpc.add_CommunicationServerServicer_to_server(self.servicer, self.listener)
self.listener.add_insecure_port(f'[::]:{self.port}')
self.listener.start()
print(f'Started server on port {self.port}')

def send(self, dest, data):
"""
data should be a torch model
"""
try:
buffer = serialize_model(data.state_dict())
with grpc.insecure_channel(dest) as channel:
stub = comm_pb2_grpc.CommunicationServerStub(channel)
stub.SendData(comm_pb2.Model(buffer=buffer))
except grpc.RpcError as e:
print(f"RPC failed: {e}")

def receive(self, node_ids, data):
# this .get() will block until
# at least 1 item is received in the queue
return self.servicer.received_data.get()

def broadcast(self, data):
for peer_id in self.all_peer_ids:
self.send(peer_id, data)

def finalize(self):
if self.listener:
self.listener.stop(0)
print(f'Stopped server on port {self.port}')
25 changes: 25 additions & 0 deletions src/utils/communication/mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from mpi4py import MPI
from utils.communication.comm_utils import CommunicationInterface

class MPICommUtils(CommunicationInterface):
def __init__(self):
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()

def initialize(self):
pass

def send(self, dest, data):
self.comm.send(data, dest=dest)

def receive(self, node_ids, data):
return self.comm.recv(source=node_ids)

def broadcast(self, data):
for i in range(1, self.size):
if i != self.rank:
self.send(i, data)

def finalize(self):
pass

0 comments on commit 1b96ffe

Please sign in to comment.