Skip to content

Commit

Permalink
Reset head,
Browse files Browse the repository at this point in the history
Changed name of function,
Added checks for dictionary unpacking in function
Added negative test for dictionary unpacking in function

Signed-off-by: bpmeek <[email protected]>
  • Loading branch information
bpmeek committed Jul 18, 2024
1 parent eb4e20c commit 29fd1b3
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 1 deletion.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Major features and improvements
* Kedro commands are now lazily loaded to add performance gains when running Kedro commands.
* Can use unpacking with parameter dictionaries.

## Bug fixes and other changes
* Updated error message for invalid catalog entries.
Expand All @@ -22,6 +23,8 @@
* Extended documentation with an example of logging customisation at runtime

## Community contributions
Many thanks to the following Kedroids for contributing PRs to this release:
* [bpmeek](https://github.com/bpmeek/)

# Release 0.19.6

Expand Down
13 changes: 12 additions & 1 deletion kedro/pipeline/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ def _get_dataset_names_mapping(

def _normalize_param_name(name: str) -> str:
"""Make sure that a param name has a `params:` prefix before passing to the node"""
return name if name.startswith("params:") else f"params:{name}"
return (
name
if name.startswith("params:") or name.startswith("**params:")
else f"params:{name}"
)


def _get_param_names_mapping(
Expand Down Expand Up @@ -251,6 +255,11 @@ def _map_transcode_base(name: str) -> str:
base_name, transcode_suffix = _transcode_split(name)
return TRANSCODING_SEPARATOR.join((mapping[base_name], transcode_suffix))

def _matches_unpackable(name: str) -> bool:
param_base = name.split(".")[0]
matches = [True for key, value in mapping.items() if f"**{param_base}" in key]
return any(matches)

def _rename(name: str) -> str:
rules = [
# if name mapped to new name, update with new name
Expand All @@ -259,6 +268,8 @@ def _rename(name: str) -> str:
(_is_all_parameters, lambda n: n),
# if transcode base is mapped to a new name, update with new base
(_is_transcode_base_in_mapping, _map_transcode_base),
# if name refers to dictionary to be unpacked, leave as is
(lambda n: _matches_unpackable(name), lambda n: n),
# if name refers to a single parameter and a namespace is given, apply prefix
(lambda n: bool(namespace) and _is_single_parameter(n), _prefix_param),
# if namespace given for a dataset, prefix name using that namespace
Expand Down
30 changes: 30 additions & 0 deletions kedro/pipeline/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def __init__( # noqa: PLR0913
_node_error_message("it must have some 'inputs' or 'outputs'.")
)

inputs = _unpacked_params(func, inputs)

self._validate_inputs(func, inputs)

self._func = func
Expand Down Expand Up @@ -683,3 +685,31 @@ def _get_readable_func_name(func: Callable) -> str:
name = "<partial>"

return name


def _unpacked_params(
func: Callable, inputs: None | str | list[str] | dict[str, str]
) -> None | str | list[str] | dict[str, str]:
"""Iterate over Node inputs to see if they need to be unpacked.
Returns:
Either original inputs if no input was unpacked or a list of inputs if an input was unpacked.
"""
use_new = False
new_inputs = []
_func_arguments = [arg for arg in inspect.signature(func).parameters]
for idx, _input in enumerate(_to_list(inputs)):
if _input.startswith("**params"):
if "**" in str(inspect.signature(func)):
raise TypeError(
"Function side dictionary unpacking is currently incompatible with parameter dictionary unpacking."
)
use_new = True
dict_root = _input.split(":")[-1]
for param in _func_arguments[idx:]:
new_inputs.append(f"params:{dict_root}.{param}")
else:
new_inputs.append(_input)
if use_new:
return new_inputs
return inputs
49 changes: 49 additions & 0 deletions tests/pipeline/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ def triconcat(input1: str, input2: str, input3: str):
return input1 + input2 + input3 # pragma: no cover


def dict_unpack():
return dict(input2="input2", input3="another node")


def kwargs_input(**kwargs):
return kwargs


@pytest.fixture
def simple_tuple_node_list():
return [
Expand Down Expand Up @@ -125,6 +133,47 @@ def test_inputs_list(self):
)
assert dummy_node.inputs == ["input1", "input2", "another node"]

def test_inputs_unpack_str(self):
dummy_node = node(triconcat, inputs="**params:dict_unpack", outputs="output1")
assert dummy_node.inputs == [
"params:dict_unpack.input1",
"params:dict_unpack.input2",
"params:dict_unpack.input3",
]

def test_inputs_unpack_list(self):
dummy_node = node(
triconcat,
inputs=["input1", "**params:dict_unpack"],
outputs=["output1", "output2", "last node"],
)
assert dummy_node.inputs == [
"input1",
"params:dict_unpack.input2",
"params:dict_unpack.input3",
]

def test_inputs_unpack_dict(self):
dummy_node = node(
triconcat,
inputs={"input1": "**params:dict_unpack"},
outputs=["output1", "output2", "last node"],
)
assert dummy_node.inputs == [
"params:dict_unpack.input1",
"params:dict_unpack.input2",
"params:dict_unpack.input3",
]

def test_kwargs_node_negative(self):
pattern = "parameter dictionary unpacking"
with pytest.raises(TypeError, match=pattern):
node(
kwargs_input,
inputs="**params:dict_unpack",
outputs="output1",
)

def test_outputs_none(self):
dummy_node = node(identity, "input", None)
assert dummy_node.outputs == []
Expand Down

0 comments on commit 29fd1b3

Please sign in to comment.