Skip to content

Commit

Permalink
Add project name validation to flwr new (#3241)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Apr 11, 2024
1 parent af1c375 commit 7fdc309
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
20 changes: 16 additions & 4 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
import typer
from typing_extensions import Annotated

from ..utils import prompt_options, prompt_text
from ..utils import (
is_valid_project_name,
prompt_options,
prompt_text,
sanitize_project_name,
)


class MlFramework(str, Enum):
Expand Down Expand Up @@ -81,6 +86,16 @@ def new(
] = None,
) -> None:
"""Create new Flower project."""
if project_name is None:
project_name = prompt_text("Please provide project name")
if not is_valid_project_name(project_name):
project_name = prompt_text(
"Please provide a name that only contains "
"characters in {'_', 'a-zA-Z', '0-9'}",
predicate=is_valid_project_name,
default=sanitize_project_name(project_name),
)

print(
typer.style(
f"🔨 Creating Flower project {project_name}...",
Expand All @@ -89,9 +104,6 @@ def new(
)
)

if project_name is None:
project_name = prompt_text("Please provide project name")

if framework is not None:
framework_str = str(framework.value)
else:
Expand Down
64 changes: 60 additions & 4 deletions src/py/flwr/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,23 @@
# ==============================================================================
"""Flower command line interface utils."""

from typing import List, cast
from typing import Callable, List, Optional, cast

import typer


def prompt_text(text: str) -> str:
def prompt_text(
text: str,
predicate: Callable[[str], bool] = lambda _: True,
default: Optional[str] = None,
) -> str:
"""Ask user to enter text input."""
while True:
result = typer.prompt(
typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True)
typer.style(f"\n💬 {text}", fg=typer.colors.MAGENTA, bold=True),
default=default,
)
if len(result) > 0:
if predicate(result) and len(result) > 0:
break
print(typer.style("❌ Invalid entry", fg=typer.colors.RED, bold=True))

Expand Down Expand Up @@ -65,3 +70,54 @@ def prompt_options(text: str, options: List[str]) -> str:

result = options[int(index)]
return result


def is_valid_project_name(name: str) -> bool:
"""Check if the given string is a valid Python module name.
A valid module name must start with a letter or an underscore, and can only contain
letters, digits, and underscores.
"""
if not name:
return False

# Check if the first character is a letter or underscore
if not (name[0].isalpha() or name[0] == "_"):
return False

# Check if the rest of the characters are valid (letter, digit, or underscore)
for char in name[1:]:
if not (char.isalnum() or char == "_"):
return False

return True


def sanitize_project_name(name: str) -> str:
"""Sanitize the given string to make it a valid Python module name.
This version replaces hyphens with underscores, removes any characters not allowed
in Python module names, makes the string lowercase, and ensures it starts with a
valid character.
"""
# Replace '-' with '_'
name_with_underscores = name.replace("-", "_").replace(" ", "_")

# Allowed characters in a module name: letters, digits, underscore
allowed_chars = set(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"
)

# Make the string lowercase
sanitized_name = name_with_underscores.lower()

# Remove any characters not allowed in Python module names
sanitized_name = "".join(c for c in sanitized_name if c in allowed_chars)

# Ensure the first character is a letter or underscore
if sanitized_name and (
sanitized_name[0].isdigit() or sanitized_name[0] not in allowed_chars
):
sanitized_name = "_" + sanitized_name

return sanitized_name

0 comments on commit 7fdc309

Please sign in to comment.