From 9bb359e28931ff6f138a4e6a10de3dd72ad5d7ad Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Thu, 11 Jul 2024 22:58:34 +0200 Subject: [PATCH] Use Value instead of ConfigsRecordValues --- src/py/flwr/cli/run/run.py | 4 +-- src/py/flwr/client/node_state.py | 4 +-- src/py/flwr/common/config.py | 28 ++++++++----------- src/py/flwr/common/config_test.py | 4 +-- src/py/flwr/common/context.py | 8 +++--- src/py/flwr/server/run_serverapp.py | 4 +-- .../server/superlink/state/in_memory_state.py | 4 +-- .../server/superlink/state/sqlite_state.py | 4 +-- src/py/flwr/server/superlink/state/state.py | 4 +-- src/py/flwr/superexec/deployment.py | 6 ++-- src/py/flwr/superexec/executor.py | 6 ++-- 11 files changed, 36 insertions(+), 40 deletions(-) diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index 219c499053c8..c5a98b975f39 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -30,7 +30,7 @@ from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log from flwr.common.serde import record_value_dict_to_proto -from flwr.common.typing import ConfigsRecordValues, ValueList +from flwr.common.typing import Value, ValueList # pylint: disable=E0611 from flwr.proto.common_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue @@ -133,7 +133,7 @@ def run( def _start_superexec_run( - override_config: Dict[str, ConfigsRecordValues], directory: Optional[Path] + override_config: Dict[str, Value], directory: Optional[Path] ) -> None: def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index fcaaf3ed72f0..4a6fa4bd10d1 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -21,7 +21,7 @@ from flwr.common import Context, RecordSet from flwr.common.config import get_fused_config -from flwr.common.typing import ConfigsRecordValues, Run +from flwr.common.typing import Value, Run @dataclass() @@ -29,7 +29,7 @@ class RunInfo: """Contains the Context and initial run_config of a Run.""" context: Context - initial_run_config: Dict[str, ConfigsRecordValues] + initial_run_config: Dict[str, Value] class NodeState: diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 9a34ed45fd2d..ec28d080ec9a 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -22,7 +22,7 @@ from flwr.cli.config_utils import validate_fields from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME -from flwr.common.typing import ConfigsRecordValues, ConfigsScalarList, Run +from flwr.common.typing import ConfigsScalarList, Run, Value def get_flwr_dir(provided_path: Optional[str] = None) -> Path: @@ -75,9 +75,9 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]: def _fuse_dicts( - main_dict: Dict[str, ConfigsRecordValues], - override_dict: Dict[str, ConfigsRecordValues], -) -> Dict[str, ConfigsRecordValues]: + main_dict: Dict[str, Value], + override_dict: Dict[str, Value], +) -> Dict[str, Value]: fused_dict = main_dict.copy() for key, value in override_dict.items(): @@ -87,9 +87,7 @@ def _fuse_dicts( return fused_dict -def get_fused_config( - run: Run, flwr_dir: Optional[Path] -) -> Dict[str, ConfigsRecordValues]: +def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, Value]: """Merge the overrides from a `Run` with the config from a FAB. Get the config using the fab_id and the fab_version, remove the nesting by adding @@ -106,22 +104,20 @@ def get_fused_config( return _fuse_dicts(flat_default_config, run.override_config) -def flatten_dict( - raw_dict: Dict[str, Any], parent_key: str = "" -) -> Dict[str, ConfigsRecordValues]: +def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, Value]: """Flatten dict by joining nested keys with a given separator.""" - items: List[Tuple[str, ConfigsRecordValues]] = [] + items: List[Tuple[str, Value]] = [] separator: str = "." for k, v in raw_dict.items(): new_key = f"{parent_key}{separator}{k}" if parent_key else k if isinstance(v, dict): items.extend(flatten_dict(v, parent_key=new_key).items()) - elif isinstance(v, get_args(ConfigsRecordValues)): - items.append((new_key, cast(ConfigsRecordValues, v))) + elif isinstance(v, get_args(Value)): + items.append((new_key, cast(Value, v))) elif isinstance(v, list) and any( isinstance(v, list_type) for list_type in get_args(ConfigsScalarList) ): - items.append((new_key, cast(ConfigsRecordValues, v))) + items.append((new_key, cast(Value, v))) else: raise ValueError( f"The value for key {k} needs to be of type `int`, `float`, " @@ -133,9 +129,9 @@ def flatten_dict( def parse_config_args( config_overrides: Optional[str], separator: str = ",", -) -> Dict[str, ConfigsRecordValues]: +) -> Dict[str, Value]: """Parse separator separated list of key-value pairs separated by '='.""" - overrides: Dict[str, ConfigsRecordValues] = {} + overrides: Dict[str, Value] = {} if config_overrides is None: return overrides diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py index 8a7a520da389..486beebba062 100644 --- a/src/py/flwr/common/config_test.py +++ b/src/py/flwr/common/config_test.py @@ -22,7 +22,7 @@ import pytest -from flwr.common.typing import ConfigsRecordValues +from flwr.common.typing import Value from .config import ( _fuse_dicts, @@ -112,7 +112,7 @@ def test_get_fused_config_valid(tmp_path: Path) -> None: [flower.config.clientapp] test = "key" """ - overrides: Dict[str, ConfigsRecordValues] = { + overrides: Dict[str, Value] = { "num_server_rounds": 5, "lr": 0.2, "serverapp.test": "overriden", diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py index 33765dc98239..d41ed17c5a8f 100644 --- a/src/py/flwr/common/context.py +++ b/src/py/flwr/common/context.py @@ -19,7 +19,7 @@ from typing import Dict, Optional from .record import RecordSet -from .typing import ConfigsRecordValues +from .typing import Value @dataclass @@ -35,7 +35,7 @@ class Context: executing mods. It can also be used as a memory to access at different points during the lifecycle of this entity (e.g. across multiple rounds) - run_config : Dict[str, ConfigsRecordValues] + run_config : Dict[str, Value] A config (key/value mapping) held by the entity in a given run and that will stay local. It can be used at any point during the lifecycle of this entity (e.g. across multiple rounds) @@ -47,12 +47,12 @@ class Context: state: RecordSet partition_id: Optional[int] - run_config: Dict[str, ConfigsRecordValues] + run_config: Dict[str, Value] def __init__( self, state: RecordSet, - run_config: Dict[str, ConfigsRecordValues], + run_config: Dict[str, Value], partition_id: Optional[int] = None, ) -> None: self.state = state diff --git a/src/py/flwr/server/run_serverapp.py b/src/py/flwr/server/run_serverapp.py index c668321cf57e..5086d54633e8 100644 --- a/src/py/flwr/server/run_serverapp.py +++ b/src/py/flwr/server/run_serverapp.py @@ -30,7 +30,7 @@ ) from flwr.common.logger import log, update_console_handler, warn_deprecated_feature from flwr.common.object_ref import load_app -from flwr.common.typing import ConfigsRecordValues +from flwr.common.typing import Value from flwr.proto.driver_pb2 import ( # pylint: disable=E0611 CreateRunRequest, CreateRunResponse, @@ -46,7 +46,7 @@ def run( driver: Driver, server_app_dir: str, - server_app_run_config: Dict[str, ConfigsRecordValues], + server_app_run_config: Dict[str, Value], server_app_attr: Optional[str] = None, loaded_server_app: Optional[ServerApp] = None, ) -> None: diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 938274562b70..dbdbae8bac38 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -23,7 +23,7 @@ from flwr.common import log, now from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES -from flwr.common.typing import ConfigsRecordValues, Run +from flwr.common.typing import Run, Value from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.state.state import State from flwr.server.utils import validate_task_ins_or_res @@ -279,7 +279,7 @@ def create_run( self, fab_id: str, fab_version: str, - override_config: Dict[str, ConfigsRecordValues], + override_config: Dict[str, Value], ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index a2a489f13610..d0c3bab4e9da 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -25,7 +25,7 @@ from flwr.common import log, now from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES -from flwr.common.typing import ConfigsRecordValues, Run +from flwr.common.typing import Run, Value from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 @@ -619,7 +619,7 @@ def create_run( self, fab_id: str, fab_version: str, - override_config: Dict[str, ConfigsRecordValues], + override_config: Dict[str, Value], ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id diff --git a/src/py/flwr/server/superlink/state/state.py b/src/py/flwr/server/superlink/state/state.py index bb3a00731414..c23785f16e8e 100644 --- a/src/py/flwr/server/superlink/state/state.py +++ b/src/py/flwr/server/superlink/state/state.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Set from uuid import UUID -from flwr.common.typing import ConfigsRecordValues, Run +from flwr.common.typing import Run, Value from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -161,7 +161,7 @@ def create_run( self, fab_id: str, fab_version: str, - override_config: Dict[str, ConfigsRecordValues], + override_config: Dict[str, Value], ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 91ca27c1239e..8b7c7101c897 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -26,7 +26,7 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.common.serde import record_value_dict_to_proto -from flwr.common.typing import ConfigsRecordValues, ValueList +from flwr.common.typing import Value, ValueList # pylint: disable=E0611 from flwr.proto.common_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue @@ -62,7 +62,7 @@ def _create_run( self, fab_id: str, fab_version: str, - override_config: Dict[str, ConfigsRecordValues], + override_config: Dict[str, Value], ) -> int: if self.stub is None: self._connect() @@ -81,7 +81,7 @@ def _create_run( @override def start_run( - self, fab_file: bytes, override_config: Dict[str, ConfigsRecordValues] + self, fab_file: bytes, override_config: Dict[str, Value] ) -> Optional[RunTracker]: """Start run using the Flower Deployment Engine.""" try: diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index 9806c471307a..a3a6e79f0a8a 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -19,7 +19,7 @@ from subprocess import Popen from typing import Dict, Optional -from flwr.common.typing import ConfigsRecordValues +from flwr.common.typing import Value @dataclass @@ -35,7 +35,7 @@ class Executor(ABC): @abstractmethod def start_run( - self, fab_file: bytes, override_config: Dict[str, ConfigsRecordValues] + self, fab_file: bytes, override_config: Dict[str, Value] ) -> Optional[RunTracker]: """Start a run using the given Flower FAB ID and version. @@ -46,7 +46,7 @@ def start_run( ---------- fab_file : bytes The Flower App Bundle file bytes. - override_config : Dict[str, ConfigsRecordValues] + override_config : Dict[str, Value] A dict containing key-value pairs to override the FAB config. Returns