Skip to content

Commit

Permalink
make run_id mandatory
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jun 19, 2024
1 parent 0a09223 commit 5461e5d
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 189 deletions.
124 changes: 50 additions & 74 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@

import time
import warnings
from logging import DEBUG, ERROR, WARNING
from logging import DEBUG, ERROR
from typing import Iterable, List, Optional, Tuple

import grpc

from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
from flwr.common.grpc import create_channel
from flwr.common.logger import log
Expand Down Expand Up @@ -48,103 +46,94 @@
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
[Driver] Error: Not connected.
Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other
`GrpcDriverHelper` methods.
Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
`GrpcDriverStub` methods.
"""


class GrpcDriverHelper:
"""`GrpcDriverHelper` provides access to the gRPC Driver API/service."""
class GrpcDriverStub(DriverStub):
"""`GrpcDriverStub` provides access to the gRPC Driver API/service."""

def __init__(
self,
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
root_certificates: Optional[bytes] = None,
) -> None:
event(EventType.DRIVER_CONNECT)
self.driver_service_address = driver_service_address
self.root_certificates = root_certificates
self.channel: Optional[grpc.Channel] = None
self.stub: Optional[DriverStub] = None

def connect(self) -> None:
"""Connect to the Driver API."""
event(EventType.DRIVER_CONNECT)
if self.channel is not None or self.stub is not None:
log(WARNING, "Already connected")
return
self.channel = create_channel(
server_address=self.driver_service_address,
insecure=(self.root_certificates is None),
root_certificates=self.root_certificates,
)
self.stub = DriverStub(self.channel)
super().__init__(self.channel)
log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)

def disconnect(self) -> None:
"""Disconnect from the Driver API."""
event(EventType.DRIVER_DISCONNECT)
if self.channel is None or self.stub is None:
if self.channel is None:
log(DEBUG, "Already disconnected")
return
channel = self.channel
self.channel = None
self.stub = None
channel.close()
log(DEBUG, "[Driver] Disconnected")

def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
"""Request for run ID."""
# Check if channel is open
if self.stub is None:
if self.channel is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverHelper` instance not connected")
raise ConnectionError("`GrpcDriverStub` instance not connected")

# Call Driver API
res: CreateRunResponse = self.stub.CreateRun(request=req)
res: CreateRunResponse = self.CreateRun(request=req)
return res

def get_run(self, req: GetRunRequest) -> GetRunResponse:
"""Get run information."""
# Check if channel is open
if self.stub is None:
if self.channel is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverHelper` instance not connected")
raise ConnectionError("`GrpcDriverStub` instance not connected")

# Call gRPC Driver API
res: GetRunResponse = self.stub.GetRun(request=req)
res: GetRunResponse = self.GetRun(request=req)
return res

def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
"""Get client IDs."""
# Check if channel is open
if self.stub is None:
if self.channel is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverHelper` instance not connected")
raise ConnectionError("`GrpcDriverStub` instance not connected")

# Call gRPC Driver API
res: GetNodesResponse = self.stub.GetNodes(request=req)
res: GetNodesResponse = self.GetNodes(request=req)
return res

def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
"""Schedule tasks."""
# Check if channel is open
if self.stub is None:
if self.channel is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverHelper` instance not connected")
raise ConnectionError("`GrpcDriverStub` instance not connected")

# Call gRPC Driver API
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
res: PushTaskInsResponse = self.PushTaskIns(request=req)
return res

def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
"""Get task results."""
# Check if channel is open
if self.stub is None:
if self.channel is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverHelper` instance not connected")
raise ConnectionError("`GrpcDriverStub` instance not connected")

# Call Driver API
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
res: PullTaskResResponse = self.PullTaskRes(request=req)
return res


Expand Down Expand Up @@ -172,18 +161,14 @@ class GrpcDriver(Driver):

def __init__( # pylint: disable=too-many-arguments
self,
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
root_certificates: Optional[bytes] = None,
fab_id: Optional[str] = None,
fab_version: Optional[str] = None,
run_id: Optional[int] = None,
run_id: int,
stub: Optional[GrpcDriverStub] = None,
) -> None:
self.addr = driver_service_address
self.root_certificates = root_certificates
self.driver_helper: Optional[GrpcDriverHelper] = None
self.stub = stub
self._run_id = run_id
self._fab_id = fab_id if fab_id is not None else ""
self._fab_ver = fab_version if fab_version is not None else ""
self._fab_id = ""
self._fab_ver = ""
self._has_initialized = False
self.node = Node(node_id=0, anonymous=True)

@property
Expand All @@ -196,32 +181,23 @@ def run(self) -> Run:
fab_version=self._fab_ver,
)

def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
# Check if the GrpcDriverHelper is initialized
if self.driver_helper is None or self._run_id is None:
def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
# Check if the GrpcDriverStub is initialized
if not self._has_initialized or self.stub is None:
# Connect and create run
self.driver_helper = GrpcDriverHelper(
driver_service_address=self.addr,
root_certificates=self.root_certificates,
)
self.driver_helper.connect()
# Create the run if the run_id is not provided
if self._run_id is None:
create_run_req = CreateRunRequest(
fab_id=self._fab_id, fab_version=self._fab_ver
)
create_run_res = self.driver_helper.create_run(create_run_req)
self._run_id = create_run_res.run_id
# Get the run if the run_id is provided
else:
get_run_req = GetRunRequest(run_id=self._run_id)
get_run_res = self.driver_helper.get_run(get_run_req)
if not get_run_res.HasField("run"):
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
self._fab_id = get_run_res.run.fab_id
self._fab_ver = get_run_res.run.fab_version

return self.driver_helper, self._run_id
if self.stub is None:
self.stub = GrpcDriverStub()

# Get the run info
req = GetRunRequest(run_id=self._run_id)
res = self.stub.get_run(req)
if not res.HasField("run"):
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
self._fab_id = res.run.fab_id
self._fab_ver = res.run.fab_version
self._has_initialized = True

return self.stub, self._run_id

def _check_message(self, message: Message) -> None:
# Check if the message is valid
Expand Down Expand Up @@ -272,7 +248,7 @@ def create_message( # pylint: disable=too-many-arguments
def get_node_ids(self) -> List[int]:
"""Get node IDs."""
grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id()
# Call GrpcDriverHelper method
# Call GrpcDriverStub method
res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id))
return [node.node_id for node in res.nodes]

Expand All @@ -292,7 +268,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
taskins = message_to_taskins(msg)
# Add to list
task_ins_list.append(taskins)
# Call GrpcDriverHelper method
# Call GrpcDriverStub method
res = grpc_driver_helper.push_task_ins(
PushTaskInsRequest(task_ins_list=task_ins_list)
)
Expand Down Expand Up @@ -345,8 +321,8 @@ def send_and_receive(

def close(self) -> None:
"""Disconnect from the SuperLink if connected."""
# Check if GrpcDriverHelper is initialized
if self.driver_helper is None:
# Check if GrpcDriverStub is initialized
if self.stub is None:
return
# Disconnect
self.driver_helper.disconnect()
self.stub.disconnect()
Loading

0 comments on commit 5461e5d

Please sign in to comment.