From e770218d61f6ef47c48a058f13450fd66dfc48d2 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Wed, 6 Nov 2024 20:28:17 +0000 Subject: [PATCH 1/3] reset driver._run --- src/py/flwr/server/serverapp/app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/py/flwr/server/serverapp/app.py b/src/py/flwr/server/serverapp/app.py index 6ae63734d0df..1a50c97e2b09 100644 --- a/src/py/flwr/server/serverapp/app.py +++ b/src/py/flwr/server/serverapp/app.py @@ -189,6 +189,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212 run = run_from_proto(res.run) fab = fab_from_proto(res.fab) + driver._run = None driver.init_run(run.run_id) # Start log uploader for this run From 1fb592e798140a07ae01caaa1bbb35cdecd648fb Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 7 Nov 2024 13:22:56 +0000 Subject: [PATCH 2/3] renaming --- src/py/flwr/server/driver/driver.py | 2 +- src/py/flwr/server/driver/grpc_driver.py | 8 ++------ src/py/flwr/server/driver/grpc_driver_test.py | 2 +- src/py/flwr/server/driver/inmemory_driver.py | 2 +- src/py/flwr/server/driver/inmemory_driver_test.py | 6 +++--- src/py/flwr/server/run_serverapp.py | 4 ++-- src/py/flwr/server/serverapp/app.py | 3 +-- src/py/flwr/simulation/run_simulation.py | 2 +- 8 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index c1b0caf2f378..e7176e4515ec 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -27,7 +27,7 @@ class Driver(ABC): """Abstract base Driver class for the ServerAppIo API.""" @abstractmethod - def init_run(self, run_id: int) -> None: + def set_run(self, run_id: int) -> None: """Request a run to the SuperLink with a given `run_id`. If a Run with the specified `run_id` exists, a local Run diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 25d25f633d28..9dbc72c189f8 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -112,12 +112,8 @@ def _disconnect(self) -> None: channel.close() log(DEBUG, "[Driver] Disconnected") - def init_run(self, run_id: int) -> None: - """Initialize the run.""" - # Check if is initialized - if self._run is not None: - return - + def set_run(self, run_id: int) -> None: + """Set the run.""" # Get the run info req = GetRunRequest(run_id=run_id) res: GetRunResponse = self._stub.GetRun(req) diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 222e46b0f7f1..ad61deadb331 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -59,7 +59,7 @@ def _mock_fn(req: GetRunRequest) -> GetRunResponse: self.driver = GrpcDriver() self.driver._grpc_stub = self.mock_stub # pylint: disable=protected-access self.driver._channel = self.mock_channel # pylint: disable=protected-access - self.driver.init_run(run_id=61016) + self.driver.set_run(run_id=61016) def test_init_grpc_driver(self) -> None: """Test GrpcDriverStub initialization.""" diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 1fe69d79b5da..5c04af644e8c 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -62,7 +62,7 @@ def _check_message(self, message: Message) -> None: ): raise ValueError(f"Invalid message: {message}") - def init_run(self, run_id: int) -> None: + def set_run(self, run_id: int) -> None: """Initialize the run.""" if self._run is not None: return diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index c10c57648900..4c4f50db91e3 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -97,7 +97,7 @@ def setUp(self) -> None: ) state_factory = MagicMock(state=lambda: self.state) self.driver = InMemoryDriver(state_factory=state_factory) - self.driver.init_run(run_id=61016) + self.driver.set_run(run_id=61016) self.driver.state = self.state def test_get_run(self) -> None: @@ -234,7 +234,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: state = LinkStateFactory("").state() run_id = state.create_run("", "", "", {}) self.driver = InMemoryDriver(MagicMock(state=lambda: state)) - self.driver.init_run(run_id=run_id) + self.driver.set_run(run_id=run_id) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, SqliteLinkState) @@ -261,7 +261,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: state = state_factory.state() run_id = state.create_run("", "", "", {}) self.driver = InMemoryDriver(state_factory) - self.driver.init_run(run_id=run_id) + self.driver.set_run(run_id=run_id) msg_ids, node_id = push_messages(self.driver, self.num_nodes) assert isinstance(state, InMemoryLinkState) diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index 2215b87295b6..293a4d649632 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -174,7 +174,7 @@ def run_server_app() -> None: root_certificates=root_certificates, ) flwr_dir = get_flwr_dir(args.flwr_dir) - driver.init_run(args.run_id) + driver.set_run(args.run_id) run_ = driver.run if not run_.fab_hash: raise ValueError("FAB hash not provided.") @@ -204,7 +204,7 @@ def run_server_app() -> None: req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version) res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212 # Fetch full `Run` using `run_id` - driver.init_run(res.run_id) # pylint: disable=W0212 + driver.set_run(res.run_id) # pylint: disable=W0212 run_id = res.run_id # Obtain server app reference and the run config diff --git a/src/py/flwr/server/serverapp/app.py b/src/py/flwr/server/serverapp/app.py index 1a50c97e2b09..8e57a6046b97 100644 --- a/src/py/flwr/server/serverapp/app.py +++ b/src/py/flwr/server/serverapp/app.py @@ -189,8 +189,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212 run = run_from_proto(res.run) fab = fab_from_proto(res.fab) - driver._run = None - driver.init_run(run.run_id) + driver.set_run(run.run_id) # Start log uploader for this run log_uploader = start_log_uploader( diff --git a/src/py/flwr/simulation/run_simulation.py b/src/py/flwr/simulation/run_simulation.py index 88d3fc8b213c..a1312b2e2734 100644 --- a/src/py/flwr/simulation/run_simulation.py +++ b/src/py/flwr/simulation/run_simulation.py @@ -347,7 +347,7 @@ def _main_loop( # Initialize Driver driver = InMemoryDriver(state_factory=state_factory) - driver.init_run(run_id=run.run_id) + driver.set_run(run_id=run.run_id) # Get and run ServerApp thread serverapp_th = run_serverapp_th( From 8766a1b483fddaa5f0cb9546968d5ec76758cecc Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 7 Nov 2024 20:41:44 +0000 Subject: [PATCH 3/3] amend in mem driver --- src/py/flwr/server/driver/inmemory_driver.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 5c04af644e8c..d9189002e5ad 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -64,8 +64,6 @@ def _check_message(self, message: Message) -> None: def set_run(self, run_id: int) -> None: """Initialize the run.""" - if self._run is not None: - return run = self.state.get_run(run_id) if run is None: raise RuntimeError(f"Cannot find the run with ID: {run_id}")