From 408d8820ba89b02fe938eafd1f6c50bacd0b41cf Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Wed, 12 Jun 2024 12:56:00 +0200 Subject: [PATCH] feat(framework) Add `flwr install` command (#3258) Co-authored-by: Heng Pan Co-authored-by: Taner Topal Co-authored-by: Javier --- src/py/flwr/cli/app.py | 2 + src/py/flwr/cli/build.py | 17 +-- src/py/flwr/cli/config_utils.py | 15 ++- src/py/flwr/cli/install.py | 196 +++++++++++++++++++++++++++++++ src/py/flwr/cli/utils.py | 14 +++ src/py/flwr/common/object_ref.py | 22 ++-- 6 files changed, 238 insertions(+), 28 deletions(-) create mode 100644 src/py/flwr/cli/install.py diff --git a/src/py/flwr/cli/app.py b/src/py/flwr/cli/app.py index e1417f1267ac..477b990bf1da 100644 --- a/src/py/flwr/cli/app.py +++ b/src/py/flwr/cli/app.py @@ -18,6 +18,7 @@ from .build import build from .example import example +from .install import install from .new import new from .run import run @@ -34,6 +35,7 @@ app.command()(example) app.command()(run) app.command()(build) +app.command()(install) if __name__ == "__main__": app() diff --git a/src/py/flwr/cli/build.py b/src/py/flwr/cli/build.py index ca7ab8686c5c..d279a8d11bc2 100644 --- a/src/py/flwr/cli/build.py +++ b/src/py/flwr/cli/build.py @@ -14,7 +14,6 @@ # ============================================================================== """Flower command line interface `build` command.""" -import hashlib import os import zipfile from pathlib import Path @@ -25,7 +24,7 @@ from typing_extensions import Annotated from .config_utils import load_and_validate -from .utils import is_valid_project_name +from .utils import get_sha256_hash, is_valid_project_name # pylint: disable=too-many-locals @@ -115,7 +114,7 @@ def build( fab_file.write(file_path, archive_path) # Calculate file info - sha256_hash = _get_sha256_hash(file_path) + sha256_hash = get_sha256_hash(file_path) file_size_bits = os.path.getsize(file_path) * 8 # size in bits list_file_content += f"{archive_path},{sha256_hash},{file_size_bits}\n" @@ -127,18 +126,6 @@ def build( ) -def _get_sha256_hash(file_path: Path) -> str: - """Calculate the SHA-256 hash of a file.""" - sha256 = hashlib.sha256() - with open(file_path, "rb") as f: - while True: - data = f.read(65536) # Read in 64kB blocks - if not data: - break - sha256.update(data) - return sha256.hexdigest() - - def _load_gitignore(directory: Path) -> pathspec.PathSpec: """Load and parse .gitignore file, returning a pathspec.""" gitignore_path = directory / ".gitignore" diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index d943d87e3812..ec67fefda0d2 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -24,6 +24,7 @@ def load_and_validate( path: Optional[Path] = None, + check_module: bool = True, ) -> Tuple[Optional[Dict[str, Any]], List[str], List[str]]: """Load and validate pyproject.toml as dict. @@ -42,7 +43,7 @@ def load_and_validate( ] return (None, errors, []) - is_valid, errors, warnings = validate(config) + is_valid, errors, warnings = validate(config, check_module) if not is_valid: return (None, errors, warnings) @@ -102,7 +103,9 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]] return len(errors) == 0, errors, warnings -def validate(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]: +def validate( + config: Dict[str, Any], check_module: bool = True +) -> Tuple[bool, List[str], List[str]]: """Validate pyproject.toml.""" is_valid, errors, warnings = validate_fields(config) @@ -110,12 +113,16 @@ def validate(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]: return False, errors, warnings # Validate serverapp - is_valid, reason = object_ref.validate(config["flower"]["components"]["serverapp"]) + is_valid, reason = object_ref.validate( + config["flower"]["components"]["serverapp"], check_module + ) if not is_valid and isinstance(reason, str): return False, [reason], [] # Validate clientapp - is_valid, reason = object_ref.validate(config["flower"]["components"]["clientapp"]) + is_valid, reason = object_ref.validate( + config["flower"]["components"]["clientapp"], check_module + ) if not is_valid and isinstance(reason, str): return False, [reason], [] diff --git a/src/py/flwr/cli/install.py b/src/py/flwr/cli/install.py new file mode 100644 index 000000000000..e6ce9fe1a69a --- /dev/null +++ b/src/py/flwr/cli/install.py @@ -0,0 +1,196 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower command line interface `install` command.""" + + +import os +import shutil +import tempfile +import zipfile +from pathlib import Path +from typing import Optional + +import typer +from typing_extensions import Annotated + +from .config_utils import load_and_validate +from .utils import get_sha256_hash + + +def install( + source: Annotated[ + Optional[Path], + typer.Argument(metavar="source", help="The source FAB file to install."), + ] = None, + flwr_dir: Annotated[ + Optional[Path], + typer.Option(help="The desired install path."), + ] = None, +) -> None: + """Install a Flower App Bundle. + + It can be ran with a single FAB file argument: + + ``flwr install ./target_project.fab`` + + The target install directory can be specified with ``--flwr-dir``: + + ``flwr install ./target_project.fab --flwr-dir ./docs/flwr`` + + This will install ``target_project`` to ``./docs/flwr/``. By default, + ``flwr-dir`` is equal to: + + - ``$FLWR_HOME/`` if ``$FLWR_HOME`` is defined + - ``$XDG_DATA_HOME/.flwr/`` if ``$XDG_DATA_HOME`` is defined + - ``$HOME/.flwr/`` in all other cases + """ + if source is None: + source = Path(typer.prompt("Enter the source FAB file")) + + source = source.resolve() + if not source.exists() or not source.is_file(): + typer.secho( + f"❌ The source {source} does not exist or is not a file.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + if source.suffix != ".fab": + typer.secho( + f"❌ The source {source} is not a `.fab` file.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + install_from_fab(source, flwr_dir) + + +def install_from_fab( + fab_file: Path, flwr_dir: Optional[Path], skip_prompt: bool = False +) -> None: + """Install from a FAB file after extracting and validating.""" + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(fab_file, "r") as zipf: + zipf.extractall(tmpdir) + tmpdir_path = Path(tmpdir) + info_dir = tmpdir_path / ".info" + if not info_dir.exists(): + typer.secho( + "❌ FAB file has incorrect format.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + content_file = info_dir / "CONTENT" + + if not content_file.exists() or not _verify_hashes( + content_file.read_text(), tmpdir_path + ): + typer.secho( + "❌ File hashes couldn't be verified.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + shutil.rmtree(info_dir) + + validate_and_install(tmpdir_path, fab_file.stem, flwr_dir, skip_prompt) + + +def validate_and_install( + project_dir: Path, + fab_name: str, + flwr_dir: Optional[Path], + skip_prompt: bool = False, +) -> None: + """Validate TOML files and install the project to the desired directory.""" + config, _, _ = load_and_validate(project_dir / "pyproject.toml", check_module=False) + + if config is None: + typer.secho( + "❌ Invalid config inside FAB file.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + publisher = config["flower"]["publisher"] + project_name = config["project"]["name"] + version = config["project"]["version"] + + if fab_name != f"{publisher}.{project_name}.{version.replace('.', '-')}": + typer.secho( + "❌ FAB file has incorrect name. The file name must follow the format " + "`...fab`.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) + + install_dir: Path = ( + ( + Path( + os.getenv( + "FLWR_HOME", + f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr", + ) + ) + if not flwr_dir + else flwr_dir + ) + / "apps" + / publisher + / project_name + / version + ) + if install_dir.exists() and not skip_prompt: + if not typer.confirm( + typer.style( + f"\n💬 {project_name} version {version} is already installed, " + "do you want to reinstall it?", + fg=typer.colors.MAGENTA, + bold=True, + ) + ): + return + + install_dir.mkdir(parents=True, exist_ok=True) + + # Move contents from source directory + for item in project_dir.iterdir(): + if item.is_dir(): + shutil.copytree(item, install_dir / item.name, dirs_exist_ok=True) + else: + shutil.copy2(item, install_dir / item.name) + + typer.secho( + f"🎊 Successfully installed {project_name} to {install_dir}.", + fg=typer.colors.GREEN, + bold=True, + ) + + +def _verify_hashes(list_content: str, tmpdir: Path) -> bool: + """Verify file hashes based on the LIST content.""" + for line in list_content.strip().split("\n"): + rel_path, hash_expected, _ = line.split(",") + file_path = tmpdir / rel_path + if not file_path.exists() or get_sha256_hash(file_path) != hash_expected: + return False + return True diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py index 6460b770b184..2f5a8831fa7c 100644 --- a/src/py/flwr/cli/utils.py +++ b/src/py/flwr/cli/utils.py @@ -14,7 +14,9 @@ # ============================================================================== """Flower command line interface utils.""" +import hashlib import re +from pathlib import Path from typing import Callable, List, Optional, cast import typer @@ -122,3 +124,15 @@ def sanitize_project_name(name: str) -> str: sanitized_name = sanitized_name[1:] return sanitized_name + + +def get_sha256_hash(file_path: Path) -> str: + """Calculate the SHA-256 hash of a file.""" + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + while True: + data = f.read(65536) # Read in 64kB blocks + if not data: + break + sha256.update(data) + return sha256.hexdigest() diff --git a/src/py/flwr/common/object_ref.py b/src/py/flwr/common/object_ref.py index 4660f07e24a4..b56c69a4f36b 100644 --- a/src/py/flwr/common/object_ref.py +++ b/src/py/flwr/common/object_ref.py @@ -30,6 +30,7 @@ def validate( module_attribute_str: str, + check_module: bool = True, ) -> Tuple[bool, Optional[str]]: """Validate object reference. @@ -56,15 +57,18 @@ def validate( f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}", ) - # Load module - module = find_spec(module_str) - if module and module.origin: - if not _find_attribute_in_module(module.origin, attributes_str): - return ( - False, - f"Unable to find attribute {attributes_str} in module {module_str}" - f"{OBJECT_REF_HELP_STR}", - ) + if check_module: + # Load module + module = find_spec(module_str) + if module and module.origin: + if not _find_attribute_in_module(module.origin, attributes_str): + return ( + False, + f"Unable to find attribute {attributes_str} in module {module_str}" + f"{OBJECT_REF_HELP_STR}", + ) + return (True, None) + else: return (True, None) return (