Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Oct 9, 2024
1 parent 354ddde commit b5c9d53
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/py/flwr/cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def build(
write_to_zip(fab_file, str(archive_path), file_contents)

# Calculate file info
sha256_hash = get_sha256_hash(file_path)
sha256_hash = hashlib.sha256(file_contents).hexdigest()
file_size_bits = os.path.getsize(file_path) * 8 # size in bits
list_file_content += f"{archive_path},{sha256_hash},{file_size_bits}\n"

Expand Down
28 changes: 2 additions & 26 deletions src/py/flwr/cli/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,34 +160,10 @@ def validate_and_install(

version, fab_id = get_metadata_from_config(config)
publisher, project_name = fab_id.split("/")
config_metadata = (publisher, project_name, version)

if fab_name:
fab_publisher, fab_project_name, fab_version, fab_shorthash = (
get_metadata_from_fab_filename(fab_name)
)
if (
f"{fab_publisher}.{fab_project_name}.{fab_version}"
!= f"{publisher}.{project_name}.{version}"
or len(fab_shorthash) != FAB_HASH_TRUNCATION # Verify hash length
):

typer.secho(
"❌ FAB file has incorrect name. The file name must follow the format "
"`<publisher>.<project_name>.<version>.<8hexchars>.fab`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

try:
_ = int(fab_shorthash, 16) # Verify hash is a valid hexadecimal
except ValueError as e:
typer.secho(
f"❌ FAB file has an invalid hexadecimal string `{fab_shorthash}`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1) from e
_ = get_metadata_from_fab_filename(fab_name, config_metadata)

install_dir: Path = (
(get_flwr_dir() if not flwr_dir else flwr_dir)
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ def _run_with_superexec(
channel.subscribe(on_channel_state_change)
stub = ExecStub(channel)

fab_path = Path(build(app)[0])
content = fab_path.read_bytes()
fab = Fab(hashlib.sha256(content).hexdigest(), content)
fab_path, fab_hash = build(app)
content = Path(fab_path).read_bytes()
fab = Fab(fab_hash, content)

req = StartRunRequest(
fab=fab_to_proto(fab),
Expand Down
46 changes: 41 additions & 5 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Provide functions for managing global Flower config."""

import sys
import os
import re
from pathlib import Path
Expand All @@ -22,7 +23,12 @@
import tomli

from flwr.cli.config_utils import get_fab_config, validate_fields
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
from flwr.common.constant import (
APP_DIR,
FAB_CONFIG_FILE,
FAB_HASH_TRUNCATION,
FLWR_HOME,
)
from flwr.common.typing import Run, UserConfig, UserConfigValue


Expand Down Expand Up @@ -218,12 +224,42 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:


def get_metadata_from_fab_filename(
fab_file: Union[Path, str]
) -> tuple[str, str, str, str]:
fab_file: Union[Path, str], config_metadata: tuple[str, str, str]
) -> Optional[tuple[str, str, str, str]]:
"""Extract metadata from the FAB filename."""

fab_file_name: str
publisher, project_name, version = config_metadata

if isinstance(fab_file, Path):
fab_file_name = fab_file.stem
elif isinstance(fab_file, str):
fab_file_name = fab_file.removesuffix(".fab")
publisher, project_name, version, shorthash = fab_file_name.split(".")
return publisher, project_name, version.replace("-", "."), shorthash

fab_publisher, fab_project_name, fab_version, fab_shorthash = fab_file_name.split(
"."
)

if (
f"{fab_publisher}.{fab_project_name}.{fab_version}"
!= f"{publisher}.{project_name}.{version}"
or len(fab_shorthash) != FAB_HASH_TRUNCATION # Verify hash length
):
raise ValueError(
"❌ FAB file has incorrect name. The file name must follow the format "
"`<publisher>.<project_name>.<version>.<8hexchars>.fab`.",
)

try:
_ = int(fab_shorthash, 16) # Verify hash is a valid hexadecimal
except Exception as e:
raise ValueError(
"❌ FAB file has an invalid hexadecimal string `{fab_shorthash}`."
) from e

return (
fab_publisher,
fab_project_name,
fab_version.replace("-", "."),
fab_shorthash,
)

0 comments on commit b5c9d53

Please sign in to comment.