Skip to content

Commit

Permalink
[patch] Add a utility to the API for finding node classes (#439)
Browse files Browse the repository at this point in the history
* [patch] Add a utility to the API for finding node classes

Per @JNmpi's request pyiron/pyironFlow@679faf3#r146023624

* Remove finding from the API

So we can iterate on it while still making weaker semantic version upgrades

* Correctly verify local definition

And refactor the boolean flag to be in line with the other two

* Add tests

* Semi-expose the method in the API (there but private)

* Fix relative pathing

* Format black

---------

Co-authored-by: pyiron-runner <[email protected]>
  • Loading branch information
liamhuber and pyiron-runner authored Sep 4, 2024
1 parent 903b8ea commit 427733b
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pyiron_workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@
available_backends,
TypeNotFoundError,
)
from pyiron_workflow.find import (
find_nodes as _find_nodes, # Not formally in API -- don't rely on interface
)
57 changes: 57 additions & 0 deletions pyiron_workflow/find.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import inspect
import importlib.util
from pathlib import Path
import sys
from types import ModuleType

from pyiron_workflow.node import Node


def _get_subclasses(
source: str | Path | ModuleType,
base_class: type,
get_private: bool = False,
get_abstract: bool = False,
get_imports_too: bool = False,
):
if isinstance(source, (str, Path)):
source = Path(source)
if source.is_file():
# Load the module from the file
module_name = source.stem
spec = importlib.util.spec_from_file_location(module_name, str(source))
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
else:
raise ValueError("File path does not point to a valid file")
elif inspect.ismodule(source):
module = source
else:
raise ValueError("Input must be a module or a valid file path")

return [
obj
for name, obj in inspect.getmembers(module, inspect.isclass)
if (
issubclass(obj, base_class)
and (get_private or not name.startswith("_"))
and (get_abstract or not inspect.isabstract(obj))
and (get_imports_too or _locally_defined(obj, module))
)
]


def _locally_defined(obj, module):
obj_module_name = obj.__module__
obj_module = importlib.import_module(obj_module_name)
return obj_module.__file__ == module.__file__


def find_nodes(source: str | Path | ModuleType) -> list[type[Node]]:
"""
Get a list of all public, non-abstract nodes defined in the source.
"""
return _get_subclasses(source, Node)
7 changes: 7 additions & 0 deletions tests/static/demo_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

from pyiron_workflow import Workflow
from pyiron_workflow.nodes.standard import Add as NotDefinedLocally


@Workflow.wrap.as_function_node("sum")
Expand Down Expand Up @@ -32,3 +33,9 @@ def dynamic(x):


Dynamic = Workflow.wrap.as_function_node(dynamic)


@Workflow.wrap.as_function_node("y")
def _APrivateNode(x):
"""A node, but named to indicate it is private"""
return x + 1
59 changes: 59 additions & 0 deletions tests/unit/test_find.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from pathlib import Path
import unittest

from pyiron_workflow._tests import ensure_tests_in_python_path
from pyiron_workflow.find import find_nodes


class TestFind(unittest.TestCase):
def test_find_nodes(self):
"""
We compare names instead of direct `is` comparisons with the imported objects
because the class factories are being forced to create new classes on repeated
import, i.e. we don't leverage classfactory's ability to make the dynamic
classes be the same object.
This is because users might _intentionally_ be re-calling the factories, e.g.
with new output labels, and we then _want_ new classes to get generated.
There is probably a workaround that lets us have our cake and eat it to (i.e.
only generate new classes when they are strictly needed), but we don't have it
now.
"""
demo_nodes_file = str(
Path(__file__).parent.joinpath("..", "static", "demo_nodes.py").resolve()
)
found_by_string = find_nodes(demo_nodes_file)
path = Path(demo_nodes_file)
found_by_path = find_nodes(path)

ensure_tests_in_python_path()
from static import demo_nodes
found_by_module = find_nodes(demo_nodes)

self.assertListEqual(
[o.__name__ for o in found_by_path],
[o.__name__ for o in found_by_string],
msg=f"You should find the same thing regardless of source representation;"
f"by path got {found_by_path} and by string got {found_by_string}"
)
self.assertListEqual(
[o.__name__ for o in found_by_string],
[o.__name__ for o in found_by_module],
msg=f"You should find the same thing regardless of source representation;"
f"by string got {found_by_string} and by module got {found_by_module}"
)
self.assertListEqual(
[o.__name__ for o in found_by_string],
[
demo_nodes.AddPlusOne.__name__,
demo_nodes.AddThree.__name__,
demo_nodes.Dynamic.__name__,
demo_nodes.OptionallyAdd.__name__
],
msg=f"Should match a hand-selected expectation list that ignores the "
f"private and non-local nodes. If you update the demo nodes this may "
f"fail and need to be trivially updated. Got {found_by_module}"
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 427733b

Please sign in to comment.