Skip to content

Commit

Permalink
Use Value instead of ConfigsRecordValues
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Jul 11, 2024
1 parent c9719ae commit 9bb359e
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 40 deletions.
4 changes: 2 additions & 2 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.common.serde import record_value_dict_to_proto
from flwr.common.typing import ConfigsRecordValues, ValueList
from flwr.common.typing import Value, ValueList

# pylint: disable=E0611
from flwr.proto.common_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
Expand Down Expand Up @@ -133,7 +133,7 @@ def run(


def _start_superexec_run(
override_config: Dict[str, ConfigsRecordValues], directory: Optional[Path]
override_config: Dict[str, Value], directory: Optional[Path]
) -> None:
def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@

from flwr.common import Context, RecordSet
from flwr.common.config import get_fused_config
from flwr.common.typing import ConfigsRecordValues, Run
from flwr.common.typing import Value, Run


@dataclass()
class RunInfo:
"""Contains the Context and initial run_config of a Run."""

context: Context
initial_run_config: Dict[str, ConfigsRecordValues]
initial_run_config: Dict[str, Value]


class NodeState:
Expand Down
28 changes: 12 additions & 16 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from flwr.cli.config_utils import validate_fields
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
from flwr.common.typing import ConfigsRecordValues, ConfigsScalarList, Run
from flwr.common.typing import ConfigsScalarList, Run, Value


def get_flwr_dir(provided_path: Optional[str] = None) -> Path:
Expand Down Expand Up @@ -75,9 +75,9 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]:


def _fuse_dicts(
main_dict: Dict[str, ConfigsRecordValues],
override_dict: Dict[str, ConfigsRecordValues],
) -> Dict[str, ConfigsRecordValues]:
main_dict: Dict[str, Value],
override_dict: Dict[str, Value],
) -> Dict[str, Value]:
fused_dict = main_dict.copy()

for key, value in override_dict.items():
Expand All @@ -87,9 +87,7 @@ def _fuse_dicts(
return fused_dict


def get_fused_config(
run: Run, flwr_dir: Optional[Path]
) -> Dict[str, ConfigsRecordValues]:
def get_fused_config(run: Run, flwr_dir: Optional[Path]) -> Dict[str, Value]:
"""Merge the overrides from a `Run` with the config from a FAB.
Get the config using the fab_id and the fab_version, remove the nesting by adding
Expand All @@ -106,22 +104,20 @@ def get_fused_config(
return _fuse_dicts(flat_default_config, run.override_config)


def flatten_dict(
raw_dict: Dict[str, Any], parent_key: str = ""
) -> Dict[str, ConfigsRecordValues]:
def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, Value]:
"""Flatten dict by joining nested keys with a given separator."""
items: List[Tuple[str, ConfigsRecordValues]] = []
items: List[Tuple[str, Value]] = []
separator: str = "."
for k, v in raw_dict.items():
new_key = f"{parent_key}{separator}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, parent_key=new_key).items())
elif isinstance(v, get_args(ConfigsRecordValues)):
items.append((new_key, cast(ConfigsRecordValues, v)))
elif isinstance(v, get_args(Value)):
items.append((new_key, cast(Value, v)))
elif isinstance(v, list) and any(
isinstance(v, list_type) for list_type in get_args(ConfigsScalarList)
):
items.append((new_key, cast(ConfigsRecordValues, v)))
items.append((new_key, cast(Value, v)))
else:
raise ValueError(
f"The value for key {k} needs to be of type `int`, `float`, "
Expand All @@ -133,9 +129,9 @@ def flatten_dict(
def parse_config_args(
config_overrides: Optional[str],
separator: str = ",",
) -> Dict[str, ConfigsRecordValues]:
) -> Dict[str, Value]:
"""Parse separator separated list of key-value pairs separated by '='."""
overrides: Dict[str, ConfigsRecordValues] = {}
overrides: Dict[str, Value] = {}

if config_overrides is None:
return overrides
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import pytest

from flwr.common.typing import ConfigsRecordValues
from flwr.common.typing import Value

from .config import (
_fuse_dicts,
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_get_fused_config_valid(tmp_path: Path) -> None:
[flower.config.clientapp]
test = "key"
"""
overrides: Dict[str, ConfigsRecordValues] = {
overrides: Dict[str, Value] = {
"num_server_rounds": 5,
"lr": 0.2,
"serverapp.test": "overriden",
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Dict, Optional

from .record import RecordSet
from .typing import ConfigsRecordValues
from .typing import Value


@dataclass
Expand All @@ -35,7 +35,7 @@ class Context:
executing mods. It can also be used as a memory to access
at different points during the lifecycle of this entity (e.g. across
multiple rounds)
run_config : Dict[str, ConfigsRecordValues]
run_config : Dict[str, Value]
A config (key/value mapping) held by the entity in a given run and that will
stay local. It can be used at any point during the lifecycle of this entity
(e.g. across multiple rounds)
Expand All @@ -47,12 +47,12 @@ class Context:

state: RecordSet
partition_id: Optional[int]
run_config: Dict[str, ConfigsRecordValues]
run_config: Dict[str, Value]

def __init__(
self,
state: RecordSet,
run_config: Dict[str, ConfigsRecordValues],
run_config: Dict[str, Value],
partition_id: Optional[int] = None,
) -> None:
self.state = state
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/run_serverapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
from flwr.common.object_ref import load_app
from flwr.common.typing import ConfigsRecordValues
from flwr.common.typing import Value
from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
CreateRunRequest,
CreateRunResponse,
Expand All @@ -46,7 +46,7 @@
def run(
driver: Driver,
server_app_dir: str,
server_app_run_config: Dict[str, ConfigsRecordValues],
server_app_run_config: Dict[str, Value],
server_app_attr: Optional[str] = None,
loaded_server_app: Optional[ServerApp] = None,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from flwr.common import log, now
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
from flwr.common.typing import ConfigsRecordValues, Run
from flwr.common.typing import Run, Value
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.superlink.state.state import State
from flwr.server.utils import validate_task_ins_or_res
Expand Down Expand Up @@ -279,7 +279,7 @@ def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, ConfigsRecordValues],
override_config: Dict[str, Value],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
# Sample a random int64 as run_id
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from flwr.common import log, now
from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
from flwr.common.typing import ConfigsRecordValues, Run
from flwr.common.typing import Run, Value
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
Expand Down Expand Up @@ -619,7 +619,7 @@ def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, ConfigsRecordValues],
override_config: Dict[str, Value],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
# Sample a random int64 as run_id
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Dict, List, Optional, Set
from uuid import UUID

from flwr.common.typing import ConfigsRecordValues, Run
from flwr.common.typing import Run, Value
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611


Expand Down Expand Up @@ -161,7 +161,7 @@ def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, ConfigsRecordValues],
override_config: Dict[str, Value],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""

Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/superexec/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.common.serde import record_value_dict_to_proto
from flwr.common.typing import ConfigsRecordValues, ValueList
from flwr.common.typing import Value, ValueList

# pylint: disable=E0611
from flwr.proto.common_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
Expand Down Expand Up @@ -62,7 +62,7 @@ def _create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, ConfigsRecordValues],
override_config: Dict[str, Value],
) -> int:
if self.stub is None:
self._connect()
Expand All @@ -81,7 +81,7 @@ def _create_run(

@override
def start_run(
self, fab_file: bytes, override_config: Dict[str, ConfigsRecordValues]
self, fab_file: bytes, override_config: Dict[str, Value]
) -> Optional[RunTracker]:
"""Start run using the Flower Deployment Engine."""
try:
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/superexec/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from subprocess import Popen
from typing import Dict, Optional

from flwr.common.typing import ConfigsRecordValues
from flwr.common.typing import Value


@dataclass
Expand All @@ -35,7 +35,7 @@ class Executor(ABC):

@abstractmethod
def start_run(
self, fab_file: bytes, override_config: Dict[str, ConfigsRecordValues]
self, fab_file: bytes, override_config: Dict[str, Value]
) -> Optional[RunTracker]:
"""Start a run using the given Flower FAB ID and version.
Expand All @@ -46,7 +46,7 @@ def start_run(
----------
fab_file : bytes
The Flower App Bundle file bytes.
override_config : Dict[str, ConfigsRecordValues]
override_config : Dict[str, Value]
A dict containing key-value pairs to override the FAB config.
Returns
Expand Down

0 comments on commit 9bb359e

Please sign in to comment.