Skip to content

Commit

Permalink
Merge branch 'main' into add-clientappio-servicer
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Aug 15, 2024
2 parents 0947b74 + 480f683 commit 32a5260
Show file tree
Hide file tree
Showing 18 changed files with 106 additions and 51 deletions.
2 changes: 1 addition & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _on_backoff(retry_state: RetryState) -> None:
runs[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
runs[run_id] = Run(run_id, "", "", {})
runs[run_id] = Run(run_id, "", "", "", {})

# Register context for this run
node_state.register_context(
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def get_run(run_id: int) -> Run:
run_id,
get_run_response.run.fab_id,
get_run_response.run.fab_version,
get_run_response.run.fab_hash,
user_config_from_proto(get_run_response.run.override_config),
)

Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,13 @@ def get_run(run_id: int) -> Run:
# Send the request
res = _request(req, GetRunResponse, PATH_GET_RUN)
if res is None:
return Run(run_id, "", "", {})
return Run(run_id, "", "", "", {})

return Run(
run_id,
res.run.fab_id,
res.run.fab_version,
res.run.fab_hash,
user_config_from_proto(res.run.override_config),
)

Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,8 +850,8 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
run_id=run.run_id,
fab_id=run.fab_id,
fab_version=run.fab_version,
fab_hash=run.fab_hash,
override_config=user_config_to_proto(run.override_config),
fab_hash="",
)
return proto

Expand All @@ -862,6 +862,7 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
run_id=run_proto.run_id,
fab_id=run_proto.fab_id,
fab_version=run_proto.fab_version,
fab_hash=run_proto.fab_hash,
override_config=user_config_from_proto(run_proto.override_config),
)
return run
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def test_run_serialization_deserialization() -> None:
run_id=1,
fab_id="lorem",
fab_version="ipsum",
fab_hash="hash",
override_config=maker.user_config(),
)

Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ class Run:
run_id: int
fab_id: str
fab_version: str
fab_hash: str
override_config: UserConfig


Expand Down
12 changes: 12 additions & 0 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
from flwr.common.config import get_flwr_dir
from flwr.common.constant import (
MISSING_EXTRA_REST,
TRANSPORT_TYPE_GRPC_ADAPTER,
Expand All @@ -57,6 +58,7 @@
from .server_config import ServerConfig
from .strategy import Strategy
from .superlink.driver.driver_grpc import run_driver_api_grpc
from .superlink.ffs.ffs_factory import FfsFactory
from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
from .superlink.fleet.grpc_bidi.grpc_server import (
generic_create_grpc_server,
Expand All @@ -72,6 +74,7 @@
ADDRESS_FLEET_API_REST = "0.0.0.0:9093"

DATABASE = ":flwr-in-memory-state:"
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"


def start_server( # pylint: disable=too-many-arguments,too-many-locals
Expand Down Expand Up @@ -211,10 +214,14 @@ def run_superlink() -> None:
# Initialize StateFactory
state_factory = StateFactory(args.database)

# Initialize FfsFactory
ffs_factory = FfsFactory(args.storage_dir)

# Start Driver API
driver_server: grpc.Server = run_driver_api_grpc(
address=driver_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
)

Expand Down Expand Up @@ -610,6 +617,11 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
"Flower will just create a state in memory.",
default=DATABASE,
)
parser.add_argument(
"--storage-dir",
help="The base directory to store the objects for the Flower File System.",
default=BASE_DIR,
)
parser.add_argument(
"--auth-list-public-keys",
type=str,
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _init_run(self) -> None:
run_id=res.run.run_id,
fab_id=res.run.fab_id,
fab_version=res.run.fab_version,
fab_hash=res.run.fab_hash,
override_config=user_config_from_proto(res.run.override_config),
)

Expand Down
6 changes: 4 additions & 2 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def setUp(self) -> None:
run_id=61016,
fab_id="mock/mock",
fab_version="v1.0.0",
fab_hash="hash",
override_config={"test_key": "test_value"},
)
state_factory = MagicMock(state=lambda: self.state)
Expand All @@ -101,6 +102,7 @@ def test_get_run(self) -> None:
self.assertEqual(self.driver.run.run_id, 61016)
self.assertEqual(self.driver.run.fab_id, "mock/mock")
self.assertEqual(self.driver.run.fab_version, "v1.0.0")
self.assertEqual(self.driver.run.fab_hash, "hash")
self.assertEqual(self.driver.run.override_config["test_key"], "test_value")

def test_get_nodes(self) -> None:
Expand Down Expand Up @@ -227,7 +229,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
# Prepare
state = StateFactory("").state()
self.driver = InMemoryDriver(
state.create_run("", "", {}), MagicMock(state=lambda: state)
state.create_run("", "", "", {}), MagicMock(state=lambda: state)
)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, SqliteState)
Expand All @@ -253,7 +255,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
# Prepare
state_factory = StateFactory(":flwr-in-memory-state:")
state = state_factory.state()
self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory)
self.driver = InMemoryDriver(state.create_run("", "", "", {}), state_factory)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, InMemoryState)

Expand Down
3 changes: 3 additions & 0 deletions src/py/flwr/server/superlink/driver/driver_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
add_DriverServicer_to_server,
)
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.state import StateFactory

from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
Expand All @@ -33,12 +34,14 @@
def run_driver_api_grpc(
address: str,
state_factory: StateFactory,
ffs_factory: FfsFactory,
certificates: Optional[Tuple[bytes, bytes, bytes]],
) -> grpc.Server:
"""Run Driver API (gRPC, request-response)."""
# Create Driver API gRPC server
driver_servicer: grpc.Server = DriverServicer(
state_factory=state_factory,
ffs_factory=ffs_factory,
)
driver_add_servicer_to_server_fn = add_DriverServicer_to_server
driver_grpc_server = generic_create_grpc_server(
Expand Down
15 changes: 14 additions & 1 deletion src/py/flwr/server/superlink/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,18 @@
Run,
)
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
from flwr.server.superlink.ffs import Ffs
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.state import State, StateFactory
from flwr.server.utils.validator import validate_task_ins_or_res


class DriverServicer(driver_pb2_grpc.DriverServicer):
"""Driver API servicer."""

def __init__(self, state_factory: StateFactory) -> None:
def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
self.state_factory = state_factory
self.ffs_factory = ffs_factory

def GetNodes(
self, request: GetNodesRequest, context: grpc.ServicerContext
Expand All @@ -71,9 +74,19 @@ def CreateRun(
"""Create run ID."""
log(DEBUG, "DriverServicer.CreateRun")
state: State = self.state_factory.state()
if request.HasField("fab") and request.fab.HasField("content"):
ffs: Ffs = self.ffs_factory.ffs()
fab_hash = ffs.put(request.fab.content, {})
_raise_if(
fab_hash != request.fab.hash_str,
f"FAB ({request.fab}) hash from request doesn't match contents",
)
else:
fab_hash = ""
run_id = state.create_run(
request.fab_id,
request.fab_version,
fab_hash,
user_config_from_proto(request.override_config),
)
return CreateRunResponse(run_id=run_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._client_public_key)
)
run_id = self.state.create_run("", "", {})
run_id = self.state.create_run("", "", "", {})
request = GetRunRequest(run_id=run_id)
shared_secret = generate_shared_key(
self._client_private_key, self._server_public_key
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._client_public_key)
)
run_id = self.state.create_run("", "", {})
run_id = self.state.create_run("", "", "", {})
request = GetRunRequest(run_id=run_id)
client_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(client_private_key, self._server_public_key)
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def register_messages_into_state(
"""Register `num_messages` into the state factory."""
state: InMemoryState = state_factory.state() # type: ignore
state.run_ids[run_id] = Run(
run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={}
run_id=run_id,
fab_id="Mock/mock",
fab_version="v1.0.0",
fab_hash="hash",
override_config={},
)
# Artificially add TaskIns to state so they can be processed
# by the Simulation Engine logic
Expand Down Expand Up @@ -192,7 +196,7 @@ def start_and_shutdown(
if not app_dir:
app_dir = _autoresolve_app_dir()

run = Run(run_id=1234, fab_id="", fab_version="", override_config={})
run = Run(run_id=1234, fab_id="", fab_version="", fab_hash="", override_config={})

start_vce(
num_supernodes=num_supernodes,
Expand Down
12 changes: 7 additions & 5 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,20 +277,22 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:

def create_run(
self,
fab_id: str,
fab_version: str,
fab_id: Optional[str],
fab_version: Optional[str],
fab_hash: Optional[str],
override_config: UserConfig,
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
"""Create a new run for the specified `fab_hash`."""
# Sample a random int64 as run_id
with self.lock:
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

if run_id not in self.run_ids:
self.run_ids[run_id] = Run(
run_id=run_id,
fab_id=fab_id,
fab_version=fab_version,
fab_id=fab_id if fab_id else "",
fab_version=fab_version if fab_version else "",
fab_hash=fab_hash if fab_hash else "",
override_config=override_config,
)
return run_id
Expand Down
24 changes: 17 additions & 7 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
run_id INTEGER UNIQUE,
fab_id TEXT,
fab_version TEXT,
fab_hash TEXT,
override_config TEXT
);
"""
Expand Down Expand Up @@ -617,8 +618,9 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:

def create_run(
self,
fab_id: str,
fab_version: str,
fab_id: Optional[str],
fab_version: Optional[str],
fab_hash: Optional[str],
override_config: UserConfig,
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
Expand All @@ -630,12 +632,19 @@ def create_run(
# If run_id does not exist
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
query = (
"INSERT INTO run (run_id, fab_id, fab_version, override_config)"
"VALUES (?, ?, ?, ?);"
)
self.query(
query, (run_id, fab_id, fab_version, json.dumps(override_config))
"INSERT INTO run "
"(run_id, fab_id, fab_version, fab_hash, override_config)"
"VALUES (?, ?, ?, ?, ?);"
)
if fab_hash:
self.query(
query, (run_id, "", "", fab_hash, json.dumps(override_config))
)
else:
self.query(
query,
(run_id, fab_id, fab_version, "", json.dumps(override_config)),
)
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down Expand Up @@ -702,6 +711,7 @@ def get_run(self, run_id: int) -> Optional[Run]:
run_id=run_id,
fab_id=row["fab_id"],
fab_version=row["fab_version"],
fab_hash=row["fab_hash"],
override_config=json.loads(row["override_config"]),
)
except sqlite3.IntegrityError:
Expand Down
7 changes: 4 additions & 3 deletions src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
@abc.abstractmethod
def create_run(
self,
fab_id: str,
fab_version: str,
fab_id: Optional[str],
fab_version: Optional[str],
fab_hash: Optional[str],
override_config: UserConfig,
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
"""Create a new run for the specified `fab_hash`."""

@abc.abstractmethod
def get_run(self, run_id: int) -> Optional[Run]:
Expand Down
Loading

0 comments on commit 32a5260

Please sign in to comment.