Skip to content

Commit

Permalink
feat(framework) Add override config to SuperExec (#3731)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Jul 8, 2024
1 parent 4ce439e commit b735710
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
21 changes: 16 additions & 5 deletions src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import subprocess
import sys
from logging import ERROR, INFO
from typing import Optional
from typing import Dict, Optional

from typing_extensions import override

Expand Down Expand Up @@ -53,18 +53,29 @@ def _connect(self) -> None:
)
self.stub = DriverStub(channel)

def _create_run(self, fab_id: str, fab_version: str) -> int:
def _create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, str],
) -> int:
if self.stub is None:
self._connect()

assert self.stub is not None

req = CreateRunRequest(fab_id=fab_id, fab_version=fab_version)
req = CreateRunRequest(
fab_id=fab_id,
fab_version=fab_version,
override_config=override_config,
)
res = self.stub.CreateRun(request=req)
return int(res.run_id)

@override
def start_run(self, fab_file: bytes) -> Optional[RunTracker]:
def start_run(
self, fab_file: bytes, override_config: Dict[str, str]
) -> Optional[RunTracker]:
"""Start run using the Flower Deployment Engine."""
try:
# Install FAB to flwr dir
Expand All @@ -79,7 +90,7 @@ def start_run(self, fab_file: bytes) -> Optional[RunTracker]:
)

# Call SuperLink to create run
run_id: int = self._create_run(fab_id, fab_version)
run_id: int = self._create_run(fab_id, fab_version, override_config)
log(INFO, "Created run %s", str(run_id))

# Start ServerApp
Expand Down
5 changes: 4 additions & 1 deletion src/py/flwr/superexec/exec_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def StartRun(
"""Create run ID."""
log(INFO, "ExecServicer.StartRun")

run = self.executor.start_run(request.fab_file)
run = self.executor.start_run(
request.fab_file,
dict(request.override_config.items()),
)

if run is None:
log(ERROR, "Executor failed to start run")
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/superexec/exec_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_start_run() -> None:
run_res.proc = proc

executor = MagicMock()
executor.start_run = lambda _: run_res
executor.start_run = lambda _, __: run_res

context_mock = MagicMock()

Expand Down
5 changes: 2 additions & 3 deletions src/py/flwr/superexec/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from subprocess import Popen
from typing import Optional
from typing import Dict, Optional


@dataclass
Expand All @@ -33,8 +33,7 @@ class Executor(ABC):

@abstractmethod
def start_run(
self,
fab_file: bytes,
self, fab_file: bytes, override_config: Dict[str, str]
) -> Optional[RunTracker]:
"""Start a run using the given Flower FAB ID and version.
Expand Down

0 comments on commit b735710

Please sign in to comment.