diff --git a/src/py/flwr/cli/build.py b/src/py/flwr/cli/build.py index 1031a100514b..4c9dca4ebcf1 100644 --- a/src/py/flwr/cli/build.py +++ b/src/py/flwr/cli/build.py @@ -29,7 +29,7 @@ from flwr.common.constant import FAB_ALLOWED_EXTENSIONS, FAB_DATE, FAB_HASH_TRUNCATION from .config_utils import load_and_validate -from .utils import get_sha256_hash, is_valid_project_name +from .utils import is_valid_project_name def write_to_zip( diff --git a/src/py/flwr/cli/install.py b/src/py/flwr/cli/install.py index 67186bca0676..7ec200918088 100644 --- a/src/py/flwr/cli/install.py +++ b/src/py/flwr/cli/install.py @@ -28,7 +28,6 @@ from flwr.common.config import ( get_flwr_dir, get_metadata_from_config, - get_metadata_from_fab_filename, ) from flwr.common.constant import FAB_HASH_TRUNCATION @@ -142,7 +141,7 @@ def install_from_fab( # pylint: disable=too-many-locals def validate_and_install( project_dir: Path, - fab_hash: str, # pylint: disable=unused-argument + fab_hash: str, fab_name: Optional[str], flwr_dir: Optional[Path], skip_prompt: bool = False, @@ -160,10 +159,10 @@ def validate_and_install( version, fab_id = get_metadata_from_config(config) publisher, project_name = fab_id.split("/") - config_metadata = (publisher, project_name, version) + config_metadata = (publisher, project_name, version, fab_hash) if fab_name: - _ = get_metadata_from_fab_filename(fab_name, config_metadata) + _validate_fab_and_config_metadata(fab_name, config_metadata) install_dir: Path = ( (get_flwr_dir() if not flwr_dir else flwr_dir) @@ -226,3 +225,39 @@ def _verify_hashes(list_content: str, tmpdir: Path) -> bool: if not file_path.exists() or get_sha256_hash(file_path) != hash_expected: return False return True + + +def _validate_fab_and_config_metadata( + fab_name: str, config_metadata: tuple[str, str, str, str] +) -> None: + """Validate metadata from the FAB filename and config.""" + publisher, project_name, version, fab_hash = config_metadata + + fab_name = fab_name.removesuffix(".fab") + + fab_publisher, fab_project_name, fab_version, fab_shorthash = fab_name.split(".") + + # Check FAB filename format + if ( + f"{fab_publisher}.{fab_project_name}.{fab_version}" + != f"{publisher}.{project_name}.{version}" + or len(fab_shorthash) != FAB_HASH_TRUNCATION # Verify hash length + ): + raise ValueError( + "❌ FAB file has incorrect name. The file name must follow the format " + "`...<8hexchars>.fab`.", + ) + + # Verify hash is a valid hexadecimal + try: + _ = int(fab_shorthash, 16) + except Exception as e: + raise ValueError( + "❌ FAB file has an invalid hexadecimal string `{fab_shorthash}`." + ) from e + + # Verify shorthash matches + if fab_shorthash != fab_hash[:FAB_HASH_TRUNCATION]: + raise ValueError( + "❌ The hash in the FAB file name does not match the hash of the FAB." + ) diff --git a/src/py/flwr/cli/run/run.py b/src/py/flwr/cli/run/run.py index 26379b482b87..bc9877ec8fb0 100644 --- a/src/py/flwr/cli/run/run.py +++ b/src/py/flwr/cli/run/run.py @@ -14,7 +14,6 @@ # ============================================================================== """Flower command line interface `run` command.""" -import hashlib import json import subprocess import sys diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index aee48facd320..071d41a3ab5e 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -14,7 +14,6 @@ # ============================================================================== """Provide functions for managing global Flower config.""" -import sys import os import re from pathlib import Path @@ -23,12 +22,7 @@ import tomli from flwr.cli.config_utils import get_fab_config, validate_fields -from flwr.common.constant import ( - APP_DIR, - FAB_CONFIG_FILE, - FAB_HASH_TRUNCATION, - FLWR_HOME, -) +from flwr.common.constant import APP_DIR, FAB_CONFIG_FILE, FLWR_HOME from flwr.common.typing import Run, UserConfig, UserConfigValue @@ -221,45 +215,3 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]: config["project"]["version"], f"{config['tool']['flwr']['app']['publisher']}/{config['project']['name']}", ) - - -def get_metadata_from_fab_filename( - fab_file: Union[Path, str], config_metadata: tuple[str, str, str] -) -> Optional[tuple[str, str, str, str]]: - """Extract metadata from the FAB filename.""" - - fab_file_name: str - publisher, project_name, version = config_metadata - - if isinstance(fab_file, Path): - fab_file_name = fab_file.stem - elif isinstance(fab_file, str): - fab_file_name = fab_file.removesuffix(".fab") - - fab_publisher, fab_project_name, fab_version, fab_shorthash = fab_file_name.split( - "." - ) - - if ( - f"{fab_publisher}.{fab_project_name}.{fab_version}" - != f"{publisher}.{project_name}.{version}" - or len(fab_shorthash) != FAB_HASH_TRUNCATION # Verify hash length - ): - raise ValueError( - "❌ FAB file has incorrect name. The file name must follow the format " - "`...<8hexchars>.fab`.", - ) - - try: - _ = int(fab_shorthash, 16) # Verify hash is a valid hexadecimal - except Exception as e: - raise ValueError( - "❌ FAB file has an invalid hexadecimal string `{fab_shorthash}`." - ) from e - - return ( - fab_publisher, - fab_project_name, - fab_version.replace("-", "."), - fab_shorthash, - )