Skip to content

Commit

Permalink
Add new set_config method
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Jul 13, 2024
1 parent a00b951 commit 066a16c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 17 deletions.
29 changes: 18 additions & 11 deletions src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
class DeploymentEngine(Executor):
"""Deployment engine executor.
Parameters
----------
superlink: str (Default: "0.0.0.0:9091")
Address of the `SuperLink` to connect to.
root_certificates: Optional[str] (Default: None)
Expand All @@ -61,6 +63,22 @@ def __init__(
self.flwr_dir = flwr_dir
self.stub: Optional[DriverStub] = None

@override
def set_config(
self,
config: Optional[Dict[str, str]],
) -> None:
"""Update config arguments."""
if config:
if superlink_address := config.get("superlink"):
self.superlink = superlink_address
if cert_path := config.get(
"root-certificates", config.get("root_certificates")
):
self.cert_path = cert_path
if flwr_dir := config.get("flwr-dir", config.get("flwr_dir")):
self.flwr_dir = flwr_dir

def _connect(self) -> None:
if self.stub is None:
channel = create_channel(
Expand Down Expand Up @@ -94,19 +112,8 @@ def start_run(
self,
fab_file: bytes,
override_config: Dict[str, str],
config: Optional[Dict[str, str]],
) -> Optional[RunTracker]:
"""Start run using the Flower Deployment Engine."""
if config:
if superlink_address := config.get("superlink"):
self.superlink = superlink_address
if cert_path := config.get(
"root-certificates", config.get("root_certificates")
):
self.cert_path = cert_path
if flwr_dir := config.get("flwr-dir", config.get("flwr_dir")):
self.flwr_dir = flwr_dir

try:
# Install FAB to flwr dir
fab_version, fab_id = get_fab_metadata(fab_file)
Expand Down
3 changes: 1 addition & 2 deletions src/py/flwr/superexec/exec_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):

def __init__(self, executor: Executor, config: Optional[Dict[str, str]]) -> None:
self.executor = executor
self.executor.set_config(config)
self.runs: Dict[int, RunTracker] = {}
self.config = config

def StartRun(
self, request: StartRunRequest, context: grpc.ServicerContext
Expand All @@ -49,7 +49,6 @@ def StartRun(
run = self.executor.start_run(
request.fab_file,
dict(request.override_config.items()),
self.config,
)

if run is None:
Expand Down
17 changes: 13 additions & 4 deletions src/py/flwr/superexec/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,24 @@ class RunTracker:
class Executor(ABC):
"""Execute and monitor a Flower run."""

@abstractmethod
def set_config(
self,
config: Optional[Dict[str, str]],
) -> None:
"""Register provided config as class attributes.
Parameters
----------
config : Optional[Dict[str, str]]
A dictionary for configuration values.
"""

@abstractmethod
def start_run(
self,
fab_file: bytes,
override_config: Dict[str, str],
config: Optional[Dict[str, str]],
) -> Optional[RunTracker]:
"""Start a run using the given Flower FAB ID and version.
Expand All @@ -47,9 +59,6 @@ def start_run(
----------
fab_file : bytes
The Flower App Bundle file bytes.
config : Optional[Dict[str, str]]
An optional dictionary containing key-value pairs to configure the
executor.
Returns
-------
Expand Down

0 comments on commit 066a16c

Please sign in to comment.