Skip to content

Commit

Permalink
Merge branch 'main' into add-driver-get-run-proto
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Jun 12, 2024
2 parents 82a399a + 408d882 commit 7300b9c
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 28 deletions.
2 changes: 2 additions & 0 deletions src/py/flwr/cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .build import build
from .example import example
from .install import install
from .new import new
from .run import run

Expand All @@ -34,6 +35,7 @@
app.command()(example)
app.command()(run)
app.command()(build)
app.command()(install)

if __name__ == "__main__":
app()
17 changes: 2 additions & 15 deletions src/py/flwr/cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Flower command line interface `build` command."""

import hashlib
import os
import zipfile
from pathlib import Path
Expand All @@ -25,7 +24,7 @@
from typing_extensions import Annotated

from .config_utils import load_and_validate
from .utils import is_valid_project_name
from .utils import get_sha256_hash, is_valid_project_name


# pylint: disable=too-many-locals
Expand Down Expand Up @@ -115,7 +114,7 @@ def build(
fab_file.write(file_path, archive_path)

# Calculate file info
sha256_hash = _get_sha256_hash(file_path)
sha256_hash = get_sha256_hash(file_path)
file_size_bits = os.path.getsize(file_path) * 8 # size in bits
list_file_content += f"{archive_path},{sha256_hash},{file_size_bits}\n"

Expand All @@ -127,18 +126,6 @@ def build(
)


def _get_sha256_hash(file_path: Path) -> str:
"""Calculate the SHA-256 hash of a file."""
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
while True:
data = f.read(65536) # Read in 64kB blocks
if not data:
break
sha256.update(data)
return sha256.hexdigest()


def _load_gitignore(directory: Path) -> pathspec.PathSpec:
"""Load and parse .gitignore file, returning a pathspec."""
gitignore_path = directory / ".gitignore"
Expand Down
15 changes: 11 additions & 4 deletions src/py/flwr/cli/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

def load_and_validate(
path: Optional[Path] = None,
check_module: bool = True,
) -> Tuple[Optional[Dict[str, Any]], List[str], List[str]]:
"""Load and validate pyproject.toml as dict.
Expand All @@ -42,7 +43,7 @@ def load_and_validate(
]
return (None, errors, [])

is_valid, errors, warnings = validate(config)
is_valid, errors, warnings = validate(config, check_module)

if not is_valid:
return (None, errors, warnings)
Expand Down Expand Up @@ -102,20 +103,26 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
return len(errors) == 0, errors, warnings


def validate(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
def validate(
config: Dict[str, Any], check_module: bool = True
) -> Tuple[bool, List[str], List[str]]:
"""Validate pyproject.toml."""
is_valid, errors, warnings = validate_fields(config)

if not is_valid:
return False, errors, warnings

# Validate serverapp
is_valid, reason = object_ref.validate(config["flower"]["components"]["serverapp"])
is_valid, reason = object_ref.validate(
config["flower"]["components"]["serverapp"], check_module
)
if not is_valid and isinstance(reason, str):
return False, [reason], []

# Validate clientapp
is_valid, reason = object_ref.validate(config["flower"]["components"]["clientapp"])
is_valid, reason = object_ref.validate(
config["flower"]["components"]["clientapp"], check_module
)

if not is_valid and isinstance(reason, str):
return False, [reason], []
Expand Down
196 changes: 196 additions & 0 deletions src/py/flwr/cli/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower command line interface `install` command."""


import os
import shutil
import tempfile
import zipfile
from pathlib import Path
from typing import Optional

import typer
from typing_extensions import Annotated

from .config_utils import load_and_validate
from .utils import get_sha256_hash


def install(
source: Annotated[
Optional[Path],
typer.Argument(metavar="source", help="The source FAB file to install."),
] = None,
flwr_dir: Annotated[
Optional[Path],
typer.Option(help="The desired install path."),
] = None,
) -> None:
"""Install a Flower App Bundle.
It can be ran with a single FAB file argument:
``flwr install ./target_project.fab``
The target install directory can be specified with ``--flwr-dir``:
``flwr install ./target_project.fab --flwr-dir ./docs/flwr``
This will install ``target_project`` to ``./docs/flwr/``. By default,
``flwr-dir`` is equal to:
- ``$FLWR_HOME/`` if ``$FLWR_HOME`` is defined
- ``$XDG_DATA_HOME/.flwr/`` if ``$XDG_DATA_HOME`` is defined
- ``$HOME/.flwr/`` in all other cases
"""
if source is None:
source = Path(typer.prompt("Enter the source FAB file"))

source = source.resolve()
if not source.exists() or not source.is_file():
typer.secho(
f"❌ The source {source} does not exist or is not a file.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

if source.suffix != ".fab":
typer.secho(
f"❌ The source {source} is not a `.fab` file.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

install_from_fab(source, flwr_dir)


def install_from_fab(
fab_file: Path, flwr_dir: Optional[Path], skip_prompt: bool = False
) -> None:
"""Install from a FAB file after extracting and validating."""
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(fab_file, "r") as zipf:
zipf.extractall(tmpdir)
tmpdir_path = Path(tmpdir)
info_dir = tmpdir_path / ".info"
if not info_dir.exists():
typer.secho(
"❌ FAB file has incorrect format.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

content_file = info_dir / "CONTENT"

if not content_file.exists() or not _verify_hashes(
content_file.read_text(), tmpdir_path
):
typer.secho(
"❌ File hashes couldn't be verified.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

shutil.rmtree(info_dir)

validate_and_install(tmpdir_path, fab_file.stem, flwr_dir, skip_prompt)


def validate_and_install(
project_dir: Path,
fab_name: str,
flwr_dir: Optional[Path],
skip_prompt: bool = False,
) -> None:
"""Validate TOML files and install the project to the desired directory."""
config, _, _ = load_and_validate(project_dir / "pyproject.toml", check_module=False)

if config is None:
typer.secho(
"❌ Invalid config inside FAB file.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

publisher = config["flower"]["publisher"]
project_name = config["project"]["name"]
version = config["project"]["version"]

if fab_name != f"{publisher}.{project_name}.{version.replace('.', '-')}":
typer.secho(
"❌ FAB file has incorrect name. The file name must follow the format "
"`<publisher>.<project_name>.<version>.fab`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

install_dir: Path = (
(
Path(
os.getenv(
"FLWR_HOME",
f"{os.getenv('XDG_DATA_HOME', os.getenv('HOME'))}/.flwr",
)
)
if not flwr_dir
else flwr_dir
)
/ "apps"
/ publisher
/ project_name
/ version
)
if install_dir.exists() and not skip_prompt:
if not typer.confirm(
typer.style(
f"\n💬 {project_name} version {version} is already installed, "
"do you want to reinstall it?",
fg=typer.colors.MAGENTA,
bold=True,
)
):
return

install_dir.mkdir(parents=True, exist_ok=True)

# Move contents from source directory
for item in project_dir.iterdir():
if item.is_dir():
shutil.copytree(item, install_dir / item.name, dirs_exist_ok=True)
else:
shutil.copy2(item, install_dir / item.name)

typer.secho(
f"🎊 Successfully installed {project_name} to {install_dir}.",
fg=typer.colors.GREEN,
bold=True,
)


def _verify_hashes(list_content: str, tmpdir: Path) -> bool:
"""Verify file hashes based on the LIST content."""
for line in list_content.strip().split("\n"):
rel_path, hash_expected, _ = line.split(",")
file_path = tmpdir / rel_path
if not file_path.exists() or get_sha256_hash(file_path) != hash_expected:
return False
return True
14 changes: 14 additions & 0 deletions src/py/flwr/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# ==============================================================================
"""Flower command line interface utils."""

import hashlib
import re
from pathlib import Path
from typing import Callable, List, Optional, cast

import typer
Expand Down Expand Up @@ -122,3 +124,15 @@ def sanitize_project_name(name: str) -> str:
sanitized_name = sanitized_name[1:]

return sanitized_name


def get_sha256_hash(file_path: Path) -> str:
"""Calculate the SHA-256 hash of a file."""
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
while True:
data = f.read(65536) # Read in 64kB blocks
if not data:
break
sha256.update(data)
return sha256.hexdigest()
22 changes: 13 additions & 9 deletions src/py/flwr/common/object_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

def validate(
module_attribute_str: str,
check_module: bool = True,
) -> Tuple[bool, Optional[str]]:
"""Validate object reference.
Expand All @@ -56,15 +57,18 @@ def validate(
f"Missing attribute in {module_attribute_str}{OBJECT_REF_HELP_STR}",
)

# Load module
module = find_spec(module_str)
if module and module.origin:
if not _find_attribute_in_module(module.origin, attributes_str):
return (
False,
f"Unable to find attribute {attributes_str} in module {module_str}"
f"{OBJECT_REF_HELP_STR}",
)
if check_module:
# Load module
module = find_spec(module_str)
if module and module.origin:
if not _find_attribute_in_module(module.origin, attributes_str):
return (
False,
f"Unable to find attribute {attributes_str} in module {module_str}"
f"{OBJECT_REF_HELP_STR}",
)
return (True, None)
else:
return (True, None)

return (
Expand Down

0 comments on commit 7300b9c

Please sign in to comment.