Skip to content

Commit

Permalink
Merge branch 'main' into get-federation-options
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Nov 7, 2024
2 parents 93da4be + d83edbb commit 938d8b0
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 90 deletions.
3 changes: 2 additions & 1 deletion src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package flwr.proto;

import "flwr/proto/fab.proto";
import "flwr/proto/transport.proto";
import "flwr/proto/recordset.proto";

service Exec {
// Start run upon request
Expand All @@ -31,7 +32,7 @@ service Exec {
message StartRunRequest {
Fab fab = 1;
map<string, Scalar> override_config = 2;
map<string, Scalar> federation_config = 3;
ConfigsRecord federation_options = 3;
}
message StartRunResponse { uint64 run_id = 1; }
message StreamLogsRequest {
Expand Down
21 changes: 16 additions & 5 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@
validate_federation_in_project_config,
validate_project_config,
)
from flwr.common.config import flatten_dict, parse_config_args
from flwr.common.config import (
flatten_dict,
parse_config_args,
user_config_to_configsrecord,
)
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.common.serde import fab_to_proto, user_config_to_proto
from flwr.common.serde import (
configs_record_to_proto,
fab_to_proto,
user_config_to_proto,
)
from flwr.common.typing import Fab
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub
Expand Down Expand Up @@ -94,6 +102,7 @@ def run(
_run_without_exec_api(app, federation_config, config_overrides, federation)


# pylint: disable-next=too-many-locals
def _run_with_exec_api(
app: Path,
federation_config: dict[str, Any],
Expand All @@ -118,12 +127,14 @@ def _run_with_exec_api(
content = Path(fab_path).read_bytes()
fab = Fab(fab_hash, content)

# Construct a `ConfigsRecord` out of a flattened `UserConfig`
fed_conf = flatten_dict(federation_config.get("options", {}))
c_record = user_config_to_configsrecord(fed_conf)

req = StartRunRequest(
fab=fab_to_proto(fab),
override_config=user_config_to_proto(parse_config_args(config_overrides)),
federation_config=user_config_to_proto(
flatten_dict(federation_config.get("options"))
),
federation_options=configs_record_to_proto(c_record),
)
res = stub.StartRun(req)

Expand Down
10 changes: 10 additions & 0 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tomli

from flwr.cli.config_utils import get_fab_config, validate_fields
from flwr.common import ConfigsRecord
from flwr.common.constant import (
APP_DIR,
FAB_CONFIG_FILE,
Expand Down Expand Up @@ -229,3 +230,12 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
config["project"]["version"],
f"{config['tool']['flwr']['app']['publisher']}/{config['project']['name']}",
)


def user_config_to_configsrecord(config: UserConfig) -> ConfigsRecord:
"""Construct a `ConfigsRecord` out of a `UserConfig`."""
c_record = ConfigsRecord()
for k, v in config.items():
c_record[k] = v

return c_record
73 changes: 40 additions & 33 deletions src/py/flwr/common/object_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def validate(
specified attribute within it.
project_dir : Optional[Union[str, Path]] (default: None)
The directory containing the module. If None, the current working directory
is used. If `check_module` is True, the `project_dir` will be inserted into
the system path, and the previously inserted `project_dir` will be removed.
is used. If `check_module` is True, the `project_dir` will be temporarily
inserted into the system path and then removed after the validation is complete.
Returns
-------
Expand All @@ -66,8 +66,8 @@ def validate(
Note
----
This function will modify `sys.path` by inserting the provided `project_dir`
and removing the previously inserted `project_dir`.
This function will temporarily modify `sys.path` by inserting the provided
`project_dir`, which will be removed after the validation is complete.
"""
module_str, _, attributes_str = module_attribute_str.partition(":")
if not module_str:
Expand All @@ -82,11 +82,19 @@ def validate(
)

if check_module:
if project_dir is None:
project_dir = Path.cwd()
project_dir = Path(project_dir).absolute()
# Set the system path
_set_sys_path(project_dir)
sys.path.insert(0, str(project_dir))

# Load module
module = find_spec(module_str)

# Unset the system path
sys.path.remove(str(project_dir))

# Check if the module and the attribute exist
if module and module.origin:
if not _find_attribute_in_module(module.origin, attributes_str):
return (
Expand Down Expand Up @@ -133,8 +141,10 @@ def load_app( # pylint: disable= too-many-branches
Note
----
This function will modify `sys.path` by inserting the provided `project_dir`
and removing the previously inserted `project_dir`.
- This function will unload all modules in the previously provided `project_dir`,
if it is invoked again.
- This function will modify `sys.path` by inserting the provided `project_dir`
and removing the previously inserted `project_dir`.
"""
valid, error_msg = validate(module_attribute_str, check_module=False)
if not valid and error_msg:
Expand All @@ -145,33 +155,21 @@ def load_app( # pylint: disable= too-many-branches
print(f"{project_dir = }")
importlib.invalidate_caches()
try:
if _current_sys_path:
# Hack: `tabnet` does not work with reloading
if "tabnet" in sys.modules:
log(
WARN,
"Cannot reload module `%s` from disk due to compatibility issues "
"with the `tabnet` library. The module will be loaded from the "
"cache instead. If you experience issues, consider restarting "
"the application.",
module_str,
)
else:
_unload_modules(Path(_current_sys_path))
_set_sys_path(project_dir)

if module_str not in sys.modules:
module = importlib.import_module(module_str)
# Hack: `tabnet` does not work with `importlib.reload`
elif "tabnet" in sys.modules:
log(
WARN,
"Cannot reload module `%s` from disk due to compatibility issues "
"with the `tabnet` library. The module will be loaded from the "
"cache instead. If you experience issues, consider restarting "
"the application.",
module_str,
)
module = sys.modules[module_str]
else:
module = sys.modules[module_str]

if project_dir is None:
project_dir = Path.cwd()

# Reload cached modules in the project directory
for m in list(sys.modules.values()):
path: Optional[str] = getattr(m, "__file__", None)
if path is not None and path.startswith(str(project_dir)):
importlib.reload(m)

module = importlib.import_module(module_str)
except ModuleNotFoundError as err:
raise error_type(
f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
Expand All @@ -191,6 +189,15 @@ def load_app( # pylint: disable= too-many-branches
return attribute


def _unload_modules(project_dir: Path) -> None:
"""Unload modules from the project directory."""
dir_str = str(project_dir.absolute())
for name, m in list(sys.modules.items()):
path: Optional[str] = getattr(m, "__file__", None)
if path is not None and path.startswith(dir_str):
del sys.modules[name]


def _set_sys_path(directory: Optional[Union[str, Path]]) -> None:
"""Set the system path."""
if directory is None:
Expand Down
31 changes: 14 additions & 17 deletions src/py/flwr/proto/exec_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 6 additions & 20 deletions src/py/flwr/proto/exec_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ isort:skip_file
"""
import builtins
import flwr.proto.fab_pb2
import flwr.proto.recordset_pb2
import flwr.proto.transport_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
Expand All @@ -30,38 +31,23 @@ class StartRunRequest(google.protobuf.message.Message):
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

class FederationConfigEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: typing.Text
@property
def value(self) -> flwr.proto.transport_pb2.Scalar: ...
def __init__(self,
*,
key: typing.Text = ...,
value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

FAB_FIELD_NUMBER: builtins.int
OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
FEDERATION_CONFIG_FIELD_NUMBER: builtins.int
FEDERATION_OPTIONS_FIELD_NUMBER: builtins.int
@property
def fab(self) -> flwr.proto.fab_pb2.Fab: ...
@property
def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
@property
def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
def federation_options(self) -> flwr.proto.recordset_pb2.ConfigsRecord: ...
def __init__(self,
*,
fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
federation_options: typing.Optional[flwr.proto.recordset_pb2.ConfigsRecord] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_options",b"federation_options"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_options",b"federation_options","override_config",b"override_config"]) -> None: ...
global___StartRunRequest = StartRunRequest

class StartRunResponse(google.protobuf.message.Message):
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def start_run(
self,
fab_file: bytes,
override_config: UserConfig,
federation_config: UserConfig,
federation_options: ConfigsRecord,
) -> Optional[int]:
"""Start run using the Flower Deployment Engine."""
run_id = None
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/superexec/exec_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from flwr.common.constant import LOG_STREAM_INTERVAL, Status
from flwr.common.logger import log
from flwr.common.serde import user_config_from_proto
from flwr.common.serde import configs_record_from_proto, user_config_from_proto
from flwr.proto import exec_pb2_grpc # pylint: disable=E0611
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
StartRunRequest,
Expand Down Expand Up @@ -61,7 +61,7 @@ def StartRun(
run_id = self.executor.start_run(
request.fab.content,
user_config_from_proto(request.override_config),
user_config_from_proto(request.federation_config),
configs_record_from_proto(request.federation_options),
)

if run_id is None:
Expand Down
7 changes: 4 additions & 3 deletions src/py/flwr/superexec/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from subprocess import Popen
from typing import Optional

from flwr.common import ConfigsRecord
from flwr.common.typing import UserConfig
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.linkstate import LinkStateFactory
Expand Down Expand Up @@ -71,7 +72,7 @@ def start_run(
self,
fab_file: bytes,
override_config: UserConfig,
federation_config: UserConfig,
federation_options: ConfigsRecord,
) -> Optional[int]:
"""Start a run using the given Flower FAB ID and version.
Expand All @@ -84,8 +85,8 @@ def start_run(
The Flower App Bundle file bytes.
override_config: UserConfig
The config overrides dict sent by the user (using `flwr run`).
federation_config: UserConfig
The federation options dict sent by the user (using `flwr run`).
federation_options: ConfigsRecord
The federation options sent by the user (using `flwr run`).
Returns
-------
Expand Down
16 changes: 8 additions & 8 deletions src/py/flwr/superexec/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from flwr.cli.config_utils import load_and_validate
from flwr.cli.install import install_from_fab
from flwr.common import ConfigsRecord
from flwr.common.config import unflatten_dict
from flwr.common.constant import RUN_ID_NUM_BYTES
from flwr.common.logger import log
Expand Down Expand Up @@ -124,7 +125,7 @@ def start_run(
self,
fab_file: bytes,
override_config: UserConfig,
federation_config: UserConfig,
federation_options: ConfigsRecord,
) -> Optional[int]:
"""Start run using the Flower Simulation Engine."""
if self.num_supernodes is None:
Expand Down Expand Up @@ -163,14 +164,13 @@ def start_run(
"Config extracted from FAB's pyproject.toml is not valid"
)

# Flatten federated config
federation_config_flat = unflatten_dict(federation_config)
# Unflatten underlaying dict
fed_opt = unflatten_dict({**federation_options})

num_supernodes = federation_config_flat.get(
"num-supernodes", self.num_supernodes
)
backend_cfg = federation_config_flat.get("backend", {})
verbose: Optional[bool] = federation_config_flat.get("verbose")
# Read data
num_supernodes = fed_opt.get("num-supernodes", self.num_supernodes)
backend_cfg = fed_opt.get("backend", {})
verbose: Optional[bool] = fed_opt.get("verbose")

# In Simulation there is no SuperLink, still we create a run_id
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
Expand Down

0 comments on commit 938d8b0

Please sign in to comment.