Skip to content

Commit

Permalink
Improve the flower.toml loading module (#3136)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanertopal authored Mar 13, 2024
1 parent f316049 commit d510b35
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 79 deletions.
51 changes: 42 additions & 9 deletions src/py/flwr/cli/flower_toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,45 @@

import tomli

from flwr.common.object_ref import validate
from flwr.common import object_ref


def load_flower_toml(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
def load_and_validate_with_defaults(
path: Optional[str] = None,
) -> Tuple[Optional[Dict[str, Any]], List[str], List[str]]:
"""Load and validate flower.toml as dict.
Returns
-------
Tuple[Optional[config], List[str], List[str]]
A tuple with the optional config in case it exists and is valid
and associated errors and warnings.
"""
config = load(path)

if config is None:
errors = [
"Project configuration could not be loaded. flower.toml does not exist."
]
return (None, errors, [])

is_valid, errors, warnings = validate(config)

if not is_valid:
return (None, errors, warnings)

# Apply defaults
defaults = {
"flower": {
"engine": {"name": "simulation", "simulation": {"supernode": {"num": 2}}}
}
}
config = apply_defaults(config, defaults)

return (config, errors, warnings)


def load(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Load flower.toml and return as dict."""
if path is None:
cur_dir = os.getcwd()
Expand All @@ -38,9 +73,7 @@ def load_flower_toml(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
return data


def validate_flower_toml_fields(
config: Dict[str, Any]
) -> Tuple[bool, List[str], List[str]]:
def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
"""Validate flower.toml fields."""
errors = []
warnings = []
Expand Down Expand Up @@ -72,20 +105,20 @@ def validate_flower_toml_fields(
return len(errors) == 0, errors, warnings


def validate_flower_toml(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
def validate(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
"""Validate flower.toml."""
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

if not is_valid:
return False, errors, warnings

# Validate serverapp
is_valid, reason = validate(config["flower"]["components"]["serverapp"])
is_valid, reason = object_ref.validate(config["flower"]["components"]["serverapp"])
if not is_valid and isinstance(reason, str):
return False, [reason], []

# Validate clientapp
is_valid, reason = validate(config["flower"]["components"]["clientapp"])
is_valid, reason = object_ref.validate(config["flower"]["components"]["clientapp"])

if not is_valid and isinstance(reason, str):
return False, [reason], []
Expand Down
24 changes: 10 additions & 14 deletions src/py/flwr/cli/flower_toml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
import textwrap
from typing import Any, Dict

from .flower_toml import (
load_flower_toml,
validate_flower_toml,
validate_flower_toml_fields,
)
from .flower_toml import load, validate, validate_fields


def test_load_flower_toml_load_from_cwd(tmp_path: str) -> None:
Expand Down Expand Up @@ -68,7 +64,7 @@ def test_load_flower_toml_load_from_cwd(tmp_path: str) -> None:
f.write(textwrap.dedent(flower_toml_content))

# Execute
config = load_flower_toml()
config = load()

# Assert
assert config == expected_config
Expand Down Expand Up @@ -119,7 +115,7 @@ def test_load_flower_toml_from_path(tmp_path: str) -> None:
f.write(textwrap.dedent(flower_toml_content))

# Execute
config = load_flower_toml(path=os.path.join(tmp_path, "flower.toml"))
config = load(path=os.path.join(tmp_path, "flower.toml"))

# Assert
assert config == expected_config
Expand All @@ -133,7 +129,7 @@ def test_validate_flower_toml_fields_empty() -> None:
config: Dict[str, Any] = {}

# Execute
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
Expand All @@ -155,7 +151,7 @@ def test_validate_flower_toml_fields_no_flower() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
Expand All @@ -178,7 +174,7 @@ def test_validate_flower_toml_fields_no_flower_components() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
Expand All @@ -201,7 +197,7 @@ def test_validate_flower_toml_fields_no_server_and_client_app() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
Expand All @@ -224,7 +220,7 @@ def test_validate_flower_toml_fields() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

# Assert
assert is_valid
Expand Down Expand Up @@ -252,7 +248,7 @@ def test_validate_flower_toml() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml(config)
is_valid, errors, warnings = validate(config)

# Assert
assert is_valid
Expand Down Expand Up @@ -280,7 +276,7 @@ def test_validate_flower_toml_fail() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml(config)
is_valid, errors, warnings = validate(config)

# Assert
assert not is_valid
Expand Down
78 changes: 22 additions & 56 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,64 +18,34 @@

import typer

from flwr.cli.flower_toml import apply_defaults, load_flower_toml, validate_flower_toml
from flwr.cli import flower_toml
from flwr.simulation.run_simulation import _run_simulation


def run() -> None:
"""Run Flower project."""
print(
typer.style("Loading project configuration... ", fg=typer.colors.BLUE),
end="",
)
config = load_flower_toml()
if not config:
print(
typer.style(
"Project configuration could not be loaded. "
"flower.toml does not exist.",
fg=typer.colors.RED,
bold=True,
)
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)

config, errors, warnings = flower_toml.load_and_validate_with_defaults()

if config is None:
typer.secho(
"Project configuration could not be loaded.\nflower.toml is invalid:\n"
+ "\n".join([f"- {line}" for line in errors]),
fg=typer.colors.RED,
bold=True,
)
sys.exit()
print(typer.style("Success", fg=typer.colors.GREEN))

print(
typer.style("Validating project configuration... ", fg=typer.colors.BLUE),
end="",
)
is_valid, errors, warnings = validate_flower_toml(config)
if warnings:
print(
typer.style(
"Project configuration is missing the following "
"recommended properties:\n"
+ "\n".join([f"- {line}" for line in warnings]),
fg=typer.colors.RED,
bold=True,
)
)

if not is_valid:
print(
typer.style(
"Project configuration could not be loaded.\nflower.toml is invalid:\n"
+ "\n".join([f"- {line}" for line in errors]),
fg=typer.colors.RED,
bold=True,
)
typer.secho(
"Project configuration is missing the following "
"recommended properties:\n" + "\n".join([f"- {line}" for line in warnings]),
fg=typer.colors.RED,
bold=True,
)
sys.exit()
print(typer.style("Success", fg=typer.colors.GREEN))

# Apply defaults
defaults = {
"flower": {
"engine": {"name": "simulation", "simulation": {"supernode": {"num": 2}}}
}
}
config = apply_defaults(config, defaults)
typer.secho("Success", fg=typer.colors.GREEN)

server_app_ref = config["flower"]["components"]["serverapp"]
client_app_ref = config["flower"]["components"]["clientapp"]
Expand All @@ -84,19 +54,15 @@ def run() -> None:
if engine == "simulation":
num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]

print(
typer.style("Starting run... ", fg=typer.colors.BLUE),
)
typer.secho("Starting run... ", fg=typer.colors.BLUE)
_run_simulation(
server_app_attr=server_app_ref,
client_app_attr=client_app_ref,
num_supernodes=num_supernodes,
)
else:
print(
typer.style(
f"Engine '{engine}' is not yet supported in `flwr run`",
fg=typer.colors.RED,
bold=True,
)
typer.secho(
f"Engine '{engine}' is not yet supported in `flwr run`",
fg=typer.colors.RED,
bold=True,
)

0 comments on commit d510b35

Please sign in to comment.