Skip to content

Commit

Permalink
better
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Nov 7, 2024
1 parent 0181223 commit bf67547
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
22 changes: 13 additions & 9 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import subprocess
from logging import DEBUG
from pathlib import Path
from typing import Annotated, Any, Optional, cast
from typing import Annotated, Any, Optional

import typer

Expand All @@ -29,16 +29,19 @@
validate_federation_in_project_config,
validate_project_config,
)
from flwr.common import ConfigsRecord
from flwr.common.config import flatten_dict, parse_config_args
from flwr.common.config import (
flatten_dict,
parse_config_args,
user_config_to_configsrecord,
)
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.common.serde import (
configs_record_to_proto,
fab_to_proto,
user_config_to_proto,
)
from flwr.common.typing import ConfigsRecordValues, Fab
from flwr.common.typing import Fab
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub

Expand Down Expand Up @@ -99,6 +102,7 @@ def run(
_run_without_exec_api(app, federation_config, config_overrides, federation)


# pylint: disable-next=too-many-locals
def _run_with_exec_api(
app: Path,
federation_config: dict[str, Any],
Expand All @@ -123,14 +127,14 @@ def _run_with_exec_api(
content = Path(fab_path).read_bytes()
fab = Fab(fab_hash, content)

# Cast UserConfig as a dict compatible with what `ConfigsRecord` expects
user_config_flat = cast(
dict[str, ConfigsRecordValues], flatten_dict(federation_config.get("options"))
)
# Construct a `ConfigsRecord` out of a flattened `UserConfig`
fed_conf = flatten_dict(federation_config.get("options", {}))
c_record = user_config_to_configsrecord(fed_conf)

req = StartRunRequest(
fab=fab_to_proto(fab),
override_config=user_config_to_proto(parse_config_args(config_overrides)),
federation_options=configs_record_to_proto(ConfigsRecord(user_config_flat)),
federation_options=configs_record_to_proto(c_record),
)
res = stub.StartRun(req)

Expand Down
10 changes: 10 additions & 0 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tomli

from flwr.cli.config_utils import get_fab_config, validate_fields
from flwr.common import ConfigsRecord
from flwr.common.constant import (
APP_DIR,
FAB_CONFIG_FILE,
Expand Down Expand Up @@ -229,3 +230,12 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
config["project"]["version"],
f"{config['tool']['flwr']['app']['publisher']}/{config['project']['name']}",
)


def user_config_to_configsrecord(config: UserConfig) -> ConfigsRecord:
"""Construct a `ConfigsRecord` out of a `UserConfig`."""
c_record = ConfigsRecord()
for k, v in config.items():
c_record[k] = v

return c_record
10 changes: 7 additions & 3 deletions src/py/flwr/superexec/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from flwr.common.config import unflatten_dict
from flwr.common.constant import RUN_ID_NUM_BYTES
from flwr.common.logger import log
from flwr.common.typing import UserConfig
from flwr.common.typing import ConfigsRecordValues, UserConfig
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.linkstate import LinkStateFactory
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
Expand Down Expand Up @@ -165,8 +165,12 @@ def start_run(
)

# Unflatten underlaying dict
fed_opt = unflatten_dict(federation_options.__dict__["_data"])
print(f"{fed_opt = }")
c_record_data: dict[str, ConfigsRecordValues] = federation_options.__dict__[
"_data"
]
fed_opt = unflatten_dict(c_record_data)

# Read data
num_supernodes = fed_opt.get("num-supernodes", self.num_supernodes)
backend_cfg = fed_opt.get("backend", {})
verbose: Optional[bool] = fed_opt.get("verbose")
Expand Down

0 comments on commit bf67547

Please sign in to comment.