diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 7eb47e3e3548..bf7ca95ea021 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -22,7 +22,12 @@ import typer from typing_extensions import Annotated -from ..utils import prompt_options, prompt_text +from ..utils import ( + is_valid_project_name, + prompt_options, + prompt_text, + sanitize_project_name, +) class MlFramework(str, Enum): @@ -81,6 +86,16 @@ def new( ] = None, ) -> None: """Create new Flower project.""" + if project_name is None: + project_name = prompt_text("Please provide project name") + if not is_valid_project_name(project_name): + project_name = prompt_text( + "Please provide a name that only contains " + "characters in {'_', 'a-zA-Z', '0-9'}", + predicate=is_valid_project_name, + default=sanitize_project_name(project_name), + ) + print( typer.style( f"šŸ”Ø Creating Flower project {project_name}...", @@ -89,9 +104,6 @@ def new( ) ) - if project_name is None: - project_name = prompt_text("Please provide project name") - if framework is not None: framework_str = str(framework.value) else: diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py index 4e86f0c3b8c8..6e41abd1492e 100644 --- a/src/py/flwr/cli/utils.py +++ b/src/py/flwr/cli/utils.py @@ -14,18 +14,23 @@ # ============================================================================== """Flower command line interface utils.""" -from typing import List, cast +from typing import Callable, List, Optional, cast import typer -def prompt_text(text: str) -> str: +def prompt_text( + text: str, + predicate: Callable[[str], bool] = lambda _: True, + default: Optional[str] = None, +) -> str: """Ask user to enter text input.""" while True: result = typer.prompt( - typer.style(f"\nšŸ’¬ {text}", fg=typer.colors.MAGENTA, bold=True) + typer.style(f"\nšŸ’¬ {text}", fg=typer.colors.MAGENTA, bold=True), + default=default, ) - if len(result) > 0: + if predicate(result) and len(result) > 0: break print(typer.style("āŒ Invalid entry", fg=typer.colors.RED, bold=True)) @@ -65,3 +70,56 @@ def prompt_options(text: str, options: List[str]) -> str: result = options[int(index)] return result + + +def is_valid_project_name(name: str) -> bool: + """ + Check if the given string is a valid Python module name. + + A valid module name must start with a letter or an underscore, + and can only contain letters, digits, and underscores. + """ + + if not name: + return False + + # Check if the first character is a letter or underscore + if not (name[0].isalpha() or name[0] == "_"): + return False + + # Check if the rest of the characters are valid (letter, digit, or underscore) + for char in name[1:]: + if not (char.isalnum() or char == "_"): + return False + + return True + + +def sanitize_project_name(name: str) -> str: + """ + Sanitize the given string to make it a valid Python module name. + This version replaces hyphens with underscores, removes any characters + not allowed in Python module names, makes the string lowercase, and + ensures it starts with a valid character. + """ + # Replace '-' with '_' + name_with_underscores = name.replace("-", "_") + + # Allowed characters in a module name: letters, digits, underscore + allowed_chars = set( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" + ) + + # Make the string lowercase + sanitized_name = name_with_underscores.lower() + + # Remove any characters not allowed in Python module names + sanitized_name = "".join(c for c in sanitized_name if c in allowed_chars) + + # Ensure the first character is a letter or underscore + if sanitized_name and ( + sanitized_name[0].isdigit() or sanitized_name[0] not in allowed_chars + ): + sanitized_name = "_" + sanitized_name + + return sanitized_name