Skip to content

Commit

Permalink
Merge branch 'main' into fds-rename-node-id-to-partition-id
Browse files Browse the repository at this point in the history
  • Loading branch information
flwrmachine authored Mar 13, 2024
2 parents 9fa1437 + d510b35 commit cde5412
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 81 deletions.
51 changes: 42 additions & 9 deletions src/py/flwr/cli/flower_toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,45 @@

import tomli

from flwr.common.object_ref import validate
from flwr.common import object_ref


def load_flower_toml(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
def load_and_validate_with_defaults(
path: Optional[str] = None,
) -> Tuple[Optional[Dict[str, Any]], List[str], List[str]]:
"""Load and validate flower.toml as dict.
Returns
-------
Tuple[Optional[config], List[str], List[str]]
A tuple with the optional config in case it exists and is valid
and associated errors and warnings.
"""
config = load(path)

if config is None:
errors = [
"Project configuration could not be loaded. flower.toml does not exist."
]
return (None, errors, [])

is_valid, errors, warnings = validate(config)

if not is_valid:
return (None, errors, warnings)

# Apply defaults
defaults = {
"flower": {
"engine": {"name": "simulation", "simulation": {"supernode": {"num": 2}}}
}
}
config = apply_defaults(config, defaults)

return (config, errors, warnings)


def load(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Load flower.toml and return as dict."""
if path is None:
cur_dir = os.getcwd()
Expand All @@ -38,9 +73,7 @@ def load_flower_toml(path: Optional[str] = None) -> Optional[Dict[str, Any]]:
return data


def validate_flower_toml_fields(
config: Dict[str, Any]
) -> Tuple[bool, List[str], List[str]]:
def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
"""Validate flower.toml fields."""
errors = []
warnings = []
Expand Down Expand Up @@ -72,20 +105,20 @@ def validate_flower_toml_fields(
return len(errors) == 0, errors, warnings


def validate_flower_toml(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
def validate(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
"""Validate flower.toml."""
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

if not is_valid:
return False, errors, warnings

# Validate serverapp
is_valid, reason = validate(config["flower"]["components"]["serverapp"])
is_valid, reason = object_ref.validate(config["flower"]["components"]["serverapp"])
if not is_valid and isinstance(reason, str):
return False, [reason], []

# Validate clientapp
is_valid, reason = validate(config["flower"]["components"]["clientapp"])
is_valid, reason = object_ref.validate(config["flower"]["components"]["clientapp"])

if not is_valid and isinstance(reason, str):
return False, [reason], []
Expand Down
24 changes: 10 additions & 14 deletions src/py/flwr/cli/flower_toml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
import textwrap
from typing import Any, Dict

from .flower_toml import (
load_flower_toml,
validate_flower_toml,
validate_flower_toml_fields,
)
from .flower_toml import load, validate, validate_fields


def test_load_flower_toml_load_from_cwd(tmp_path: str) -> None:
Expand Down Expand Up @@ -68,7 +64,7 @@ def test_load_flower_toml_load_from_cwd(tmp_path: str) -> None:
f.write(textwrap.dedent(flower_toml_content))

# Execute
config = load_flower_toml()
config = load()

# Assert
assert config == expected_config
Expand Down Expand Up @@ -119,7 +115,7 @@ def test_load_flower_toml_from_path(tmp_path: str) -> None:
f.write(textwrap.dedent(flower_toml_content))

# Execute
config = load_flower_toml(path=os.path.join(tmp_path, "flower.toml"))
config = load(path=os.path.join(tmp_path, "flower.toml"))

# Assert
assert config == expected_config
Expand All @@ -133,7 +129,7 @@ def test_validate_flower_toml_fields_empty() -> None:
config: Dict[str, Any] = {}

# Execute
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
Expand All @@ -155,7 +151,7 @@ def test_validate_flower_toml_fields_no_flower() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
Expand All @@ -178,7 +174,7 @@ def test_validate_flower_toml_fields_no_flower_components() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
Expand All @@ -201,7 +197,7 @@ def test_validate_flower_toml_fields_no_server_and_client_app() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

# Assert
assert not is_valid
Expand All @@ -224,7 +220,7 @@ def test_validate_flower_toml_fields() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml_fields(config)
is_valid, errors, warnings = validate_fields(config)

# Assert
assert is_valid
Expand Down Expand Up @@ -252,7 +248,7 @@ def test_validate_flower_toml() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml(config)
is_valid, errors, warnings = validate(config)

# Assert
assert is_valid
Expand Down Expand Up @@ -280,7 +276,7 @@ def test_validate_flower_toml_fail() -> None:
}

# Execute
is_valid, errors, warnings = validate_flower_toml(config)
is_valid, errors, warnings = validate(config)

# Assert
assert not is_valid
Expand Down
78 changes: 22 additions & 56 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,64 +18,34 @@

import typer

from flwr.cli.flower_toml import apply_defaults, load_flower_toml, validate_flower_toml
from flwr.cli import flower_toml
from flwr.simulation.run_simulation import _run_simulation


def run() -> None:
"""Run Flower project."""
print(
typer.style("Loading project configuration... ", fg=typer.colors.BLUE),
end="",
)
config = load_flower_toml()
if not config:
print(
typer.style(
"Project configuration could not be loaded. "
"flower.toml does not exist.",
fg=typer.colors.RED,
bold=True,
)
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)

config, errors, warnings = flower_toml.load_and_validate_with_defaults()

if config is None:
typer.secho(
"Project configuration could not be loaded.\nflower.toml is invalid:\n"
+ "\n".join([f"- {line}" for line in errors]),
fg=typer.colors.RED,
bold=True,
)
sys.exit()
print(typer.style("Success", fg=typer.colors.GREEN))

print(
typer.style("Validating project configuration... ", fg=typer.colors.BLUE),
end="",
)
is_valid, errors, warnings = validate_flower_toml(config)
if warnings:
print(
typer.style(
"Project configuration is missing the following "
"recommended properties:\n"
+ "\n".join([f"- {line}" for line in warnings]),
fg=typer.colors.RED,
bold=True,
)
)

if not is_valid:
print(
typer.style(
"Project configuration could not be loaded.\nflower.toml is invalid:\n"
+ "\n".join([f"- {line}" for line in errors]),
fg=typer.colors.RED,
bold=True,
)
typer.secho(
"Project configuration is missing the following "
"recommended properties:\n" + "\n".join([f"- {line}" for line in warnings]),
fg=typer.colors.RED,
bold=True,
)
sys.exit()
print(typer.style("Success", fg=typer.colors.GREEN))

# Apply defaults
defaults = {
"flower": {
"engine": {"name": "simulation", "simulation": {"supernode": {"num": 2}}}
}
}
config = apply_defaults(config, defaults)
typer.secho("Success", fg=typer.colors.GREEN)

server_app_ref = config["flower"]["components"]["serverapp"]
client_app_ref = config["flower"]["components"]["clientapp"]
Expand All @@ -84,19 +54,15 @@ def run() -> None:
if engine == "simulation":
num_supernodes = config["flower"]["engine"]["simulation"]["supernode"]["num"]

print(
typer.style("Starting run... ", fg=typer.colors.BLUE),
)
typer.secho("Starting run... ", fg=typer.colors.BLUE)
_run_simulation(
server_app_attr=server_app_ref,
client_app_attr=client_app_ref,
num_supernodes=num_supernodes,
)
else:
print(
typer.style(
f"Engine '{engine}' is not yet supported in `flwr run`",
fg=typer.colors.RED,
bold=True,
)
typer.secho(
f"Engine '{engine}' is not yet supported in `flwr run`",
fg=typer.colors.RED,
bold=True,
)
10 changes: 10 additions & 0 deletions src/py/flwr/client/mod/centraldp_mods.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""Clipping modifiers for central DP with client-side clipping."""


from logging import INFO

from flwr.client.typing import ClientAppCallable
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
from flwr.common import recordset_compat as compat
Expand All @@ -25,6 +27,7 @@
compute_clip_model_update,
)
from flwr.common.differential_privacy_constants import KEY_CLIPPING_NORM, KEY_NORM_BIT
from flwr.common.logger import log
from flwr.common.message import Message


Expand Down Expand Up @@ -79,6 +82,8 @@ def fixedclipping_mod(
clipping_norm,
)

log(INFO, "fixedclipping_mod: parameters are clipped by value: %s.", clipping_norm)

fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
return out_msg
Expand Down Expand Up @@ -139,6 +144,11 @@ def adaptiveclipping_mod(
server_to_client_params,
clipping_norm,
)
log(
INFO,
"adaptiveclipping_mod: parameters are clipped by value: %s.",
clipping_norm,
)

fit_res.parameters = ndarrays_to_parameters(client_to_server_params)

Expand Down
14 changes: 14 additions & 0 deletions src/py/flwr/client/mod/localdp_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
"""Local DP modifier."""


from logging import INFO

import numpy as np

from flwr.client.typing import ClientAppCallable
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
from flwr.common import recordset_compat as compat
Expand All @@ -24,6 +28,7 @@
add_localdp_gaussian_noise_to_params,
compute_clip_model_update,
)
from flwr.common.logger import log
from flwr.common.message import Message


Expand Down Expand Up @@ -122,13 +127,22 @@ def __call__(
server_to_client_params,
self.clipping_norm,
)
log(
INFO, "LocalDpMod: parameters are clipped by value: %s.", self.clipping_norm
)

fit_res.parameters = ndarrays_to_parameters(client_to_server_params)

# Add noise to model params
add_localdp_gaussian_noise_to_params(
fit_res.parameters, self.sensitivity, self.epsilon, self.delta
)
log(
INFO,
"LocalDpMod: local DP noise with "
"standard deviation: %s added to parameters.",
self.sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon,
)

out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
return out_msg
Loading

0 comments on commit cde5412

Please sign in to comment.