Skip to content

Commit

Permalink
renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Nov 7, 2024
1 parent e770218 commit 1fb592e
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/py/flwr/server/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Driver(ABC):
"""Abstract base Driver class for the ServerAppIo API."""

@abstractmethod
def init_run(self, run_id: int) -> None:
def set_run(self, run_id: int) -> None:
"""Request a run to the SuperLink with a given `run_id`.
If a Run with the specified `run_id` exists, a local Run
Expand Down
8 changes: 2 additions & 6 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,8 @@ def _disconnect(self) -> None:
channel.close()
log(DEBUG, "[Driver] Disconnected")

def init_run(self, run_id: int) -> None:
"""Initialize the run."""
# Check if is initialized
if self._run is not None:
return

def set_run(self, run_id: int) -> None:
"""Set the run."""
# Get the run info
req = GetRunRequest(run_id=run_id)
res: GetRunResponse = self._stub.GetRun(req)
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/driver/grpc_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _mock_fn(req: GetRunRequest) -> GetRunResponse:
self.driver = GrpcDriver()
self.driver._grpc_stub = self.mock_stub # pylint: disable=protected-access
self.driver._channel = self.mock_channel # pylint: disable=protected-access
self.driver.init_run(run_id=61016)
self.driver.set_run(run_id=61016)

def test_init_grpc_driver(self) -> None:
"""Test GrpcDriverStub initialization."""
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/driver/inmemory_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _check_message(self, message: Message) -> None:
):
raise ValueError(f"Invalid message: {message}")

def init_run(self, run_id: int) -> None:
def set_run(self, run_id: int) -> None:
"""Initialize the run."""
if self._run is not None:
return
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def setUp(self) -> None:
)
state_factory = MagicMock(state=lambda: self.state)
self.driver = InMemoryDriver(state_factory=state_factory)
self.driver.init_run(run_id=61016)
self.driver.set_run(run_id=61016)
self.driver.state = self.state

def test_get_run(self) -> None:
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
state = LinkStateFactory("").state()
run_id = state.create_run("", "", "", {})
self.driver = InMemoryDriver(MagicMock(state=lambda: state))
self.driver.init_run(run_id=run_id)
self.driver.set_run(run_id=run_id)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, SqliteLinkState)

Expand All @@ -261,7 +261,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
state = state_factory.state()
run_id = state.create_run("", "", "", {})
self.driver = InMemoryDriver(state_factory)
self.driver.init_run(run_id=run_id)
self.driver.set_run(run_id=run_id)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, InMemoryLinkState)

Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def run_server_app() -> None:
root_certificates=root_certificates,
)
flwr_dir = get_flwr_dir(args.flwr_dir)
driver.init_run(args.run_id)
driver.set_run(args.run_id)
run_ = driver.run
if not run_.fab_hash:
raise ValueError("FAB hash not provided.")
Expand Down Expand Up @@ -204,7 +204,7 @@ def run_server_app() -> None:
req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
# Fetch full `Run` using `run_id`
driver.init_run(res.run_id) # pylint: disable=W0212
driver.set_run(res.run_id) # pylint: disable=W0212
run_id = res.run_id

# Obtain server app reference and the run config
Expand Down
3 changes: 1 addition & 2 deletions src/py/flwr/server/serverapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212
run = run_from_proto(res.run)
fab = fab_from_proto(res.fab)

driver._run = None
driver.init_run(run.run_id)
driver.set_run(run.run_id)

# Start log uploader for this run
log_uploader = start_log_uploader(
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _main_loop(

# Initialize Driver
driver = InMemoryDriver(state_factory=state_factory)
driver.init_run(run_id=run.run_id)
driver.set_run(run_id=run.run_id)

# Get and run ServerApp thread
serverapp_th = run_serverapp_th(
Expand Down

0 comments on commit 1fb592e

Please sign in to comment.