From 066a16c830035ef330bea1578886986145cb7861 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Sat, 13 Jul 2024 17:43:07 +0200 Subject: [PATCH] Add new `set_config` method --- src/py/flwr/superexec/deployment.py | 29 ++++++++++++++++---------- src/py/flwr/superexec/exec_servicer.py | 3 +-- src/py/flwr/superexec/executor.py | 17 +++++++++++---- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index fb0d38e2eb1e..6032582d2cb3 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -36,6 +36,8 @@ class DeploymentEngine(Executor): """Deployment engine executor. + Parameters + ---------- superlink: str (Default: "0.0.0.0:9091") Address of the `SuperLink` to connect to. root_certificates: Optional[str] (Default: None) @@ -61,6 +63,22 @@ def __init__( self.flwr_dir = flwr_dir self.stub: Optional[DriverStub] = None + @override + def set_config( + self, + config: Optional[Dict[str, str]], + ) -> None: + """Update config arguments.""" + if config: + if superlink_address := config.get("superlink"): + self.superlink = superlink_address + if cert_path := config.get( + "root-certificates", config.get("root_certificates") + ): + self.cert_path = cert_path + if flwr_dir := config.get("flwr-dir", config.get("flwr_dir")): + self.flwr_dir = flwr_dir + def _connect(self) -> None: if self.stub is None: channel = create_channel( @@ -94,19 +112,8 @@ def start_run( self, fab_file: bytes, override_config: Dict[str, str], - config: Optional[Dict[str, str]], ) -> Optional[RunTracker]: """Start run using the Flower Deployment Engine.""" - if config: - if superlink_address := config.get("superlink"): - self.superlink = superlink_address - if cert_path := config.get( - "root-certificates", config.get("root_certificates") - ): - self.cert_path = cert_path - if flwr_dir := config.get("flwr-dir", config.get("flwr_dir")): - self.flwr_dir = flwr_dir - try: # Install FAB to flwr dir fab_version, fab_id = get_fab_metadata(fab_file) diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 3671b72cc126..7175bfaf4805 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -37,8 +37,8 @@ class ExecServicer(exec_pb2_grpc.ExecServicer): def __init__(self, executor: Executor, config: Optional[Dict[str, str]]) -> None: self.executor = executor + self.executor.set_config(config) self.runs: Dict[int, RunTracker] = {} - self.config = config def StartRun( self, request: StartRunRequest, context: grpc.ServicerContext @@ -49,7 +49,6 @@ def StartRun( run = self.executor.start_run( request.fab_file, dict(request.override_config.items()), - self.config, ) if run is None: diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index 5cba27964be7..4e948df16010 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -31,12 +31,24 @@ class RunTracker: class Executor(ABC): """Execute and monitor a Flower run.""" + @abstractmethod + def set_config( + self, + config: Optional[Dict[str, str]], + ) -> None: + """Register provided config as class attributes. + + Parameters + ---------- + config : Optional[Dict[str, str]] + A dictionary for configuration values. + """ + @abstractmethod def start_run( self, fab_file: bytes, override_config: Dict[str, str], - config: Optional[Dict[str, str]], ) -> Optional[RunTracker]: """Start a run using the given Flower FAB ID and version. @@ -47,9 +59,6 @@ def start_run( ---------- fab_file : bytes The Flower App Bundle file bytes. - config : Optional[Dict[str, str]] - An optional dictionary containing key-value pairs to configure the - executor. Returns -------