diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 3dcb4f13..7f8482bc 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -5,7 +5,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import is_dataclass, MISSING +from dataclasses import dataclass as as_dataclass, is_dataclass, MISSING import itertools from typing import Any, ClassVar, Optional @@ -380,16 +380,13 @@ def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: def dataclass_node_factory( dataclass: type, use_cache: bool = True, / ) -> type[DataclassNode]: - if not is_dataclass(dataclass): - raise TypeError( - f"{DataclassNode} expected to get a dataclass but {dataclass} failed " - f"`dataclasses.is_dataclass`." - ) if type(dataclass) is not type: raise TypeError( f"{DataclassNode} expected to get a dataclass but {dataclass} is not " f"type `type`." ) + if not is_dataclass(dataclass): + dataclass = as_dataclass(dataclass) return ( f"{DataclassNode.__name__}{dataclass.__name__}", (DataclassNode,), @@ -415,7 +412,9 @@ def as_dataclass_node(dataclass: type, use_cache: bool = True): channel values at class defintion (instantiation). Args: - dataclass (type): A dataclass, i.e. class passing `dataclasses.is_dataclass`. + dataclass (type): A dataclass, i.e. class passing `dataclasses.is_dataclass`, + or class definition that will be automatically wrapped with + `dataclasses.dataclass`. use_cache (bool): Whether nodes of this type should default to caching their values. (Default is True.) @@ -432,7 +431,6 @@ def as_dataclass_node(dataclass: type, use_cache: bool = True): ... return [1, 2, 3] >>> >>> @Workflow.wrap.as_dataclass_node - ... @dataclass ... class Foo: ... necessary: str ... bar: str = "bar" @@ -471,7 +469,9 @@ def dataclass_node(dataclass: type, use_cache: bool = True, *node_args, **node_k channel values at class defintion (instantiation). Args: - dataclass (type): A dataclass, i.e. class passing `dataclasses.is_dataclass`. + dataclass (type): A dataclass, i.e. class passing `dataclasses.is_dataclass`, + or class variable that will be automatically passed to + `dataclasses.dataclass`. use_cache (bool): Whether this node should default to caching its values. (Default is True.) *node_args: Other :class:`Node` positional arguments. @@ -489,8 +489,8 @@ def dataclass_node(dataclass: type, use_cache: bool = True, *node_args, **node_k >>> def some_list(): ... return [1, 2, 3] >>> - >>> @dataclass - ... class Foo: + >>> #@dataclass # Works on actual dataclasses as well as dataclass-like classes + >>> class Foo: ... necessary: str ... bar: str = "bar" ... answer: int = 42 diff --git a/tests/unit/nodes/test_transform.py b/tests/unit/nodes/test_transform.py index 363e0abe..b2d9718b 100644 --- a/tests/unit/nodes/test_transform.py +++ b/tests/unit/nodes/test_transform.py @@ -136,8 +136,9 @@ def some_generator(): return [1, 2, 3] with self.subTest("From instantiator"): - @dataclass + class DC: + """Doesn't even have to be an actual dataclass, just dataclass-like""" necessary: str with_default: int = 42 with_factory: list = field(default_factory=some_generator) @@ -148,6 +149,19 @@ class DC: DC, msg="Underlying dataclass should be accessible" ) + self.assertTrue( + is_dataclass(n.dataclass), + msg="Underlying dataclass should be a real dataclass" + ) + self.assertTrue( + is_dataclass(DC), + msg="Note that passing the underlying dataclass variable through the " + "`dataclasses.dataclass` operator actually transforms it, so it " + "too is now a real dataclass, even though it wasn't defined as " + "one! This is just a side effect. I don't see it being harmful, " + "but in case it gives some future reader trouble, I want to " + "explicitly note the side effect here in the tests." + ) self.assertListEqual( list(DC.__dataclass_fields__.keys()), n.inputs.labels, @@ -196,35 +210,44 @@ class DecoratedDC: with_default: int = 42 with_factory: list = field(default_factory=some_generator) - n_cls = DecoratedDC(label="decorated_instance") + @as_dataclass_node + class DecoratedDCLike: + necessary: str + with_default: int = 42 + with_factory: list = field(default_factory=some_generator) - self.assertTrue( - is_dataclass(n_cls.dataclass), - msg="Underlying dataclass should be available on node class" - ) - prev = n_cls.preview_inputs() - key = random.choice(list(prev.keys())) - self.assertIs( - n_cls._dataclass_fields[key].type, - prev[key][0], - msg="Spot-check input type hints are pulled from dataclass fields" - ) - self.assertIs( - prev["necessary"][1], - NOT_DATA, - msg="Field has no default" - ) - self.assertEqual( - n_cls._dataclass_fields["with_default"].default, - prev["with_default"][1], - msg="Fields with default should get scraped" - ) - self.assertIs( - prev["with_factory"][1], - NOT_DATA, - msg="Fields with default factory won't see their default until " - "instantiation" - ) + for n_cls, style in zip( + [DecoratedDC(label="dcinst"), DecoratedDCLike(label="dcinst")], + ["Actual dataclass", "Dataclass-like class"] + ): + with self.subTest(style): + self.assertTrue( + is_dataclass(n_cls.dataclass), + msg="Underlying dataclass should be available on node class" + ) + prev = n_cls.preview_inputs() + key = random.choice(list(prev.keys())) + self.assertIs( + n_cls._dataclass_fields[key].type, + prev[key][0], + msg="Spot-check input type hints are pulled from dataclass fields" + ) + self.assertIs( + prev["necessary"][1], + NOT_DATA, + msg="Field has no default" + ) + self.assertEqual( + n_cls._dataclass_fields["with_default"].default, + prev["with_default"][1], + msg="Fields with default should get scraped" + ) + self.assertIs( + prev["with_factory"][1], + NOT_DATA, + msg="Fields with default factory won't see their default until " + "instantiation" + ) if __name__ == '__main__':