Skip to content

Commit

Permalink
Add on_output to pipe_output
Browse files Browse the repository at this point in the history
When multiple end nodes are availible we can choose onto which we will
apply pipe_output. Can be multiple.

Changes:
- Applicable gets the method on_output that allows to specify target nodes
- pipe_output inherits from NodeTransformer
- pipe_output the namespace is resolved in Node level name to avoid
  collisions
- if on_output is global it limits the all the steps to the same subset,
  we prohibit additionally for steps to have clear specification of the
  end nodes

Tested and documented.
  • Loading branch information
jernejfrank committed Oct 6, 2024
1 parent d4831ec commit 37dedd3
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 13 deletions.
135 changes: 127 additions & 8 deletions hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def __init__(
_resolvers: List[ConfigResolver] = None,
_name: Optional[str] = None,
_namespace: Union[str, None, EllipsisType] = ...,
_target: base.TargetType = None,
):
"""Instantiates an Applicable.
Expand All @@ -338,6 +339,8 @@ def __init__(
:param _resolvers: Resolvers to use for the function
:param _name: Name of the node to be created
:param _namespace: Namespace of the node to be created -- currently only single-level namespaces are supported
:param _target: Selects which target nodes it will be appended onto. By default all.
:param kwargs: Kwargs (**kwargs) to pass to the function
"""
self.fn = fn
if "_name" in kwargs:
Expand All @@ -349,6 +352,7 @@ def __init__(
self.resolvers = _resolvers if _resolvers is not None else []
self.name = _name
self.namespace = _namespace
self.target = _target

def _with_resolvers(self, *additional_resolvers: ConfigResolver) -> "Applicable":
"""Helper function for the .when* group"""
Expand Down Expand Up @@ -450,6 +454,26 @@ def named(self, name: str, namespace: NamespaceType = ...) -> "Applicable":
kwargs=self.kwargs,
)

def on_output(self, target: base.TargetType) -> "Applicable":
"""Add Target on a single function level.
This determines to which node(s) it will applies. Should match the same naming convention
as the NodeTransorfmLifecycle child class (for example NodeTransformer).
:param target: Which node(s) to apply on top of
:return: The Applicable with specified target
"""
return Applicable(
fn=self.fn,
_resolvers=self.resolvers,
_name=self.name,
_namespace=self.namespace,
_target=target if target is not None else self.target,
args=self.args,
kwargs=self.kwargs,
target_fn=self.target_fn,
)

def get_config_elements(self) -> List[str]:
"""Returns the config elements that this Applicable uses"""
out = []
Expand Down Expand Up @@ -883,7 +907,11 @@ def __init__(
# super(flow, self).__init__(*transforms, collapse=collapse, _chain=False)


class pipe_output(base.SingleNodeNodeTransformer):
class SingleTargetError(Exception):
pass


class pipe_output(base.NodeTransformer):
"""Running a series of transformation on the output of the function.
The decorated function declares the dependency, the body of the function gets executed, and then
Expand Down Expand Up @@ -913,24 +941,88 @@ def B(...):
3. You want to use the same function multiple times, but with different parameters -- while `@does`/`@parameterize` can
do this, this presents an easier way to do this, especially in a chain.
The rules for chaining nodes as the same as for pipe.
The rules for chaining nodes are the same as for pipe_input.
For extra control in case of multiple output nodes, for example after extract_field / extract_columns we can also specify the output node that we wish to mutate.
The following apply *A* to all fields while *B* only to "field_1"
.. code-block:: python
:name: Simple @pipe_output example targeting specific nodes
@extract_columns("col_1", "col_2")
def A(...):
return ...
def B(...):
return ...
@pipe_output(
step(A),
step(B).on_output("field_1"),
)
@extract_fields(
{"field_1":int, "field_2":int, "field_3":int}
)
def foo(a:int)->Dict[str,int]:
return {"field_1":1, "field_2":2, "field_3":3}
We can also do this on the global level (but cannot do on both levels at the same time). The following would apply function *A* and function *B* to only "field_1" and "field_2"
.. code-block:: python
:name: Simple @pipe_output targeting specific nodes local
@pipe_output(
step(A),
step(B),
on_output = ["field_1","field_2]
)
@extract_fields(
{"field_1":int, "field_2":int, "field_3":int}
)
def foo(a:int)->Dict[str,int]:
return {"field_1":1, "field_2":2, "field_3":3}
"""

@classmethod
def _validate_single_target_level(cls, target: base.TargetType, transforms: Tuple[Applicable]):
"""We want to make sure that target gets applied on a single level.
Either choose for each step individually what it targets or set it on the global level where
all steps will target the same node(s).
"""
if target is not None:
for transform in transforms:
if transform.target is not None:
raise SingleTargetError("Cannot have target set on pipe_output and step level.")

def __init__(
self,
*transforms: Applicable,
namespace: NamespaceType = ...,
on_output: base.TargetType = None,
collapse=False,
_chain=False,
):
"""Instantiates a `@pipe_output` decorator.
Warning: if there is a global pipe_output target, the individual Applicable.target only chooses
from the subset pre-selected from the global pipe_output target. Leave global pipe_output target
empty if you want to choose between all the nodes on the individual Applicable level.
:param transforms: step transformations to be applied, in order
:param namespace: namespace to apply to all nodes in the pipe. This can be "..." (the default), which resolves to the name of the decorated function, None (which means no namespace), or a string (which means that all nodes will be namespaced with that string). Note that you can either use this *or* namespaces inside pipe()...
:param on_output: setting the target node for all steps in the pipe. Leave empty to select all the output nodes.
:param collapse: Whether to collapse this into a single node. This is not currently supported.
:param _chain: Whether to chain the first parameter. This is the only mode that is supported. Furthermore, this is not externally exposed. @flow will make use of this.
"""
super(pipe_output, self).__init__()
pipe_output._validate_single_target_level(target=on_output, transforms=transforms)

if on_output == ...:
raise ValueError(
"Cannot apply Elipsis(...) to on_output. Use None, single string or list of strings."
)

super(pipe_output, self).__init__(target=on_output)
self.transforms = transforms
self.collapse = collapse
self.chain = _chain
Expand All @@ -944,6 +1036,27 @@ def __init__(
if self.chain:
raise NotImplementedError("@flow() is not yet supported -- this is ")

def _check_individual_target(self, node_):
"""Resolves target option on the transform level.
Adds option that we can decide for each applicable which output node it will target.
:param node_: The current output node.
:return: The set of transforms that target this node
"""
selected_transforms = []
for transform in self.transforms:
target = transform.target
if isinstance(target, str):
if node_.name == target:
selected_transforms.append(transform)
elif isinstance(target, Collection):
if node_.name in target:
selected_transforms.append(transform)
else:
selected_transforms.append(transform)

return tuple(selected_transforms)

def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
Expand All @@ -954,22 +1067,28 @@ def transform_node(
The last node is an identity to the previous one with the original name `function_name` to
represent an exit point of `pipe_output`.
"""

if len(self.transforms) < 1:
transforms = self._check_individual_target(node_)
if len(transforms) < 1:
# in case no functions in pipeline we short-circuit and return the original node
return [node_]

if self.namespace is None:
_namespace = None
elif self.namespace is ...:
_namespace = node_.name
else:
_namespace = self.namespace

original_node = node_.copy_with(name=f"{node_.name}_raw")

def __identity(foo: Any) -> Any:
return foo

transforms = self.transforms + (step(__identity).named(fn.__name__),)

transforms = transforms + (step(__identity).named(fn.__name__),)
nodes, _ = chain_transforms(
first_arg=original_node.name,
transforms=transforms,
namespace=self.namespace,
namespace=_namespace, # self.namespace,
config=config,
fn=fn,
)
Expand Down
107 changes: 102 additions & 5 deletions tests/function_modifiers/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ def _test_apply_function(foo: int, bar: int, baz: int = 100) -> int:
return foo + bar + baz


def _test_apply_function_2(foo: int) -> int:
return foo + 1


@pytest.mark.parametrize(
"args,kwargs,chain_first_param",
[
Expand Down Expand Up @@ -463,6 +459,16 @@ def result_from_downstream_function() -> int:
return 2


def test_pipe_output_single_target_level_error():
with pytest.raises(hamilton.function_modifiers.macros.SingleTargetError):
pipe_output(
step(_test_apply_function, source("bar_upstream"), baz=value(100)).on_output(
"some_node"
),
on_output="some_other_node",
)


def test_pipe_output_shortcircuit():
n = node.Node.from_fn(result_from_downstream_function)
decorator = pipe_output()
Expand Down Expand Up @@ -531,6 +537,97 @@ def test_pipe_output_inherits_null_namespace():
assert "result_from_downstream_function" in {item.name for item in nodes}


def test_pipe_output_global_on_output_all():
n1 = node.Node.from_fn(result_from_downstream_function, name="node_1")
n2 = node.Node.from_fn(result_from_downstream_function, name="node_2")

decorator = pipe_output(
step(_test_apply_function, source("bar_upstream"), baz=value(100)),
)
nodes = decorator.select_nodes(decorator.target, [n1, n2])
assert len(nodes) == 2
assert [node_.name for node_ in nodes] == ["node_1", "node_2"]


def test_pipe_output_global_on_output_string():
n1 = node.Node.from_fn(result_from_downstream_function, name="node_1")
n2 = node.Node.from_fn(result_from_downstream_function, name="node_2")

decorator = pipe_output(
step(_test_apply_function, source("bar_upstream"), baz=value(100)), on_output="node_2"
)
nodes = decorator.select_nodes(decorator.target, [n1, n2])
assert len(nodes) == 1
assert nodes[0].name == "node_2"


def test_pipe_output_global_on_output_list_strings():
n1 = node.Node.from_fn(result_from_downstream_function, name="node_1")
n2 = node.Node.from_fn(result_from_downstream_function, name="node_2")
n3 = node.Node.from_fn(result_from_downstream_function, name="node_3")

decorator = pipe_output(
step(_test_apply_function, source("bar_upstream"), baz=value(100)),
on_output=["node_1", "node_2"],
)
nodes = decorator.select_nodes(decorator.target, [n1, n2, n3])
assert len(nodes) == 2
assert [node_.name for node_ in nodes] == ["node_1", "node_2"]


def test_pipe_output_elipsis_error():
with pytest.raises(ValueError):
pipe_output(
step(_test_apply_function, source("bar_upstream"), baz=value(100)), on_output=...
)


def test_pipe_output_local_on_output_string():
n1 = node.Node.from_fn(result_from_downstream_function, name="node_1")
n2 = node.Node.from_fn(result_from_downstream_function, name="node_2")

decorator = pipe_output(
step(_test_apply_function, source("bar_upstream"), baz=value(100))
.named("correct_transform")
.on_output("node_2"),
step(_test_apply_function, source("bar_upstream"), baz=value(100))
.named("wrong_transform")
.on_output("node_3"),
)
steps = decorator._check_individual_target(n1)
assert len(steps) == 0
steps = decorator._check_individual_target(n2)
assert len(steps) == 1
assert steps[0].name == "correct_transform"


def test_pipe_output_local_on_output_list_string():
n1 = node.Node.from_fn(result_from_downstream_function, name="node_1")
n2 = node.Node.from_fn(result_from_downstream_function, name="node_2")
n3 = node.Node.from_fn(result_from_downstream_function, name="node_3")

decorator = pipe_output(
step(_test_apply_function, source("bar_upstream"), baz=value(100))
.named("correct_transform_list")
.on_output(["node_2", "node_3"]),
step(_test_apply_function, source("bar_upstream"), baz=value(100))
.named("correct_transform_string")
.on_output("node_2"),
step(_test_apply_function, source("bar_upstream"), baz=value(100))
.named("wrong_transform")
.on_output("node_5"),
)
steps = decorator._check_individual_target(n1)
assert len(steps) == 0
steps = decorator._check_individual_target(n2)
assert len(steps) == 2
assert steps[0].name == "correct_transform_list"
assert steps[1].name == "correct_transform_string"
steps = decorator._check_individual_target(n3)
assert len(steps) == 1
assert steps[0].name == "correct_transform_list"


def test_pipe_output_end_to_end_simple():
dr = driver.Builder().with_config({"calc_c": True}).build()

Expand All @@ -552,7 +649,7 @@ def test_pipe_output_end_to_end_simple():
assert result["downstream_f"] == result["chain_not_using_pipe_output"]


def test_pipe_output_end_to_end_1():
def test_pipe_output_end_to_end():
dr = (
driver.Builder()
.with_modules(tests.resources.pipe_output)
Expand Down

0 comments on commit 37dedd3

Please sign in to comment.