Skip to content

Commit

Permalink
feat(framework) Add override config to Run (#3730)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Jul 8, 2024
1 parent b735710 commit 7c2c6c3
Show file tree
Hide file tree
Showing 14 changed files with 87 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def _on_backoff(retry_state: RetryState) -> None:
run_info[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
run_info[run_id] = Run(run_id, "", "")
run_info[run_id] = Run(run_id, "", "", {})

# Register context for this run
node_state.register_context(run_id=run_id)
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def get_run(run_id: int) -> Run:
run_id,
get_run_response.run.fab_id,
get_run_response.run.fab_version,
dict(get_run_response.run.override_config.items()),
)

try:
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,13 @@ def get_run(run_id: int) -> Run:
# Send the request
res = _request(req, GetRunResponse, PATH_GET_RUN)
if res is None:
return Run(run_id, "", "")
return Run(run_id, "", "", {})

return Run(
run_id,
res.run.fab_id,
res.run.fab_version,
dict(res.run.override_config.items()),
)

try:
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ class Run:
run_id: int
fab_id: str
fab_version: str
override_config: Dict[str, str]
1 change: 1 addition & 0 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _init_run(self) -> None:
run_id=res.run.run_id,
fab_id=res.run.fab_id,
fab_version=res.run.fab_version,
override_config=dict(res.run.override_config.items()),
)

@property
Expand Down
10 changes: 7 additions & 3 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def setUp(self) -> None:
for _ in range(self.num_nodes)
]
self.state.get_run.return_value = Run(
run_id=61016, fab_id="mock/mock", fab_version="v1.0.0"
run_id=61016,
fab_id="mock/mock",
fab_version="v1.0.0",
override_config={"test_key": "test_value"},
)
state_factory = MagicMock(state=lambda: self.state)
self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory)
Expand All @@ -98,6 +101,7 @@ def test_get_run(self) -> None:
self.assertEqual(self.driver.run.run_id, 61016)
self.assertEqual(self.driver.run.fab_id, "mock/mock")
self.assertEqual(self.driver.run.fab_version, "v1.0.0")
self.assertEqual(self.driver.run.override_config["test_key"], "test_value")

def test_get_nodes(self) -> None:
"""Test retrieval of nodes."""
Expand Down Expand Up @@ -223,7 +227,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
# Prepare
state = StateFactory("").state()
self.driver = InMemoryDriver(
state.create_run("", ""), MagicMock(state=lambda: state)
state.create_run("", "", {}), MagicMock(state=lambda: state)
)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, SqliteState)
Expand All @@ -249,7 +253,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
# Prepare
state_factory = StateFactory(":flwr-in-memory-state:")
state = state_factory.state()
self.driver = InMemoryDriver(state.create_run("", ""), state_factory)
self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, InMemoryState)

Expand Down
6 changes: 5 additions & 1 deletion src/py/flwr/server/superlink/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def CreateRun(
"""Create run ID."""
log(DEBUG, "DriverServicer.CreateRun")
state: State = self.state_factory.state()
run_id = state.create_run(request.fab_id, request.fab_version)
run_id = state.create_run(
request.fab_id,
request.fab_version,
dict(request.override_config.items()),
)
return CreateRunResponse(run_id=run_id)

def PushTaskIns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._client_public_key)
)
run_id = self.state.create_run("", "")
run_id = self.state.create_run("", "", {})
request = GetRunRequest(run_id=run_id)
shared_secret = generate_shared_key(
self._client_private_key, self._server_public_key
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._client_public_key)
)
run_id = self.state.create_run("", "")
run_id = self.state.create_run("", "", {})
request = GetRunRequest(run_id=run_id)
client_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(client_private_key, self._server_public_key)
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def register_messages_into_state(
) -> Dict[UUID, float]:
"""Register `num_messages` into the state factory."""
state: InMemoryState = state_factory.state() # type: ignore
state.run_ids[run_id] = Run(run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0")
state.run_ids[run_id] = Run(
run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={}
)
# Artificially add TaskIns to state so they can be processed
# by the Simulation Engine logic
nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes
Expand Down
12 changes: 10 additions & 2 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,23 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
return self.public_key_to_node_id.get(client_public_key)

def create_run(self, fab_id: str, fab_version: str) -> int:
def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, str],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
# Sample a random int64 as run_id
with self.lock:
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

if run_id not in self.run_ids:
self.run_ids[run_id] = Run(
run_id=run_id, fab_id=fab_id, fab_version=fab_version
run_id=run_id,
fab_id=fab_id,
fab_version=fab_version,
override_config=override_config,
)
return run_id
log(ERROR, "Unexpected run creation failure.")
Expand Down
29 changes: 22 additions & 7 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""SQLite based implemenation of server state."""


import json
import re
import sqlite3
import time
Expand Down Expand Up @@ -61,9 +62,10 @@

SQL_CREATE_TABLE_RUN = """
CREATE TABLE IF NOT EXISTS run(
run_id INTEGER UNIQUE,
fab_id TEXT,
fab_version TEXT
run_id INTEGER UNIQUE,
fab_id TEXT,
fab_version TEXT,
override_config TEXT
);
"""

Expand Down Expand Up @@ -613,7 +615,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
return node_id
return None

def create_run(self, fab_id: str, fab_version: str) -> int:
def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, str],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
# Sample a random int64 as run_id
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
Expand All @@ -622,8 +629,13 @@ def create_run(self, fab_id: str, fab_version: str) -> int:
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
# If run_id does not exist
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
self.query(query, (run_id, fab_id, fab_version))
query = (
"INSERT INTO run (run_id, fab_id, fab_version, override_config)"
"VALUES (?, ?, ?, ?);"
)
self.query(
query, (run_id, fab_id, fab_version, json.dumps(override_config))
)
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down Expand Up @@ -687,7 +699,10 @@ def get_run(self, run_id: int) -> Optional[Run]:
try:
row = self.query(query, (run_id,))[0]
return Run(
run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"]
run_id=run_id,
fab_id=row["fab_id"],
fab_version=row["fab_version"],
override_config=json.loads(row["override_config"]),
)
except sqlite3.IntegrityError:
log(ERROR, "`run_id` does not exist.")
Expand Down
9 changes: 7 additions & 2 deletions src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


import abc
from typing import List, Optional, Set
from typing import Dict, List, Optional, Set
from uuid import UUID

from flwr.common.typing import Run
Expand Down Expand Up @@ -157,7 +157,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `client_public_keys`."""

@abc.abstractmethod
def create_run(self, fab_id: str, fab_version: str) -> int:
def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, str],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""

@abc.abstractmethod
Expand Down
Loading

0 comments on commit 7c2c6c3

Please sign in to comment.