Skip to content

Commit

Permalink
Add config overrides and improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Jul 6, 2024
1 parent 55b36a5 commit 39a1946
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 54 deletions.
39 changes: 25 additions & 14 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,27 @@
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.common.serde import record_value_dict_to_proto
from flwr.common.typing import ConfigsRecordValues
from flwr.common.typing import ConfigsRecordValues, ValueList
from flwr.proto.common_pb2 import ConfigsRecordValue as PCRV # pylint: disable=E0611
from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub
from flwr.simulation.run_simulation import _run_simulation


def _parse_config_overrides(
config_overrides: Optional[List[str]],
) -> Dict[str, ConfigsRecordValues]:
"""Parse the -c arguments and return the overrides as a dict."""
overrides: Dict[str, ConfigsRecordValues] = {}

if config_overrides is not None:
for kv_pair in config_overrides:
key, value = kv_pair.split("=")
overrides[key] = value

return overrides


class Engine(str, Enum):
"""Enum defining the engine to run on."""

Expand All @@ -61,10 +75,18 @@ def run(
Optional[Path],
typer.Option(help="Path of the Flower project to run"),
] = None,
config_overrides: Annotated[
Optional[List[str]],
typer.Option(
"--config",
"-c",
help="Override configuration key-value pairs",
),
] = None,
) -> None:
"""Run Flower project."""
if use_superexec:
_start_superexec_run({}, directory)
_start_superexec_run(_parse_config_overrides(config_overrides), directory)
return

typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
Expand Down Expand Up @@ -141,18 +163,7 @@ def on_channel_state_change(channel_connectivity: str) -> None:
fab_file=Path(fab_path).read_bytes(),
override_config=record_value_dict_to_proto(
override_config,
[
bool,
int,
float,
str,
bytes,
List[int],
List[float],
List[str],
List[bytes],
List[bool],
],
ValueList,
PCRV,
),
)
Expand Down
18 changes: 15 additions & 3 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:
return config


def _flatten_dict(raw_dict, sep="."):
items = []
for k, v in raw_dict.items():
if isinstance(v, dict):
items.extend(_flatten_dict(v, sep=sep).items())
else:
items.append((k, v))
return dict(items)


def get_fused_config(
run: Run, flwr_dir: Optional[Path]
) -> Dict[str, ConfigsRecordValues]:
Expand All @@ -84,9 +94,11 @@ def get_fused_config(
return {}

final_config = {}
default_config = get_project_config(
get_project_dir(run.fab_id, run.fab_version, flwr_dir)
)["flower"]["config"]
default_config = _flatten_dict(
get_project_config(get_project_dir(run.fab_id, run.fab_version, flwr_dir))[
"flower"
]["config"]
)

for key in default_config.keys():
if key in run.override_config:
Expand Down
38 changes: 14 additions & 24 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
"""ProtoBuf serialization and deserialization."""


from typing import Any, Dict, List, MutableMapping, OrderedDict, Type, TypeVar, cast
from typing import (
Any,
Dict,
List,
MutableMapping,
OrderedDict,
Type,
TypeVar,
Union,
cast,
)

from google.protobuf.message import Message as GrpcMessage

Expand Down Expand Up @@ -412,27 +422,7 @@ def _record_value_from_proto(value_proto: GrpcMessage) -> Any:


def record_value_dict_to_proto(
value_dict: Dict[str, ConfigsRecordValues],
allowed_types: List[type],
value_proto_class: Type[T],
) -> Dict[str, T]:
"""Serialize the record value dict to ProtoBuf.
Note: `bool` MUST be put in the front of allowd_types if it exists.
"""
# Move bool to the front
if bool in allowed_types and allowed_types[0] != bool:
allowed_types.remove(bool)
allowed_types.insert(0, bool)

def proto(_v: Any) -> T:
return _record_value_to_proto(_v, allowed_types, value_proto_class)

return {k: proto(v) for k, v in value_dict.items()}


def _record_value_dict_to_proto(
value_dict: TypedDict[str, Any],
value_dict: Union[TypedDict[str, Any], Dict[str, ConfigsRecordValues]],
allowed_types: List[type],
value_proto_class: Type[T],
) -> Dict[str, T]:
Expand Down Expand Up @@ -496,7 +486,7 @@ def parameters_record_from_proto(
def metrics_record_to_proto(record: MetricsRecord) -> ProtoMetricsRecord:
"""Serialize MetricsRecord to ProtoBuf."""
return ProtoMetricsRecord(
data=_record_value_dict_to_proto(record, [float, int], ProtoMetricsRecordValue)
data=record_value_dict_to_proto(record, [float, int], ProtoMetricsRecordValue)
)


Expand All @@ -514,7 +504,7 @@ def metrics_record_from_proto(record_proto: ProtoMetricsRecord) -> MetricsRecord
def configs_record_to_proto(record: ConfigsRecord) -> ProtoConfigsRecord:
"""Serialize ConfigsRecord to ProtoBuf."""
return ProtoConfigsRecord(
data=_record_value_dict_to_proto(
data=record_value_dict_to_proto(
record,
[bool, int, float, str, bytes],
ProtoConfigsRecordValue,
Expand Down
13 changes: 13 additions & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@
List[int],
List[str],
]
ValueList = [
bool,
bytes,
float,
int,
str,
List[bool],
List[bytes],
List[float],
List[int],
List[str],
]


# Value types for common.MetricsRecord
MetricsScalar = Union[int, float]
Expand Down
15 changes: 2 additions & 13 deletions src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.serde import record_value_dict_to_proto
from flwr.common.typing import ConfigsRecordValues
from flwr.common.typing import ConfigsRecordValues, ValueList
from flwr.proto.common_pb2 import ConfigsRecordValue as PCRV # pylint: disable=E0611
from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
from flwr.proto.driver_pb2_grpc import DriverStub
Expand Down Expand Up @@ -72,18 +72,7 @@ def _create_run(
fab_version=fab_version,
override_config=record_value_dict_to_proto(
override_config,
[
bool,
int,
float,
str,
bytes,
List[int],
List[float],
List[str],
List[bytes],
List[bool],
],
ValueList,
PCRV,
),
)
Expand Down

0 comments on commit 39a1946

Please sign in to comment.