diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index f2a43f0e2c8e..cb5a617164ff 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -16,11 +16,9 @@ import time import warnings -from logging import DEBUG, ERROR, WARNING +from logging import DEBUG, ERROR from typing import Iterable, List, Optional, Tuple -import grpc - from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event from flwr.common.grpc import create_channel from flwr.common.logger import log @@ -48,103 +46,94 @@ ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """ [Driver] Error: Not connected. -Call `connect()` on the `GrpcDriverHelper` instance before calling any of the other -`GrpcDriverHelper` methods. +Call `connect()` on the `GrpcDriverStub` instance before calling any of the other +`GrpcDriverStub` methods. """ -class GrpcDriverHelper: - """`GrpcDriverHelper` provides access to the gRPC Driver API/service.""" +class GrpcDriverStub(DriverStub): + """`GrpcDriverStub` provides access to the gRPC Driver API/service.""" def __init__( self, driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, root_certificates: Optional[bytes] = None, ) -> None: + event(EventType.DRIVER_CONNECT) self.driver_service_address = driver_service_address self.root_certificates = root_certificates - self.channel: Optional[grpc.Channel] = None - self.stub: Optional[DriverStub] = None - - def connect(self) -> None: - """Connect to the Driver API.""" - event(EventType.DRIVER_CONNECT) - if self.channel is not None or self.stub is not None: - log(WARNING, "Already connected") - return self.channel = create_channel( server_address=self.driver_service_address, insecure=(self.root_certificates is None), root_certificates=self.root_certificates, ) - self.stub = DriverStub(self.channel) + super().__init__(self.channel) log(DEBUG, "[Driver] Connected to %s", self.driver_service_address) def disconnect(self) -> None: """Disconnect from the Driver API.""" event(EventType.DRIVER_DISCONNECT) - if self.channel is None or self.stub is None: + if self.channel is None: log(DEBUG, "Already disconnected") return channel = self.channel self.channel = None - self.stub = None channel.close() log(DEBUG, "[Driver] Disconnected") def create_run(self, req: CreateRunRequest) -> CreateRunResponse: """Request for run ID.""" # Check if channel is open - if self.stub is None: + if self.channel is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call Driver API - res: CreateRunResponse = self.stub.CreateRun(request=req) + res: CreateRunResponse = self.CreateRun(request=req) return res def get_run(self, req: GetRunRequest) -> GetRunResponse: """Get run information.""" # Check if channel is open - if self.stub is None: + if self.channel is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API - res: GetRunResponse = self.stub.GetRun(request=req) + res: GetRunResponse = self.GetRun(request=req) return res def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: """Get client IDs.""" # Check if channel is open - if self.stub is None: + if self.channel is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API - res: GetNodesResponse = self.stub.GetNodes(request=req) + res: GetNodesResponse = self.GetNodes(request=req) return res def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: """Schedule tasks.""" # Check if channel is open - if self.stub is None: + if self.channel is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call gRPC Driver API - res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) + res: PushTaskInsResponse = self.PushTaskIns(request=req) return res def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: """Get task results.""" # Check if channel is open - if self.stub is None: + if self.channel is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise ConnectionError("`GrpcDriverHelper` instance not connected") + raise ConnectionError("`GrpcDriverStub` instance not connected") # Call Driver API - res: PullTaskResResponse = self.stub.PullTaskRes(request=req) + res: PullTaskResResponse = self.PullTaskRes(request=req) return res @@ -172,18 +161,14 @@ class GrpcDriver(Driver): def __init__( # pylint: disable=too-many-arguments self, - driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER, - root_certificates: Optional[bytes] = None, - fab_id: Optional[str] = None, - fab_version: Optional[str] = None, - run_id: Optional[int] = None, + run_id: int, + stub: Optional[GrpcDriverStub] = None, ) -> None: - self.addr = driver_service_address - self.root_certificates = root_certificates - self.driver_helper: Optional[GrpcDriverHelper] = None + self.stub = stub self._run_id = run_id - self._fab_id = fab_id if fab_id is not None else "" - self._fab_ver = fab_version if fab_version is not None else "" + self._fab_id = "" + self._fab_ver = "" + self._has_initialized = False self.node = Node(node_id=0, anonymous=True) @property @@ -196,32 +181,23 @@ def run(self) -> Run: fab_version=self._fab_ver, ) - def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverHelper, int]: - # Check if the GrpcDriverHelper is initialized - if self.driver_helper is None or self._run_id is None: + def _get_grpc_driver_helper_and_run_id(self) -> Tuple[GrpcDriverStub, int]: + # Check if the GrpcDriverStub is initialized + if not self._has_initialized or self.stub is None: # Connect and create run - self.driver_helper = GrpcDriverHelper( - driver_service_address=self.addr, - root_certificates=self.root_certificates, - ) - self.driver_helper.connect() - # Create the run if the run_id is not provided - if self._run_id is None: - create_run_req = CreateRunRequest( - fab_id=self._fab_id, fab_version=self._fab_ver - ) - create_run_res = self.driver_helper.create_run(create_run_req) - self._run_id = create_run_res.run_id - # Get the run if the run_id is provided - else: - get_run_req = GetRunRequest(run_id=self._run_id) - get_run_res = self.driver_helper.get_run(get_run_req) - if not get_run_res.HasField("run"): - raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") - self._fab_id = get_run_res.run.fab_id - self._fab_ver = get_run_res.run.fab_version - - return self.driver_helper, self._run_id + if self.stub is None: + self.stub = GrpcDriverStub() + + # Get the run info + req = GetRunRequest(run_id=self._run_id) + res = self.stub.get_run(req) + if not res.HasField("run"): + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") + self._fab_id = res.run.fab_id + self._fab_ver = res.run.fab_version + self._has_initialized = True + + return self.stub, self._run_id def _check_message(self, message: Message) -> None: # Check if the message is valid @@ -272,7 +248,7 @@ def create_message( # pylint: disable=too-many-arguments def get_node_ids(self) -> List[int]: """Get node IDs.""" grpc_driver_helper, run_id = self._get_grpc_driver_helper_and_run_id() - # Call GrpcDriverHelper method + # Call GrpcDriverStub method res = grpc_driver_helper.get_nodes(GetNodesRequest(run_id=run_id)) return [node.node_id for node in res.nodes] @@ -292,7 +268,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: taskins = message_to_taskins(msg) # Add to list task_ins_list.append(taskins) - # Call GrpcDriverHelper method + # Call GrpcDriverStub method res = grpc_driver_helper.push_task_ins( PushTaskInsRequest(task_ins_list=task_ins_list) ) @@ -345,8 +321,8 @@ def send_and_receive( def close(self) -> None: """Disconnect from the SuperLink if connected.""" - # Check if GrpcDriverHelper is initialized - if self.driver_helper is None: + # Check if GrpcDriverStub is initialized + if self.stub is None: return # Disconnect - self.driver_helper.disconnect() + self.stub.disconnect() diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 642fdbe9d8ab..775348437da6 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -36,74 +36,36 @@ class TestGrpcDriver(unittest.TestCase): """Tests for `GrpcDriver` class.""" def setUp(self) -> None: - """Initialize mock GrpcDriverHelper and Driver instance before each test.""" - mock_response = Mock() - mock_response.run_id = 61016 - self.mock_grpc_driver_helper = Mock() - self.mock_grpc_driver_helper.create_run.return_value = mock_response - self.patcher = patch( - "flwr.server.driver.grpc_driver.GrpcDriverHelper", - return_value=self.mock_grpc_driver_helper, + """Initialize mock GrpcDriverStub and Driver instance before each test.""" + mock_response = Mock( + run=Mock(run_id=61016, fab_id="mock/mock", fab_version="v1.0.0") ) - self.patcher.start() - self.driver = GrpcDriver() - - def tearDown(self) -> None: - """Cleanup after each test.""" - self.patcher.stop() - - def test_get_run(self) -> None: - """Test the GrpcDriver starting with run_id.""" - # Prepare - self.driver._run_id = 61016 # pylint: disable=protected-access - mock_response = Mock() - mock_response.run = Mock() - mock_response.run.run_id = 61016 - mock_response.run.fab_id = "mock/mock" - mock_response.run.fab_version = "v1.0.0" - self.mock_grpc_driver_helper.get_run.return_value = mock_response + self.mock_grpc_driver_stub = Mock() + self.mock_grpc_driver_stub.get_run.return_value = mock_response + self.mock_grpc_driver_stub.HasField.return_value = True + self.driver = GrpcDriver(run_id=61016, stub=self.mock_grpc_driver_stub) + def test_init_grpc_driver(self) -> None: + """Test GrpcDriverStub initialization.""" # Assert self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") self.assertEqual(self.driver.run.fab_version, "v1.0.0") - - def test_check_and_init_grpc_driver_already_initialized(self) -> None: - """Test that GrpcDriverHelper doesn't initialize if run is created.""" - # Prepare - self.driver.driver_helper = self.mock_grpc_driver_helper - self.driver._run_id = 61016 # pylint: disable=protected-access - - # Execute - # pylint: disable-next=protected-access - self.driver._get_grpc_driver_helper_and_run_id() - - # Assert - self.mock_grpc_driver_helper.connect.assert_not_called() - - def test_check_and_init_grpc_driver_needs_initialization(self) -> None: - """Test GrpcDriverHelper initialization when run is not created.""" - # Execute - # pylint: disable-next=protected-access - self.driver._get_grpc_driver_helper_and_run_id() - - # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() - self.assertEqual(self.driver.run.run_id, 61016) + self.mock_grpc_driver_stub.get_run.assert_called_once() def test_get_nodes(self) -> None: """Test retrieval of nodes.""" # Prepare mock_response = Mock() mock_response.nodes = [Mock(node_id=404), Mock(node_id=200)] - self.mock_grpc_driver_helper.get_nodes.return_value = mock_response + self.mock_grpc_driver_stub.get_nodes.return_value = mock_response # Execute node_ids = self.driver.get_node_ids() - args, kwargs = self.mock_grpc_driver_helper.get_nodes.call_args + args, kwargs = self.mock_grpc_driver_stub.get_nodes.call_args # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() + self.mock_grpc_driver_stub.get_run.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], GetNodesRequest) @@ -114,7 +76,7 @@ def test_push_messages_valid(self) -> None: """Test pushing valid messages.""" # Prepare mock_response = Mock(task_ids=["id1", "id2"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response msgs = [ self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) for _ in range(2) @@ -122,10 +84,10 @@ def test_push_messages_valid(self) -> None: # Execute msg_ids = self.driver.push_messages(msgs) - args, kwargs = self.mock_grpc_driver_helper.push_task_ins.call_args + args, kwargs = self.mock_grpc_driver_stub.push_task_ins.call_args # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() + self.mock_grpc_driver_stub.get_run.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PushTaskInsRequest) @@ -137,7 +99,7 @@ def test_push_messages_invalid(self) -> None: """Test pushing invalid messages.""" # Prepare mock_response = Mock(task_ids=["id1", "id2"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response msgs = [ self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL) for _ in range(2) @@ -161,16 +123,16 @@ def test_pull_messages_with_given_message_ids(self) -> None: ), TaskRes(task=Task(ancestry=["id3"], error=error_to_proto(Error(code=0)))), ] - self.mock_grpc_driver_helper.pull_task_res.return_value = mock_response + self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response msg_ids = ["id1", "id2", "id3"] # Execute msgs = self.driver.pull_messages(msg_ids) reply_tos = {msg.metadata.reply_to_message for msg in msgs} - args, kwargs = self.mock_grpc_driver_helper.pull_task_res.call_args + args, kwargs = self.mock_grpc_driver_stub.pull_task_res.call_args # Assert - self.mock_grpc_driver_helper.connect.assert_called_once() + self.mock_grpc_driver_stub.get_run.assert_called_once() self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PullTaskResRequest) @@ -181,14 +143,14 @@ def test_send_and_receive_messages_complete(self) -> None: """Test send and receive all messages successfully.""" # Prepare mock_response = Mock(task_ids=["id1"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response # The response message must include either `content` (i.e. a recordset) or # an `Error`. We choose the latter in this case error_proto = error_to_proto(Error(code=0)) mock_response = Mock( task_res_list=[TaskRes(task=Task(ancestry=["id1"], error=error_proto))] ) - self.mock_grpc_driver_helper.pull_task_res.return_value = mock_response + self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute @@ -203,9 +165,9 @@ def test_send_and_receive_messages_timeout(self) -> None: # Prepare sleep_fn = time.sleep mock_response = Mock(task_ids=["id1"]) - self.mock_grpc_driver_helper.push_task_ins.return_value = mock_response + self.mock_grpc_driver_stub.push_task_ins.return_value = mock_response mock_response = Mock(task_res_list=[]) - self.mock_grpc_driver_helper.pull_task_res.return_value = mock_response + self.mock_grpc_driver_stub.pull_task_res.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute @@ -227,12 +189,15 @@ def test_del_with_initialized_driver(self) -> None: self.driver.close() # Assert - self.mock_grpc_driver_helper.disconnect.assert_called_once() + self.mock_grpc_driver_stub.disconnect.assert_called_once() def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" + # Prepare + self.driver.stub = None + # Execute self.driver.close() # Assert - self.mock_grpc_driver_helper.disconnect.assert_not_called() + self.mock_grpc_driver_stub.disconnect.assert_not_called() diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index eda4051b5a70..96dc8a2a5716 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -17,7 +17,7 @@ import time import warnings -from typing import Iterable, List, Optional, cast +from typing import Iterable, List, Optional from uuid import UUID from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet @@ -44,14 +44,13 @@ class InMemoryDriver(Driver): def __init__( self, + run_id: int, state_factory: StateFactory, - fab_id: Optional[str] = None, - fab_version: Optional[str] = None, - run_id: Optional[int] = None, ) -> None: self._run_id = run_id - self._fab_id = fab_id - self._fab_ver = fab_version + self._fab_id = "" + self._fab_ver = "" + self._has_initialized = False self.node = Node(node_id=0, anonymous=True) self.state = state_factory.state() @@ -68,29 +67,23 @@ def _check_message(self, message: Message) -> None: def _init_run(self) -> None: """Initialize the run.""" - # Run ID is not provided - if self._run_id is None: - self._fab_id = "" if self._fab_id is None else self._fab_id - self._fab_ver = "" if self._fab_ver is None else self._fab_ver - self._run_id = self.state.create_run( - fab_id=self._fab_id, fab_version=self._fab_ver - ) - # Run ID is provided - elif self._fab_id is None or self._fab_ver is None: - run = self.state.get_run(self._run_id) - if run is None: - raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") - self._fab_id = run.fab_id - self._fab_ver = run.fab_version + if self._has_initialized: + return + run = self.state.get_run(self._run_id) + if run is None: + raise RuntimeError(f"Cannot find the run with ID: {self._run_id}") + self._fab_id = run.fab_id + self._fab_ver = run.fab_version + self._has_initialized = True @property def run(self) -> Run: """Run ID.""" self._init_run() return Run( - run_id=cast(int, self._run_id), - fab_id=cast(str, self._fab_id), - fab_version=cast(str, self._fab_ver), + run_id=self._run_id, + fab_id=self._fab_id, + fab_version=self._fab_ver, ) def create_message( # pylint: disable=too-many-arguments diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 8bfa14def0c2..1f457decf228 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -32,7 +32,7 @@ recordset_to_proto, ) from flwr.proto.task_pb2 import Task, TaskRes # pylint: disable=E0611 -from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.state import InMemoryState, SqliteState, StateFactory from .inmemory_driver import InMemoryDriver @@ -84,19 +84,15 @@ def setUp(self) -> None: int.from_bytes(os.urandom(8), "little", signed=True) for _ in range(self.num_nodes) ] - state_factory = MagicMock() - state_factory.state.return_value = self.state - self.driver = InMemoryDriver(state_factory) - self.driver.state = self.state - - def test_get_run(self) -> None: - """Test the InMemoryDriver starting with run_id.""" - # Prepare - self.driver._run_id = 61016 # pylint: disable=protected-access self.state.get_run.return_value = MagicMock( run_id=61016, fab_id="mock/mock", fab_version="v1.0.0" ) + state_factory = MagicMock(state=lambda: self.state) + self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory) + self.driver.state = self.state + def test_get_run(self) -> None: + """Test the InMemoryDriver starting with run_id.""" # Assert self.assertEqual(self.driver.run.run_id, 61016) self.assertEqual(self.driver.run.fab_id, "mock/mock") @@ -224,19 +220,23 @@ def test_send_and_receive_messages_timeout(self) -> None: def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: """Test tasks are deleted in sqlite state once messages are pulled.""" # Prepare - self.driver = InMemoryDriver(StateFactory("")) + state = StateFactory("").state() + self.driver = InMemoryDriver( + state.create_run("", ""), MagicMock(state=lambda: state) + ) msg_ids, node_id = push_messages(self.driver, self.num_nodes) + assert isinstance(state, SqliteState) # Check recorded - task_ins = self.driver.state.query("SELECT * FROM task_ins;") # type: ignore + task_ins = state.query("SELECT * FROM task_ins;") self.assertEqual(len(task_ins), len(list(msg_ids))) # Prepare: create replies reply_tos = get_replies(self.driver, msg_ids, node_id) # Query number of task_ins and task_res in State - task_res = self.driver.state.query("SELECT * FROM task_res;") # type: ignore - task_ins = self.driver.state.query("SELECT * FROM task_ins;") # type: ignore + task_res = state.query("SELECT * FROM task_res;") + task_ins = state.query("SELECT * FROM task_ins;") # Assert self.assertEqual(reply_tos, msg_ids) @@ -246,18 +246,19 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: """Test tasks are deleted in in-memory state once messages are pulled.""" # Prepare - self.driver = InMemoryDriver(StateFactory(":flwr-in-memory-state:")) + state_factory = StateFactory(":flwr-in-memory-state:") + state = state_factory.state() + self.driver = InMemoryDriver(state.create_run("", ""), state_factory) msg_ids, node_id = push_messages(self.driver, self.num_nodes) + assert isinstance(state, InMemoryState) # Check recorded - self.assertEqual( - len(self.driver.state.task_ins_store), len(list(msg_ids)) # type: ignore - ) + self.assertEqual(len(state.task_ins_store), len(list(msg_ids))) # Prepare: create replies reply_tos = get_replies(self.driver, msg_ids, node_id) # Assert self.assertEqual(reply_tos, msg_ids) - self.assertEqual(len(self.driver.state.task_res_store), 0) # type: ignore - self.assertEqual(len(self.driver.state.task_ins_store), 0) # type: ignore + self.assertEqual(len(state.task_res_store), 0) + self.assertEqual(len(state.task_ins_store), 0) diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index fd0214a040bc..25a523229f52 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -24,8 +24,10 @@ from flwr.common import Context, EventType, RecordSet, event from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app +from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611 -from .driver import Driver, GrpcDriver +from .driver import Driver +from .driver.grpc_driver import GrpcDriver, GrpcDriverStub from .server_app import LoadServerAppError, ServerApp ADDRESS_DRIVER_API = "0.0.0.0:9091" @@ -147,13 +149,15 @@ def run_server_app() -> None: server_app_dir = args.dir server_app_attr = getattr(args, "server-app") - # Initialize GrpcDriver - driver = GrpcDriver( - driver_service_address=args.superlink, - root_certificates=root_certificates, - fab_id=args.fab_id, - fab_version=args.fab_version, + # Create run + stub = GrpcDriverStub( + driver_service_address=args.superlink, root_certificates=root_certificates ) + req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version) + res = stub.create_run(req) + + # Initialize GrpcDriver + driver = GrpcDriver(run_id=res.run_id, stub=stub) # Run the ServerApp with the Driver run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr) diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 3e5eb266c89a..6785f3ac38b6 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -201,8 +201,11 @@ def _main_loop( f_stop = asyncio.Event() serverapp_th = None try: + # Create run (with empty fab_id and fab_version) + run_id = state_factory.state().create_run("", "") + # Initialize Driver - driver = InMemoryDriver(state_factory) + driver = InMemoryDriver(run_id=run_id, state_factory=state_factory) if run_id: _init_run_id(driver, state_factory, run_id)