Skip to content

Commit

Permalink
[patch] Remove the need to double-decorate (#398)
Browse files Browse the repository at this point in the history
If you're making something a dataclass node, you can now pass in a dataclass-compliant class that is not yet a dataclass and we'll take care of turning it into one for you. That means you only need `@Workflow.wrap.as_dataclass_node` and not `@Workflow.wrap.as_dataclass_node;@dataclass`. Per Joerg's wishlist.
  • Loading branch information
liamhuber authored Jul 31, 2024
1 parent 2f7a7d0 commit 4f1fd93
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 40 deletions.
22 changes: 11 additions & 11 deletions pyiron_workflow/nodes/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import is_dataclass, MISSING
from dataclasses import dataclass as as_dataclass, is_dataclass, MISSING
import itertools
from typing import Any, ClassVar, Optional

Expand Down Expand Up @@ -380,16 +380,13 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]:
def dataclass_node_factory(
dataclass: type, use_cache: bool = True, /
) -> type[DataclassNode]:
if not is_dataclass(dataclass):
raise TypeError(
f"{DataclassNode} expected to get a dataclass but {dataclass} failed "
f"`dataclasses.is_dataclass`."
)
if type(dataclass) is not type:
raise TypeError(
f"{DataclassNode} expected to get a dataclass but {dataclass} is not "
f"type `type`."
)
if not is_dataclass(dataclass):
dataclass = as_dataclass(dataclass)
return (
f"{DataclassNode.__name__}{dataclass.__name__}",
(DataclassNode,),
Expand All @@ -415,7 +412,9 @@ def as_dataclass_node(dataclass: type, use_cache: bool = True):
channel values at class defintion (instantiation).
Args:
dataclass (type): A dataclass, i.e. class passing `dataclasses.is_dataclass`.
dataclass (type): A dataclass, i.e. class passing `dataclasses.is_dataclass`,
or class definition that will be automatically wrapped with
`dataclasses.dataclass`.
use_cache (bool): Whether nodes of this type should default to caching their
values. (Default is True.)
Expand All @@ -432,7 +431,6 @@ def as_dataclass_node(dataclass: type, use_cache: bool = True):
... return [1, 2, 3]
>>>
>>> @Workflow.wrap.as_dataclass_node
... @dataclass
... class Foo:
... necessary: str
... bar: str = "bar"
Expand Down Expand Up @@ -471,7 +469,9 @@ def dataclass_node(dataclass: type, use_cache: bool = True, *node_args, **node_k
channel values at class defintion (instantiation).
Args:
dataclass (type): A dataclass, i.e. class passing `dataclasses.is_dataclass`.
dataclass (type): A dataclass, i.e. class passing `dataclasses.is_dataclass`,
or class variable that will be automatically passed to
`dataclasses.dataclass`.
use_cache (bool): Whether this node should default to caching its values.
(Default is True.)
*node_args: Other :class:`Node` positional arguments.
Expand All @@ -489,8 +489,8 @@ def dataclass_node(dataclass: type, use_cache: bool = True, *node_args, **node_k
>>> def some_list():
... return [1, 2, 3]
>>>
>>> @dataclass
... class Foo:
>>> #@dataclass # Works on actual dataclasses as well as dataclass-like classes
>>> class Foo:
... necessary: str
... bar: str = "bar"
... answer: int = 42
Expand Down
81 changes: 52 additions & 29 deletions tests/unit/nodes/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ def some_generator():
return [1, 2, 3]

with self.subTest("From instantiator"):
@dataclass

class DC:
"""Doesn't even have to be an actual dataclass, just dataclass-like"""
necessary: str
with_default: int = 42
with_factory: list = field(default_factory=some_generator)
Expand All @@ -148,6 +149,19 @@ class DC:
DC,
msg="Underlying dataclass should be accessible"
)
self.assertTrue(
is_dataclass(n.dataclass),
msg="Underlying dataclass should be a real dataclass"
)
self.assertTrue(
is_dataclass(DC),
msg="Note that passing the underlying dataclass variable through the "
"`dataclasses.dataclass` operator actually transforms it, so it "
"too is now a real dataclass, even though it wasn't defined as "
"one! This is just a side effect. I don't see it being harmful, "
"but in case it gives some future reader trouble, I want to "
"explicitly note the side effect here in the tests."
)
self.assertListEqual(
list(DC.__dataclass_fields__.keys()),
n.inputs.labels,
Expand Down Expand Up @@ -196,35 +210,44 @@ class DecoratedDC:
with_default: int = 42
with_factory: list = field(default_factory=some_generator)

n_cls = DecoratedDC(label="decorated_instance")
@as_dataclass_node
class DecoratedDCLike:
necessary: str
with_default: int = 42
with_factory: list = field(default_factory=some_generator)

self.assertTrue(
is_dataclass(n_cls.dataclass),
msg="Underlying dataclass should be available on node class"
)
prev = n_cls.preview_inputs()
key = random.choice(list(prev.keys()))
self.assertIs(
n_cls._dataclass_fields[key].type,
prev[key][0],
msg="Spot-check input type hints are pulled from dataclass fields"
)
self.assertIs(
prev["necessary"][1],
NOT_DATA,
msg="Field has no default"
)
self.assertEqual(
n_cls._dataclass_fields["with_default"].default,
prev["with_default"][1],
msg="Fields with default should get scraped"
)
self.assertIs(
prev["with_factory"][1],
NOT_DATA,
msg="Fields with default factory won't see their default until "
"instantiation"
)
for n_cls, style in zip(
[DecoratedDC(label="dcinst"), DecoratedDCLike(label="dcinst")],
["Actual dataclass", "Dataclass-like class"]
):
with self.subTest(style):
self.assertTrue(
is_dataclass(n_cls.dataclass),
msg="Underlying dataclass should be available on node class"
)
prev = n_cls.preview_inputs()
key = random.choice(list(prev.keys()))
self.assertIs(
n_cls._dataclass_fields[key].type,
prev[key][0],
msg="Spot-check input type hints are pulled from dataclass fields"
)
self.assertIs(
prev["necessary"][1],
NOT_DATA,
msg="Field has no default"
)
self.assertEqual(
n_cls._dataclass_fields["with_default"].default,
prev["with_default"][1],
msg="Fields with default should get scraped"
)
self.assertIs(
prev["with_factory"][1],
NOT_DATA,
msg="Fields with default factory won't see their default until "
"instantiation"
)


if __name__ == '__main__':
Expand Down

0 comments on commit 4f1fd93

Please sign in to comment.