Skip to content

Commit

Permalink
Move FAB validation to install.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Oct 9, 2024
1 parent b5c9d53 commit 015079a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/py/flwr/cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from flwr.common.constant import FAB_ALLOWED_EXTENSIONS, FAB_DATE, FAB_HASH_TRUNCATION

from .config_utils import load_and_validate
from .utils import get_sha256_hash, is_valid_project_name
from .utils import is_valid_project_name


def write_to_zip(
Expand Down
43 changes: 39 additions & 4 deletions src/py/flwr/cli/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from flwr.common.config import (
get_flwr_dir,
get_metadata_from_config,
get_metadata_from_fab_filename,
)
from flwr.common.constant import FAB_HASH_TRUNCATION

Expand Down Expand Up @@ -142,7 +141,7 @@ def install_from_fab(
# pylint: disable=too-many-locals
def validate_and_install(
project_dir: Path,
fab_hash: str, # pylint: disable=unused-argument
fab_hash: str,
fab_name: Optional[str],
flwr_dir: Optional[Path],
skip_prompt: bool = False,
Expand All @@ -160,10 +159,10 @@ def validate_and_install(

version, fab_id = get_metadata_from_config(config)
publisher, project_name = fab_id.split("/")
config_metadata = (publisher, project_name, version)
config_metadata = (publisher, project_name, version, fab_hash)

if fab_name:
_ = get_metadata_from_fab_filename(fab_name, config_metadata)
_validate_fab_and_config_metadata(fab_name, config_metadata)

install_dir: Path = (
(get_flwr_dir() if not flwr_dir else flwr_dir)
Expand Down Expand Up @@ -226,3 +225,39 @@ def _verify_hashes(list_content: str, tmpdir: Path) -> bool:
if not file_path.exists() or get_sha256_hash(file_path) != hash_expected:
return False
return True


def _validate_fab_and_config_metadata(
fab_name: str, config_metadata: tuple[str, str, str, str]
) -> None:
"""Validate metadata from the FAB filename and config."""
publisher, project_name, version, fab_hash = config_metadata

fab_name = fab_name.removesuffix(".fab")

fab_publisher, fab_project_name, fab_version, fab_shorthash = fab_name.split(".")

# Check FAB filename format
if (
f"{fab_publisher}.{fab_project_name}.{fab_version}"
!= f"{publisher}.{project_name}.{version}"
or len(fab_shorthash) != FAB_HASH_TRUNCATION # Verify hash length
):
raise ValueError(
"❌ FAB file has incorrect name. The file name must follow the format "
"`<publisher>.<project_name>.<version>.<8hexchars>.fab`.",
)

# Verify hash is a valid hexadecimal
try:
_ = int(fab_shorthash, 16)
except Exception as e:
raise ValueError(
"❌ FAB file has an invalid hexadecimal string `{fab_shorthash}`."
) from e

# Verify shorthash matches
if fab_shorthash != fab_hash[:FAB_HASH_TRUNCATION]:
raise ValueError(
"❌ The hash in the FAB file name does not match the hash of the FAB."
)
1 change: 0 additions & 1 deletion src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Flower command line interface `run` command."""

import hashlib
import json
import subprocess
import sys
Expand Down
50 changes: 1 addition & 49 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Provide functions for managing global Flower config."""

import sys
import os
import re
from pathlib import Path
Expand All @@ -23,12 +22,7 @@
import tomli

from flwr.cli.config_utils import get_fab_config, validate_fields
from flwr.common.constant import (
APP_DIR,
FAB_CONFIG_FILE,
FAB_HASH_TRUNCATION,
FLWR_HOME,
)
from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME
from flwr.common.typing import Run, UserConfig, UserConfigValue


Expand Down Expand Up @@ -221,45 +215,3 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
config["project"]["version"],
f"{config['tool']['flwr']['app']['publisher']}/{config['project']['name']}",
)


def get_metadata_from_fab_filename(
fab_file: Union[Path, str], config_metadata: tuple[str, str, str]
) -> Optional[tuple[str, str, str, str]]:
"""Extract metadata from the FAB filename."""

fab_file_name: str
publisher, project_name, version = config_metadata

if isinstance(fab_file, Path):
fab_file_name = fab_file.stem
elif isinstance(fab_file, str):
fab_file_name = fab_file.removesuffix(".fab")

fab_publisher, fab_project_name, fab_version, fab_shorthash = fab_file_name.split(
"."
)

if (
f"{fab_publisher}.{fab_project_name}.{fab_version}"
!= f"{publisher}.{project_name}.{version}"
or len(fab_shorthash) != FAB_HASH_TRUNCATION # Verify hash length
):
raise ValueError(
"❌ FAB file has incorrect name. The file name must follow the format "
"`<publisher>.<project_name>.<version>.<8hexchars>.fab`.",
)

try:
_ = int(fab_shorthash, 16) # Verify hash is a valid hexadecimal
except Exception as e:
raise ValueError(
"❌ FAB file has an invalid hexadecimal string `{fab_shorthash}`."
) from e

return (
fab_publisher,
fab_project_name,
fab_version.replace("-", "."),
fab_shorthash,
)

0 comments on commit 015079a

Please sign in to comment.