From 427733bb144b72b51b626c4cf920083e1fb95698 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Tue, 3 Sep 2024 18:53:18 -0700 Subject: [PATCH] [patch] Add a utility to the API for finding node classes (#439) * [patch] Add a utility to the API for finding node classes Per @JNmpi's request https://github.com/pyiron/pyiron-xyflow/commit/679faf35ed47b0efb6b194015b8342cfac9015dd#r146023624 * Remove finding from the API So we can iterate on it while still making weaker semantic version upgrades * Correctly verify local definition And refactor the boolean flag to be in line with the other two * Add tests * Semi-expose the method in the API (there but private) * Fix relative pathing * Format black --------- Co-authored-by: pyiron-runner --- pyiron_workflow/__init__.py | 3 ++ pyiron_workflow/find.py | 57 +++++++++++++++++++++++++++++++++++ tests/static/demo_nodes.py | 7 +++++ tests/unit/test_find.py | 59 +++++++++++++++++++++++++++++++++++++ 4 files changed, 126 insertions(+) create mode 100644 pyiron_workflow/find.py create mode 100644 tests/unit/test_find.py diff --git a/pyiron_workflow/__init__.py b/pyiron_workflow/__init__.py index 91cd87657..d7cfc9364 100644 --- a/pyiron_workflow/__init__.py +++ b/pyiron_workflow/__init__.py @@ -57,3 +57,6 @@ available_backends, TypeNotFoundError, ) +from pyiron_workflow.find import ( + find_nodes as _find_nodes, # Not formally in API -- don't rely on interface +) diff --git a/pyiron_workflow/find.py b/pyiron_workflow/find.py new file mode 100644 index 000000000..07b49e9dd --- /dev/null +++ b/pyiron_workflow/find.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import inspect +import importlib.util +from pathlib import Path +import sys +from types import ModuleType + +from pyiron_workflow.node import Node + + +def _get_subclasses( + source: str | Path | ModuleType, + base_class: type, + get_private: bool = False, + get_abstract: bool = False, + get_imports_too: bool = False, +): + if isinstance(source, (str, Path)): + source = Path(source) + if source.is_file(): + # Load the module from the file + module_name = source.stem + spec = importlib.util.spec_from_file_location(module_name, str(source)) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + else: + raise ValueError("File path does not point to a valid file") + elif inspect.ismodule(source): + module = source + else: + raise ValueError("Input must be a module or a valid file path") + + return [ + obj + for name, obj in inspect.getmembers(module, inspect.isclass) + if ( + issubclass(obj, base_class) + and (get_private or not name.startswith("_")) + and (get_abstract or not inspect.isabstract(obj)) + and (get_imports_too or _locally_defined(obj, module)) + ) + ] + + +def _locally_defined(obj, module): + obj_module_name = obj.__module__ + obj_module = importlib.import_module(obj_module_name) + return obj_module.__file__ == module.__file__ + + +def find_nodes(source: str | Path | ModuleType) -> list[type[Node]]: + """ + Get a list of all public, non-abstract nodes defined in the source. + """ + return _get_subclasses(source, Node) diff --git a/tests/static/demo_nodes.py b/tests/static/demo_nodes.py index 328ecd80f..950ea6c9e 100644 --- a/tests/static/demo_nodes.py +++ b/tests/static/demo_nodes.py @@ -5,6 +5,7 @@ from typing import Optional from pyiron_workflow import Workflow +from pyiron_workflow.nodes.standard import Add as NotDefinedLocally @Workflow.wrap.as_function_node("sum") @@ -32,3 +33,9 @@ def dynamic(x): Dynamic = Workflow.wrap.as_function_node(dynamic) + + +@Workflow.wrap.as_function_node("y") +def _APrivateNode(x): + """A node, but named to indicate it is private""" + return x + 1 diff --git a/tests/unit/test_find.py b/tests/unit/test_find.py new file mode 100644 index 000000000..01dd0092f --- /dev/null +++ b/tests/unit/test_find.py @@ -0,0 +1,59 @@ +from pathlib import Path +import unittest + +from pyiron_workflow._tests import ensure_tests_in_python_path +from pyiron_workflow.find import find_nodes + + +class TestFind(unittest.TestCase): + def test_find_nodes(self): + """ + We compare names instead of direct `is` comparisons with the imported objects + because the class factories are being forced to create new classes on repeated + import, i.e. we don't leverage classfactory's ability to make the dynamic + classes be the same object. + This is because users might _intentionally_ be re-calling the factories, e.g. + with new output labels, and we then _want_ new classes to get generated. + There is probably a workaround that lets us have our cake and eat it to (i.e. + only generate new classes when they are strictly needed), but we don't have it + now. + """ + demo_nodes_file = str( + Path(__file__).parent.joinpath("..", "static", "demo_nodes.py").resolve() + ) + found_by_string = find_nodes(demo_nodes_file) + path = Path(demo_nodes_file) + found_by_path = find_nodes(path) + + ensure_tests_in_python_path() + from static import demo_nodes + found_by_module = find_nodes(demo_nodes) + + self.assertListEqual( + [o.__name__ for o in found_by_path], + [o.__name__ for o in found_by_string], + msg=f"You should find the same thing regardless of source representation;" + f"by path got {found_by_path} and by string got {found_by_string}" + ) + self.assertListEqual( + [o.__name__ for o in found_by_string], + [o.__name__ for o in found_by_module], + msg=f"You should find the same thing regardless of source representation;" + f"by string got {found_by_string} and by module got {found_by_module}" + ) + self.assertListEqual( + [o.__name__ for o in found_by_string], + [ + demo_nodes.AddPlusOne.__name__, + demo_nodes.AddThree.__name__, + demo_nodes.Dynamic.__name__, + demo_nodes.OptionallyAdd.__name__ + ], + msg=f"Should match a hand-selected expectation list that ignores the " + f"private and non-local nodes. If you update the demo nodes this may " + f"fail and need to be trivially updated. Got {found_by_module}" + ) + + +if __name__ == '__main__': + unittest.main()