Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make project creation flow parse add-ons consistently #3265

Merged
merged 15 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 42 additions & 40 deletions kedro/framework/cli/starters.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,15 @@ class KedroStarterSpec: # noqa: too-few-public-methods
kedro new --addons=none
"""

ADD_ONS_DICT = {
ADD_ONS_SHORTNAME_TO_NUMBER = {
"lint": "1",
"test": "2",
"log": "3",
"docs": "4",
"data": "5",
"pyspark": "6",
}
NUMBER_TO_ADD_ONS_NAME = {
Comment on lines +119 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit :)

Suggested change
ADD_ONS_SHORTNAME_TO_NUMBER = {
"lint": "1",
"test": "2",
"log": "3",
"docs": "4",
"data": "5",
"pyspark": "6",
}
NUMBER_TO_ADD_ONS_NAME = {
ADD_ON_SHORTNAME_TO_NUMBER = {
"lint": "1",
"test": "2",
"log": "3",
"docs": "4",
"data": "5",
"pyspark": "6",
}
NUMBER_TO_ADD_ON_NAME = {

"1": "Linting",
"2": "Testing",
"3": "Custom Logging",
Expand All @@ -125,6 +133,7 @@ class KedroStarterSpec: # noqa: too-few-public-methods
"6": "Pyspark",
}


NAME_ARG_HELP = "The name of your new Kedro project."

# noqa: unused-argument
Expand Down Expand Up @@ -213,13 +222,13 @@ def _validate_range(start, end):

def _validate_selection(add_ons: list[str]):
for add_on in add_ons:
if int(add_on) < 1 or int(add_on) > len(ADD_ONS_DICT):
if add_on not in NUMBER_TO_ADD_ONS_NAME:
message = f"'{add_on}' is not a valid selection.\nPlease select from the available add-ons: 1, 2, 3, 4, 5, 6." # nosec
click.secho(message, fg="red", err=True)
sys.exit(1)

if add_ons_str == "all":
return list(ADD_ONS_DICT)
return list(NUMBER_TO_ADD_ONS_NAME)
if add_ons_str == "none":
return []
# Guard clause if add_ons_str is None, which can happen if prompts.yml is removed
Expand Down Expand Up @@ -322,15 +331,15 @@ def new( # noqa: too-many-arguments
shutil.rmtree(tmpdir, onerror=_remove_readonly)

# Obtain config, either from a file or from interactive user prompts.
config = _get_config(
extra_context = _get_extra_context(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is passed as extra_context to cookiecutter, so I think it's better just name it extra_context.

prompts_required=prompts_required,
config_path=config_path,
cookiecutter_context=cookiecutter_context,
selected_addons=selected_addons,
project_name=project_name,
)

cookiecutter_args = _make_cookiecutter_args(config, checkout, directory)
cookiecutter_args = _make_cookiecutter_args(extra_context, checkout, directory)

project_template = fetch_template_based_on_add_ons(template_path, cookiecutter_args)

Expand Down Expand Up @@ -367,14 +376,14 @@ def list_starters():
)


def _get_config(
def _get_extra_context(
prompts_required: dict,
config_path: str,
cookiecutter_context: OrderedDict,
selected_addons: str,
project_name: str,
) -> dict[str, str]:
"""Generates a config dictionary to be used to generate cookiecutter args, based
"""Generates a config dictionary that will be passed to cookiecutter as `extra_context`, based
on CLI flags, user prompts, or a configuration file.

Args:
Expand All @@ -392,60 +401,53 @@ def _get_config(
the prompts_required dictionary, with all the redundant information removed.
"""
if not prompts_required:
config = {}
extra_context = {}
if config_path:
config = _fetch_config_from_file(config_path)
_validate_config_file_inputs(config)
extra_context = _fetch_config_from_file(config_path)
_validate_config_file_inputs(extra_context)

elif config_path:
config = _fetch_config_from_file(config_path)
_validate_config_file_against_prompts(config, prompts_required)
_validate_config_file_inputs(config)
extra_context = _fetch_config_from_file(config_path)
_validate_config_file_against_prompts(extra_context, prompts_required)
_validate_config_file_inputs(extra_context)
else:
config = _fetch_config_from_user_prompts(prompts_required, cookiecutter_context)
extra_context = _fetch_config_from_user_prompts(
prompts_required, cookiecutter_context
)

add_ons = _get_addons_from_cli_input(selected_addons)
add_ons = _convert_addon_names_to_numbers(selected_addons)

if add_ons is not None:
config["add_ons"] = add_ons
extra_context["add_ons"] = add_ons

if project_name is not None:
config["project_name"] = project_name
extra_context["project_name"] = project_name

return config
return extra_context


def _get_addons_from_cli_input(selected_addons: str) -> str:
def _convert_addon_names_to_numbers(selected_addons: str) -> str:
"""Prepares add-on selection from the CLI input to the correct format
to be put in the project configuration, if it exists.
Replaces add-on strings with the corresponding prompt number.

Args:
selected_addons: a string containing the value for the --addons flag,
or None in case the flag wasn't used.
or None in case the flag wasn't used, i.e. lint,docs.

Returns:
String with the numbers corresponding to the desired add_ons, or
None in case the --addons flag was not used.
"""
string_to_number = {
"lint": "1",
"test": "2",
"log": "3",
"docs": "4",
"data": "5",
"pyspark": "6",
}

if selected_addons is not None:
addons = selected_addons.split(",")
for i in range(len(addons)):
addon = addons[i].strip()
if addon in string_to_number:
addons[i] = string_to_number[addon]
return ",".join(addons)
if selected_addons is None:
return None

return None
addons = []
for addon in selected_addons.split(","):
addon_short_name = addon.strip()
if addon_short_name in ADD_ONS_SHORTNAME_TO_NUMBER:
addons.append(ADD_ONS_SHORTNAME_TO_NUMBER[addon_short_name])
return ",".join(addons)


def _select_prompts_to_display(
Expand All @@ -466,7 +468,7 @@ def _select_prompts_to_display(
Returns:
the prompts_required dictionary, with all the redundant information removed.
"""
valid_addons = ["lint", "test", "log", "docs", "data", "pyspark", "all", "none"]
valid_addons = list(ADD_ONS_SHORTNAME_TO_NUMBER) + ["all", "none"]

if selected_addons is not None:
addons = re.sub(r"\s", "", selected_addons).split(",")
Expand All @@ -488,7 +490,7 @@ def _select_prompts_to_display(
del prompts_required["add_ons"]

if project_name is not None:
if bool(re.match(r"^[\w -]{2,}$", project_name)) is False:
if not re.match(r"^[\w -]{2,}$", project_name):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the boolean expression is not needed. re.match either return an object or None

click.secho(
"Kedro project names must contain only alphanumeric symbols, spaces, underscores and hyphens and be at least 2 characters long",
fg="red",
Expand Down Expand Up @@ -558,7 +560,7 @@ def _make_cookiecutter_args(
add_ons = config.get("add_ons")
if add_ons:
config["add_ons"] = [
ADD_ONS_DICT[add_on] for add_on in _parse_add_ons_input(add_ons) # type: ignore
NUMBER_TO_ADD_ONS_NAME[add_on] for add_on in _parse_add_ons_input(add_ons) # type: ignore
]
config["add_ons"] = str(config["add_ons"])

Expand Down
20 changes: 1 addition & 19 deletions tests/framework/cli/test_starters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_OFFICIAL_STARTER_SPECS,
TEMPLATE_PATH,
KedroStarterSpec,
_convert_addon_names_to_numbers,
_parse_add_ons_input,
)

Expand Down Expand Up @@ -68,25 +69,6 @@ def _make_cli_prompt_input_without_name(
return "\n".join([add_ons, repo_name, python_package])


def _convert_addon_names_to_numbers(selected_addons: str):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is actually more explicit than the _get_addons_from_cli_input. It is only used in test originally, I've rename the original function and remove this from test.

string_to_number = {
"lint": "1",
"test": "2",
"log": "3",
"docs": "4",
"data": "5",
"pyspark": "6",
}

addons = selected_addons.split(",")
for i in range(len(addons)):
addon = addons[i].strip()
if addon in string_to_number:
addons[i] = string_to_number[addon]

return ",".join(addons)


def _get_expected_files(add_ons: str):
add_ons_template_files = {
"1": 0,
Expand Down