diff --git a/RELEASE.md b/RELEASE.md index cc8b9032c9..d9bdda2197 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -2,6 +2,7 @@ ## Major features and improvements * Kedro commands are now lazily loaded to add performance gains when running Kedro commands. +* Can use unpacking with parameter dictionaries. ## Bug fixes and other changes * Updated error message for invalid catalog entries. @@ -22,6 +23,8 @@ * Extended documentation with an example of logging customisation at runtime ## Community contributions +Many thanks to the following Kedroids for contributing PRs to this release: +* [bpmeek](https://github.com/bpmeek/) # Release 0.19.6 diff --git a/kedro/pipeline/modular_pipeline.py b/kedro/pipeline/modular_pipeline.py index 9eb4caba16..8d102da09b 100644 --- a/kedro/pipeline/modular_pipeline.py +++ b/kedro/pipeline/modular_pipeline.py @@ -123,7 +123,11 @@ def _get_dataset_names_mapping( def _normalize_param_name(name: str) -> str: """Make sure that a param name has a `params:` prefix before passing to the node""" - return name if name.startswith("params:") else f"params:{name}" + return ( + name + if name.startswith("params:") or name.startswith("**params:") + else f"params:{name}" + ) def _get_param_names_mapping( @@ -251,6 +255,11 @@ def _map_transcode_base(name: str) -> str: base_name, transcode_suffix = _transcode_split(name) return TRANSCODING_SEPARATOR.join((mapping[base_name], transcode_suffix)) + def _matches_unpackable(name: str) -> bool: + param_base = name.split(".")[0] + matches = [True for key, value in mapping.items() if f"**{param_base}" in key] + return any(matches) + def _rename(name: str) -> str: rules = [ # if name mapped to new name, update with new name @@ -259,6 +268,8 @@ def _rename(name: str) -> str: (_is_all_parameters, lambda n: n), # if transcode base is mapped to a new name, update with new base (_is_transcode_base_in_mapping, _map_transcode_base), + # if name refers to dictionary to be unpacked, leave as is + (lambda n: _matches_unpackable(name), lambda n: n), # if name refers to a single parameter and a namespace is given, apply prefix (lambda n: bool(namespace) and _is_single_parameter(n), _prefix_param), # if namespace given for a dataset, prefix name using that namespace diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index 1b718689c5..9be2dfa49a 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -119,6 +119,8 @@ def __init__( # noqa: PLR0913 _node_error_message("it must have some 'inputs' or 'outputs'.") ) + inputs = _unpacked_params(func, inputs) + self._validate_inputs(func, inputs) self._func = func @@ -683,3 +685,31 @@ def _get_readable_func_name(func: Callable) -> str: name = "" return name + + +def _unpacked_params( + func: Callable, inputs: None | str | list[str] | dict[str, str] +) -> None | str | list[str] | dict[str, str]: + """Iterate over Node inputs to see if they need to be unpacked. + + Returns: + Either original inputs if no input was unpacked or a list of inputs if an input was unpacked. + """ + use_new = False + new_inputs = [] + _func_arguments = [arg for arg in inspect.signature(func).parameters] + for idx, _input in enumerate(_to_list(inputs)): + if _input.startswith("**params"): + if "**" in str(inspect.signature(func)): + raise TypeError( + "Function side dictionary unpacking is currently incompatible with parameter dictionary unpacking." + ) + use_new = True + dict_root = _input.split(":")[-1] + for param in _func_arguments[idx:]: + new_inputs.append(f"params:{dict_root}.{param}") + else: + new_inputs.append(_input) + if use_new: + return new_inputs + return inputs diff --git a/tests/pipeline/test_node.py b/tests/pipeline/test_node.py index 8798faa273..7c1861beed 100644 --- a/tests/pipeline/test_node.py +++ b/tests/pipeline/test_node.py @@ -24,6 +24,14 @@ def triconcat(input1: str, input2: str, input3: str): return input1 + input2 + input3 # pragma: no cover +def dict_unpack(): + return dict(input2="input2", input3="another node") + + +def kwargs_input(**kwargs): + return kwargs + + @pytest.fixture def simple_tuple_node_list(): return [ @@ -125,6 +133,47 @@ def test_inputs_list(self): ) assert dummy_node.inputs == ["input1", "input2", "another node"] + def test_inputs_unpack_str(self): + dummy_node = node(triconcat, inputs="**params:dict_unpack", outputs="output1") + assert dummy_node.inputs == [ + "params:dict_unpack.input1", + "params:dict_unpack.input2", + "params:dict_unpack.input3", + ] + + def test_inputs_unpack_list(self): + dummy_node = node( + triconcat, + inputs=["input1", "**params:dict_unpack"], + outputs=["output1", "output2", "last node"], + ) + assert dummy_node.inputs == [ + "input1", + "params:dict_unpack.input2", + "params:dict_unpack.input3", + ] + + def test_inputs_unpack_dict(self): + dummy_node = node( + triconcat, + inputs={"input1": "**params:dict_unpack"}, + outputs=["output1", "output2", "last node"], + ) + assert dummy_node.inputs == [ + "params:dict_unpack.input1", + "params:dict_unpack.input2", + "params:dict_unpack.input3", + ] + + def test_kwargs_node_negative(self): + pattern = "parameter dictionary unpacking" + with pytest.raises(TypeError, match=pattern): + node( + kwargs_input, + inputs="**params:dict_unpack", + outputs="output1", + ) + def test_outputs_none(self): dummy_node = node(identity, "input", None) assert dummy_node.outputs == []