From 8ad08c38b1e5f75e35a0dbc644c4157f14c68f90 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Sun, 7 Jul 2024 15:02:28 +0200 Subject: [PATCH 1/8] feat(framework) Add utility functions for config parsing and handling --- src/py/flwr/common/config.py | 41 +++++++- src/py/flwr/common/config_test.py | 164 ++++++++++++++++++++++++++++++ 2 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 src/py/flwr/common/config_test.py diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 20de00a6fba9..2c1bfda40c54 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -16,7 +16,7 @@ import os from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import tomli @@ -30,7 +30,7 @@ def get_flwr_dir(provided_path: Optional[str] = None) -> Path: return Path( os.getenv( FLWR_HOME, - f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr", + Path(f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}") / ".flwr", ) ) return Path(provided_path).absolute() @@ -71,3 +71,40 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]: ) return config + + +def flatten_dict( + raw_dict: Dict[str, Any], parent_key: str = "", sep: str = "." +) -> Dict[str, str]: + """Flatten dict by joining nested keys with a given separator.""" + items: List[Tuple[str, str]] = [] + for k, v in raw_dict.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, parent_key=new_key, sep=sep).items()) + else: + items.append((new_key, str(v))) + return dict(items) + + +def parse_config_args( + config_overrides: Optional[str], +) -> Dict[str, str]: + """Parse comma separated list of key-value pairs separated by '='.""" + overrides: Dict[str, str] = {} + + if config_overrides is not None: + overrides_list = config_overrides.split(",") + if ( + len(overrides_list) == 1 + and "=" not in overrides_list + and overrides_list[0].endswith(".toml") + ): + with Path(overrides_list[0]).open("rb") as config_file: + overrides = flatten_dict(tomli.load(config_file)) + else: + for kv_pair in overrides_list: + key, value = kv_pair.split("=") + overrides[key] = value + + return overrides diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py new file mode 100644 index 000000000000..7a7793a861ac --- /dev/null +++ b/src/py/flwr/common/config_test.py @@ -0,0 +1,164 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test util functions handling Flower config.""" + +import os +import textwrap +from pathlib import Path +from unittest.mock import patch + +import pytest + +from .config import ( + flatten_dict, + get_flwr_dir, + get_project_config, + get_project_dir, + parse_config_args, +) + +# Mock constants +FAB_CONFIG_FILE = "pyproject.toml" + + +def test_get_flwr_dir_with_provided_path() -> None: + """Test get_flwr_dir with a provided valid path.""" + provided_path = "." + assert get_flwr_dir(provided_path) == Path(provided_path).absolute() + + +def test_get_flwr_dir_without_provided_path() -> None: + """Test get_flwr_dir without a provided path, using default home directory.""" + with patch.dict(os.environ, {"HOME": "/home/user"}): + assert get_flwr_dir() == Path("/home/user/.flwr") + + +def test_get_flwr_dir_with_flwr_home() -> None: + """Test get_flwr_dir with FLWR_HOME environment variable set.""" + with patch.dict(os.environ, {"FLWR_HOME": "/custom/flwr/home"}): + assert get_flwr_dir() == Path("/custom/flwr/home") + + +def test_get_flwr_dir_with_xdg_data_home() -> None: + """Test get_flwr_dir with FLWR_HOME environment variable set.""" + with patch.dict(os.environ, {"XDG_DATA_HOME": "/custom/data/home"}): + assert get_flwr_dir() == Path("/custom/data/home/.flwr") + + +def test_get_project_dir_invalid_fab_id() -> None: + """Test get_project_dir with an invalid fab_id.""" + with pytest.raises(ValueError): + get_project_dir("invalid_fab_id", "1.0.0") + + +def test_get_project_dir_valid() -> None: + """Test get_project_dir with an valid fab_id and version.""" + app_path = get_project_dir("app_name/user", "1.0.0", flwr_dir=".") + assert app_path == Path("apps") / "app_name" / "user" / "1.0.0" + + +def test_get_project_config_file_not_found() -> None: + """Test get_project_config when the configuration file is not found.""" + with pytest.raises(FileNotFoundError): + get_project_config("/invalid/dir") + + +def test_get_project_config_file_valid(tmp_path: Path) -> None: + """Test get_project_config when the configuration file is not found.""" + pyproject_toml_content = """ + [build-system] + requires = ["hatchling"] + build-backend = "hatchling.build" + + [project] + name = "fedgpt" + version = "1.0.0" + description = "" + license = {text = "Apache License (2.0)"} + dependencies = [ + "flwr[simulation]>=1.9.0,<2.0", + "numpy>=1.21.0", + ] + + [flower] + publisher = "flwrlabs" + + [flower.components] + serverapp = "fedgpt.server:app" + clientapp = "fedgpt.client:app" + + [flower.engine] + name = "simulation" # optional + + [flower.engine.simulation.supernode] + count = 10 # optional + """ + expected_config = { + "build-system": {"build-backend": "hatchling.build", "requires": ["hatchling"]}, + "project": { + "name": "fedgpt", + "version": "1.0.0", + "description": "", + "license": {"text": "Apache License (2.0)"}, + "dependencies": ["flwr[simulation]>=1.9.0,<2.0", "numpy>=1.21.0"], + }, + "flower": { + "publisher": "flwrlabs", + "components": { + "serverapp": "fedgpt.server:app", + "clientapp": "fedgpt.client:app", + }, + "engine": { + "name": "simulation", + "simulation": {"supernode": {"count": 10}}, + }, + }, + } + # Current directory + origin = Path.cwd() + + try: + # Change into the temporary directory + os.chdir(tmp_path) + with open(FAB_CONFIG_FILE, "w", encoding="utf-8") as f: + f.write(textwrap.dedent(pyproject_toml_content)) + + # Execute + config = get_project_config(tmp_path) + + # Assert + assert config == expected_config + finally: + os.chdir(origin) + + +def test_flatten_dict() -> None: + """Test flatten_dict with a nested dictionary.""" + raw_dict = {"a": {"b": {"c": "d"}}, "e": "f"} + expected = {"a.b.c": "d", "e": "f"} + assert flatten_dict(raw_dict) == expected + + +def test_parse_config_args_none() -> None: + """Test parse_config_args with None as input.""" + assert not parse_config_args(None) + + +def test_parse_config_args_overrides() -> None: + """Test parse_config_args with key-value pairs.""" + assert parse_config_args("key1=value1,key2=value2") == { + "key1": "value1", + "key2": "value2", + } From bbeeb6681fe98bd1b43013837abc5fcfb19edfd6 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 8 Jul 2024 12:52:06 +0200 Subject: [PATCH 2/8] Update config.py --- src/py/flwr/common/config.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 2c1bfda40c54..ccc3bd542acb 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -82,8 +82,12 @@ def flatten_dict( new_key = f"{parent_key}{sep}{k}" if parent_key else k if isinstance(v, dict): items.extend(flatten_dict(v, parent_key=new_key, sep=sep).items()) + if isinstance(v, str): + items.append((new_key, v)) else: - items.append((new_key, str(v))) + raise ValueError( + f"The value for {k} needs to be a string.", + ) return dict(items) From cb158a2aab7c7d7c7975a80dfc3ed1b6fc51312d Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 8 Jul 2024 12:58:55 +0200 Subject: [PATCH 3/8] Fix conditions --- src/py/flwr/common/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index ccc3bd542acb..b534c18351c5 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -82,7 +82,7 @@ def flatten_dict( new_key = f"{parent_key}{sep}{k}" if parent_key else k if isinstance(v, dict): items.extend(flatten_dict(v, parent_key=new_key, sep=sep).items()) - if isinstance(v, str): + elif isinstance(v, str): items.append((new_key, v)) else: raise ValueError( From fa0654bec7470f551e0f6268acc15e7f5d4e276a Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 8 Jul 2024 16:18:04 +0200 Subject: [PATCH 4/8] Minor refactoring --- src/py/flwr/common/config.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index b534c18351c5..d359fd915a4e 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -97,18 +97,20 @@ def parse_config_args( """Parse comma separated list of key-value pairs separated by '='.""" overrides: Dict[str, str] = {} - if config_overrides is not None: - overrides_list = config_overrides.split(",") - if ( - len(overrides_list) == 1 - and "=" not in overrides_list - and overrides_list[0].endswith(".toml") - ): - with Path(overrides_list[0]).open("rb") as config_file: - overrides = flatten_dict(tomli.load(config_file)) - else: - for kv_pair in overrides_list: - key, value = kv_pair.split("=") - overrides[key] = value + if config_overrides is None: + return overrides + + overrides_list = config_overrides.split(",") + if ( + len(overrides_list) == 1 + and "=" not in overrides_list + and overrides_list[0].endswith(".toml") + ): + with Path(overrides_list[0]).open("rb") as config_file: + overrides = flatten_dict(tomli.load(config_file)) + else: + for kv_pair in overrides_list: + key, value = kv_pair.split("=") + overrides[key] = value return overrides From e6b6d6173f4594d1c67bcc3fa88baa842004f9f3 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 8 Jul 2024 16:27:23 +0200 Subject: [PATCH 5/8] Add separator --- src/py/flwr/common/config.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index d359fd915a4e..bea145c69828 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -74,14 +74,16 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]: def flatten_dict( - raw_dict: Dict[str, Any], parent_key: str = "", sep: str = "." + raw_dict: Dict[str, Any], parent_key: str = "", separator: str = "." ) -> Dict[str, str]: """Flatten dict by joining nested keys with a given separator.""" items: List[Tuple[str, str]] = [] for k, v in raw_dict.items(): - new_key = f"{parent_key}{sep}{k}" if parent_key else k + new_key = f"{parent_key}{separator}{k}" if parent_key else k if isinstance(v, dict): - items.extend(flatten_dict(v, parent_key=new_key, sep=sep).items()) + items.extend( + flatten_dict(v, parent_key=new_key, separator=separator).items() + ) elif isinstance(v, str): items.append((new_key, v)) else: @@ -93,14 +95,15 @@ def flatten_dict( def parse_config_args( config_overrides: Optional[str], + separator: str = ",", ) -> Dict[str, str]: - """Parse comma separated list of key-value pairs separated by '='.""" + """Parse separator separated list of key-value pairs separated by '='.""" overrides: Dict[str, str] = {} if config_overrides is None: return overrides - overrides_list = config_overrides.split(",") + overrides_list = config_overrides.split(separator) if ( len(overrides_list) == 1 and "=" not in overrides_list From 67b1044d5379df99d4096b96adf461b0a0ce91f2 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 8 Jul 2024 16:35:17 +0200 Subject: [PATCH 6/8] Improve error message --- src/py/flwr/common/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index bea145c69828..34de5392cfa6 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -88,7 +88,7 @@ def flatten_dict( items.append((new_key, v)) else: raise ValueError( - f"The value for {k} needs to be a string.", + f"The value for key {k} needs to be a `str` or a `dict`.", ) return dict(items) From a82360a312dd8da1f3f3715c8b73d7a69d72ee0c Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 8 Jul 2024 16:42:05 +0200 Subject: [PATCH 7/8] Update test --- src/py/flwr/common/config_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py index 7a7793a861ac..d30f0d5cc755 100644 --- a/src/py/flwr/common/config_test.py +++ b/src/py/flwr/common/config_test.py @@ -99,11 +99,10 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None: serverapp = "fedgpt.server:app" clientapp = "fedgpt.client:app" - [flower.engine] - name = "simulation" # optional - - [flower.engine.simulation.supernode] - count = 10 # optional + [flower.config] + num_server_rounds = "10" + momentum = "0.1" + lr = "0.01" """ expected_config = { "build-system": {"build-backend": "hatchling.build", "requires": ["hatchling"]}, @@ -120,9 +119,10 @@ def test_get_project_config_file_valid(tmp_path: Path) -> None: "serverapp": "fedgpt.server:app", "clientapp": "fedgpt.client:app", }, - "engine": { - "name": "simulation", - "simulation": {"supernode": {"count": 10}}, + "config": { + "num_server_rounds": "10", + "momentum": "0.1", + "lr": "0.01", }, }, } From 6d67e2aaf847ceb578df5bd5c723d36f492ba8cf Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 8 Jul 2024 16:46:17 +0200 Subject: [PATCH 8/8] Remove configurable sep --- src/py/flwr/common/config.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index 34de5392cfa6..e615e497b808 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -73,17 +73,14 @@ def get_project_config(project_dir: Union[str, Path]) -> Dict[str, Any]: return config -def flatten_dict( - raw_dict: Dict[str, Any], parent_key: str = "", separator: str = "." -) -> Dict[str, str]: +def flatten_dict(raw_dict: Dict[str, Any], parent_key: str = "") -> Dict[str, str]: """Flatten dict by joining nested keys with a given separator.""" items: List[Tuple[str, str]] = [] + separator: str = "." for k, v in raw_dict.items(): new_key = f"{parent_key}{separator}{k}" if parent_key else k if isinstance(v, dict): - items.extend( - flatten_dict(v, parent_key=new_key, separator=separator).items() - ) + items.extend(flatten_dict(v, parent_key=new_key).items()) elif isinstance(v, str): items.append((new_key, v)) else: