Skip to content

Commit

Permalink
Merge branch 'add-override-config-superexec-configrecordvalue' into c…
Browse files Browse the repository at this point in the history
…hange-override-config-type
  • Loading branch information
charlesbvll committed Jul 8, 2024
2 parents 4a81569 + 4dbac7d commit 7712971
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
13 changes: 10 additions & 3 deletions src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from flwr.cli.install import install_from_fab
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.serde import record_value_dict_to_proto
from flwr.common.typing import ConfigsRecordValues, ValueList

# pylint: disable=E0611
from flwr.proto.common_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
from flwr.proto.driver_pb2_grpc import DriverStub
from flwr.server.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER
Expand Down Expand Up @@ -57,7 +62,7 @@ def _create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, str],
override_config: Dict[str, ConfigsRecordValues],
) -> int:
if self.stub is None:
self._connect()
Expand All @@ -67,14 +72,16 @@ def _create_run(
req = CreateRunRequest(
fab_id=fab_id,
fab_version=fab_version,
override_config=override_config,
override_config=record_value_dict_to_proto(
override_config, ValueList, ProtoConfigsRecordValue
),
)
res = self.stub.CreateRun(request=req)
return int(res.run_id)

@override
def start_run(
self, fab_file: bytes, override_config: Dict[str, str]
self, fab_file: bytes, override_config: Dict[str, ConfigsRecordValues]
) -> Optional[RunTracker]:
"""Start run using the Flower Deployment Engine."""
try:
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/superexec/exec_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import grpc

from flwr.common.logger import log
from flwr.common.serde import record_value_dict_from_proto
from flwr.proto import exec_pb2_grpc # pylint: disable=E0611
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
StartRunRequest,
Expand All @@ -46,8 +47,7 @@ def StartRun(
log(INFO, "ExecServicer.StartRun")

run = self.executor.start_run(
request.fab_file,
dict(request.override_config.items()),
request.fab_file, record_value_dict_from_proto(request.override_config)
)

if run is None:
Expand Down
6 changes: 5 additions & 1 deletion src/py/flwr/superexec/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from subprocess import Popen
from typing import Dict, Optional

from flwr.common.typing import ConfigsRecordValues


@dataclass
class RunTracker:
Expand All @@ -33,7 +35,7 @@ class Executor(ABC):

@abstractmethod
def start_run(
self, fab_file: bytes, override_config: Dict[str, str]
self, fab_file: bytes, override_config: Dict[str, ConfigsRecordValues]
) -> Optional[RunTracker]:
"""Start a run using the given Flower FAB ID and version.
Expand All @@ -44,6 +46,8 @@ def start_run(
----------
fab_file : bytes
The Flower App Bundle file bytes.
override_config : Dict[str, ConfigsRecordValues]
A dict containing key-value pairs to override the FAB config.
Returns
-------
Expand Down

0 comments on commit 7712971

Please sign in to comment.