Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Implement DriverAPI GetRun #3580

Merged
merged 39 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
82a399a
add run proto
panh99 Jun 11, 2024
019d490
implement getrun in driver_servicer and grpc_driver
panh99 Jun 11, 2024
54b185a
Merge branch 'main' into impl-get-run-driver
panh99 Jun 12, 2024
e6cbede
Merge branch 'main' into impl-get-run-driver
panh99 Jun 13, 2024
bf3cd3b
Merge branch 'main' into impl-get-run-driver
panh99 Jun 14, 2024
5bc1c9c
Merge branch 'main' into impl-get-run-driver
danieljanes Jun 18, 2024
7c82082
amend driver class and in mem driver
panh99 Jun 18, 2024
f9f8a10
Merge branch 'main' into impl-get-run-driver
danieljanes Jun 18, 2024
1150474
Merge branch 'main' into impl-get-run-driver
danieljanes Jun 18, 2024
6ef8ceb
update with main
panh99 Jun 19, 2024
bf10f81
fix the test for in mem driver
panh99 Jun 19, 2024
b57eeb7
Merge branch 'main' into impl-get-run-driver
panh99 Jun 19, 2024
0a09223
Merge branch 'main' into impl-get-run-driver
panh99 Jun 19, 2024
5461e5d
make run_id mandatory
panh99 Jun 19, 2024
77f8927
Merge branch 'main' into impl-get-run-driver
panh99 Jun 19, 2024
dcc6cfd
fix a bug in _init_run_id in simulation
panh99 Jun 19, 2024
dc03386
format
panh99 Jun 19, 2024
4d73b9c
Merge branch 'main' into impl-get-run-driver
danieljanes Jun 19, 2024
5aa7ced
update doc string
panh99 Jun 19, 2024
30aa173
Merge remote-tracking branch 'refs/remotes/origin/impl-get-run-driver…
panh99 Jun 19, 2024
084e22c
update doc string
panh99 Jun 19, 2024
db16408
update GrpcDriverStub
panh99 Jun 19, 2024
6c9657e
use _run & _run_id
panh99 Jun 20, 2024
fe31189
Merge branch 'main' into impl-get-run-driver
panh99 Jun 20, 2024
5110873
update sim
panh99 Jun 20, 2024
d60d295
fix in run_simulation() (#3654)
jafermarq Jun 20, 2024
e65af5f
fix doc string
panh99 Jun 20, 2024
f914b2c
Update src/py/flwr/server/driver/grpc_driver.py
danieljanes Jun 20, 2024
c62f496
fix naming conflicts
panh99 Jun 20, 2024
729d418
Merge remote-tracking branch 'refs/remotes/origin/impl-get-run-driver…
panh99 Jun 20, 2024
b8515e3
fix a bug that driver stub not connected
panh99 Jun 20, 2024
fe5a8db
quick fix
panh99 Jun 20, 2024
aeebcfb
Merge branch 'main' into impl-get-run-driver
panh99 Jun 20, 2024
25e6354
update naming
panh99 Jun 20, 2024
62c3c81
Merge remote-tracking branch 'refs/remotes/origin/impl-get-run-driver…
panh99 Jun 20, 2024
05945f0
fix unit tests
panh99 Jun 20, 2024
1efc53d
fix get_run
panh99 Jun 20, 2024
2cdaf83
update in mem driver
panh99 Jun 20, 2024
bba27ea
Merge branch 'main' into impl-get-run-driver
panh99 Jun 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/py/flwr/server/compat/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _update_client_manager(
node_id=node_id,
driver=driver,
anonymous=False,
run_id=driver.run_id, # type: ignore
run_id=driver.run.run_id,
)
if client_manager.register(client_proxy):
registered_nodes[node_id] = client_proxy
Expand Down
6 changes: 6 additions & 0 deletions src/py/flwr/server/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@
from typing import Iterable, List, Optional

from flwr.common import Message, RecordSet
from flwr.common.typing import Run


class Driver(ABC):
"""Abstract base Driver class for the Driver API."""

@property
@abstractmethod
def run(self) -> Run:
"""Run information."""

@abstractmethod
def create_message( # pylint: disable=too-many-arguments
self,
Expand Down
57 changes: 47 additions & 10 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.serde import message_from_taskres, message_to_taskins
from flwr.common.typing import Run
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
CreateRunRequest,
CreateRunResponse,
Expand All @@ -37,6 +38,7 @@
)
from flwr.proto.driver_pb2_grpc import DriverStub # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611

from .driver import Driver
Expand Down Expand Up @@ -101,6 +103,17 @@ def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
res: CreateRunResponse = self.stub.CreateRun(request=req)
return res

def get_run(self, req: GetRunRequest) -> GetRunResponse:
"""Get run information."""
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise ConnectionError("`GrpcDriverHelper` instance not connected")

# Call gRPC Driver API
res: GetRunResponse = self.stub.GetRun(request=req)
return res

def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
"""Get client IDs."""
# Check if channel is open
Expand Down Expand Up @@ -157,39 +170,63 @@ class GrpcDriver(Driver):
The version of the FAB used in the run.
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
root_certificates: Optional[bytes] = None,
fab_id: Optional[str] = None,
fab_version: Optional[str] = None,
run_id: Optional[int] = None,
) -> None:
self.addr = driver_service_address
self.root_certificates = root_certificates
self.driver_helper: Optional[GrpcDriverHelper] = None
self.run_id: Optional[int] = None
self.fab_id = fab_id if fab_id is not None else ""
self.fab_version = fab_version if fab_version is not None else ""
self._run_id = run_id
self._fab_id = fab_id if fab_id is not None else ""
self._fab_ver = fab_version if fab_version is not None else ""
self.node = Node(node_id=0, anonymous=True)

@property
def run(self) -> Run:
"""Run information."""
_, run_id = self._get_grpc_driver_helper_and_run_id()
return Run(
run_id=run_id,
fab_id=self._fab_id,
fab_version=self._fab_ver,
)

def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]:
# Check if the GrpcDriverHelper is initialized
if self.driver_helper is None or self.run_id is None:
if self.driver_helper is None or self._run_id is None:
# Connect and create run
self.driver_helper = GrpcDriverHelper(
driver_service_address=self.addr,
root_certificates=self.root_certificates,
)
self.driver_helper.connect()
req = CreateRunRequest(fab_id=self.fab_id, fab_version=self.fab_version)
res = self.driver_helper.create_run(req)
self.run_id = res.run_id
return self.driver_helper, self.run_id
# Create the run if the run_id is not provided
if self._run_id is None:
create_run_req = CreateRunRequest(
fab_id=self._fab_id, fab_version=self._fab_ver
)
create_run_res = self.driver_helper.create_run(create_run_req)
self._run_id = create_run_res.run_id
# Get the run if the run_id is provided
else:
get_run_req = GetRunRequest(run_id=self._run_id)
get_run_res = self.driver_helper.get_run(get_run_req)
if not get_run_res.HasField("run"):
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
self._fab_id = get_run_res.run.fab_id
self._fab_ver = get_run_res.run.fab_version

return self.driver_helper, self._run_id

def _check_message(self, message: Message) -> None:
# Check if the message is valid
if not (
message.metadata.run_id == self.run_id
message.metadata.run_id == self._run_id
and message.metadata.src_node_id == self.node.node_id
and message.metadata.message_id == ""
and message.metadata.reply_to_message == ""
Expand Down
20 changes: 18 additions & 2 deletions src/py/flwr/server/driver/grpc_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,27 @@ def tearDown(self) -> None:
"""Cleanup after each test."""
self.patcher.stop()

def test_get_run(self) -> None:
"""Test the GrpcDriver starting with run_id."""
# Prepare
self.driver._run_id = 61016 # pylint: disable=protected-access
mock_response = Mock()
mock_response.run = Mock()
mock_response.run.run_id = 61016
mock_response.run.fab_id = "mock/mock"
mock_response.run.fab_version = "v1.0.0"
self.mock_grpc_driver_helper.get_run.return_value = mock_response

# Assert
self.assertEqual(self.driver.run.run_id, 61016)
self.assertEqual(self.driver.run.fab_id, "mock/mock")
self.assertEqual(self.driver.run.fab_version, "v1.0.0")

def test_check_and_init_grpc_driver_already_initialized(self) -> None:
"""Test that GrpcDriverHelper doesn't initialize if run is created."""
# Prepare
self.driver.driver_helper = self.mock_grpc_driver_helper
self.driver.run_id = 61016
self.driver._run_id = 61016 # pylint: disable=protected-access

# Execute
# pylint: disable-next=protected-access
Expand All @@ -73,7 +89,7 @@ def test_check_and_init_grpc_driver_needs_initialization(self) -> None:

# Assert
self.mock_grpc_driver_helper.connect.assert_called_once()
self.assertEqual(self.driver.run_id, 61016)
self.assertEqual(self.driver.run.run_id, 61016)

def test_get_nodes(self) -> None:
"""Test retrieval of nodes."""
Expand Down
52 changes: 34 additions & 18 deletions src/py/flwr/server/driver/inmemory_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

import time
import warnings
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, cast
from uuid import UUID

from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
from flwr.common.serde import message_from_taskres, message_to_taskins
from flwr.common.typing import Run
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.server.superlink.state import StateFactory

Expand All @@ -46,34 +47,51 @@ def __init__(
state_factory: StateFactory,
fab_id: Optional[str] = None,
fab_version: Optional[str] = None,
run_id: Optional[int] = None,
) -> None:
self.run_id: Optional[int] = None
self.fab_id = fab_id if fab_id is not None else ""
self.fab_version = fab_version if fab_version is not None else ""
self._run_id = run_id
self._fab_id = fab_id
self._fab_ver = fab_version
self.node = Node(node_id=0, anonymous=True)
self.state = state_factory.state()

def _check_message(self, message: Message) -> None:
# Check if the message is valid
if not (
message.metadata.run_id == self.run_id
message.metadata.run_id == self.run.run_id
and message.metadata.src_node_id == self.node.node_id
and message.metadata.message_id == ""
and message.metadata.reply_to_message == ""
and message.metadata.ttl > 0
):
raise ValueError(f"Invalid message: {message}")

def _get_run_id(self) -> int:
"""Return run_id.

If unset, create a new run.
"""
if self.run_id is None:
self.run_id = self.state.create_run(
fab_id=self.fab_id, fab_version=self.fab_version
def _init_run(self) -> None:
"""Initialize the run."""
# Run ID is not provided
if self._run_id is None:
self._fab_id = "" if self._fab_id is None else self._fab_id
self._fab_ver = "" if self._fab_ver is None else self._fab_ver
self._run_id = self.state.create_run(
fab_id=self._fab_id, fab_version=self._fab_ver
)
return self.run_id
# Run ID is provided
elif self._fab_id is None or self._fab_ver is None:
run = self.state.get_run(self._run_id)
if run is None:
raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
self._fab_id = run.fab_id
self._fab_ver = run.fab_version

@property
def run(self) -> Run:
"""Run ID."""
self._init_run()
return Run(
run_id=cast(int, self._run_id),
fab_id=cast(str, self._fab_id),
fab_version=cast(str, self._fab_ver),
)

def create_message( # pylint: disable=too-many-arguments
self,
Expand All @@ -88,7 +106,6 @@ def create_message( # pylint: disable=too-many-arguments
This method constructs a new `Message` with given content and metadata.
The `run_id` and `src_node_id` will be set automatically.
"""
run_id = self._get_run_id()
if ttl:
warnings.warn(
"A custom TTL was set, but note that the SuperLink does not enforce "
Expand All @@ -99,7 +116,7 @@ def create_message( # pylint: disable=too-many-arguments
ttl_ = DEFAULT_TTL if ttl is None else ttl

metadata = Metadata(
run_id=run_id,
run_id=self.run.run_id,
message_id="", # Will be set by the server
src_node_id=self.node.node_id,
dst_node_id=dst_node_id,
Expand All @@ -112,8 +129,7 @@ def create_message( # pylint: disable=too-many-arguments

def get_node_ids(self) -> List[int]:
"""Get node IDs."""
run_id = self._get_run_id()
return list(self.state.get_nodes(run_id))
return list(self.state.get_nodes(self.run.run_id))

def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
"""Push messages to specified node IDs.
Expand Down
34 changes: 25 additions & 9 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,28 @@ def setUp(self) -> None:
"""
# Create driver
self.num_nodes = 42
self.driver = InMemoryDriver(StateFactory(""))
self.driver.state = MagicMock()
self.driver.state.get_nodes.return_value = [
self.state = MagicMock()
self.state.get_nodes.return_value = [
int.from_bytes(os.urandom(8), "little", signed=True)
for _ in range(self.num_nodes)
]
state_factory = MagicMock()
state_factory.state.return_value = self.state
self.driver = InMemoryDriver(state_factory)
self.driver.state = self.state

def test_get_run(self) -> None:
"""Test the InMemoryDriver starting with run_id."""
# Prepare
self.driver._run_id = 61016 # pylint: disable=protected-access
self.state.get_run.return_value = MagicMock(
run_id=61016, fab_id="mock/mock", fab_version="v1.0.0"
)

# Assert
self.assertEqual(self.driver.run.run_id, 61016)
self.assertEqual(self.driver.run.fab_id, "mock/mock")
self.assertEqual(self.driver.run.fab_version, "v1.0.0")

def test_get_nodes(self) -> None:
"""Test retrieval of nodes."""
Expand All @@ -104,7 +120,7 @@ def test_push_messages_valid(self) -> None:
]

taskins_ids = [uuid4() for _ in range(num_messages)]
self.driver.state.store_task_ins.side_effect = taskins_ids # type: ignore
self.state.store_task_ins.side_effect = taskins_ids

# Execute
msg_ids = list(self.driver.push_messages(msgs))
Expand Down Expand Up @@ -141,7 +157,7 @@ def test_pull_messages_with_given_message_ids(self) -> None:
task=Task(ancestry=[msg_ids[1]], error=error_to_proto(Error(code=0)))
),
]
self.driver.state.get_task_res.return_value = task_res_list # type: ignore
self.state.get_task_res.return_value = task_res_list

# Execute
pulled_msgs = list(self.driver.pull_messages(msg_ids))
Expand All @@ -167,8 +183,8 @@ def test_send_and_receive_messages_complete(self) -> None:
task=Task(ancestry=[msg_ids[1]], error=error_to_proto(Error(code=0)))
),
]
self.driver.state.store_task_ins.side_effect = msg_ids # type: ignore
self.driver.state.get_task_res.return_value = task_res_list # type: ignore
self.state.store_task_ins.side_effect = msg_ids
self.state.get_task_res.return_value = task_res_list

# Execute
ret_msgs = list(self.driver.send_and_receive(msgs))
Expand All @@ -193,8 +209,8 @@ def test_send_and_receive_messages_timeout(self) -> None:
task=Task(ancestry=[msg_ids[1]], error=error_to_proto(Error(code=0)))
),
]
self.driver.state.store_task_ins.side_effect = msg_ids # type: ignore
self.driver.state.get_task_res.return_value = task_res_list # type: ignore
self.state.store_task_ins.side_effect = msg_ids
self.state.get_task_res.return_value = task_res_list

# Execute
with patch("time.sleep", side_effect=lambda t: time.sleep(t * 0.01)):
Expand Down
16 changes: 14 additions & 2 deletions src/py/flwr/server/superlink/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
PushTaskInsResponse,
)
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
GetRunRequest,
GetRunResponse,
Run,
)
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
from flwr.server.superlink.state import State, StateFactory
from flwr.server.utils.validator import validate_task_ins_or_res
Expand Down Expand Up @@ -134,7 +138,15 @@ def GetRun(
self, request: GetRunRequest, context: grpc.ServicerContext
) -> GetRunResponse:
"""Get run information."""
raise NotImplementedError
log(DEBUG, "DriverServicer.GetRun")

# Init state
state: State = self.state_factory.state()

# Retrieve run information
run = state.get_run(request.run_id)
run_proto = None if run is None else Run(**vars(run))
return GetRunResponse(run=run_proto)


def _raise_if(validation_error: bool, detail: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _init_run_id(driver: InMemoryDriver, state: StateFactory, run_id: int) -> No
"""Create a run with a given `run_id`."""
log(DEBUG, "Pre-registering run with id %s", run_id)
state.state().run_ids[run_id] = ("", "") # type: ignore
driver.run_id = run_id
driver._run_id = run_id # pylint: disable=protected-access


# pylint: disable=too-many-locals
Expand Down