generated from pyiron/pyiron_module_template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[patch] Add a utility to the API for finding node classes (#439)
* [patch] Add a utility to the API for finding node classes Per @JNmpi's request pyiron/pyironFlow@679faf3#r146023624 * Remove finding from the API So we can iterate on it while still making weaker semantic version upgrades * Correctly verify local definition And refactor the boolean flag to be in line with the other two * Add tests * Semi-expose the method in the API (there but private) * Fix relative pathing * Format black --------- Co-authored-by: pyiron-runner <[email protected]>
- Loading branch information
1 parent
903b8ea
commit 427733b
Showing
4 changed files
with
126 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import annotations | ||
|
||
import inspect | ||
import importlib.util | ||
from pathlib import Path | ||
import sys | ||
from types import ModuleType | ||
|
||
from pyiron_workflow.node import Node | ||
|
||
|
||
def _get_subclasses( | ||
source: str | Path | ModuleType, | ||
base_class: type, | ||
get_private: bool = False, | ||
get_abstract: bool = False, | ||
get_imports_too: bool = False, | ||
): | ||
if isinstance(source, (str, Path)): | ||
source = Path(source) | ||
if source.is_file(): | ||
# Load the module from the file | ||
module_name = source.stem | ||
spec = importlib.util.spec_from_file_location(module_name, str(source)) | ||
module = importlib.util.module_from_spec(spec) | ||
sys.modules[module_name] = module | ||
spec.loader.exec_module(module) | ||
else: | ||
raise ValueError("File path does not point to a valid file") | ||
elif inspect.ismodule(source): | ||
module = source | ||
else: | ||
raise ValueError("Input must be a module or a valid file path") | ||
|
||
return [ | ||
obj | ||
for name, obj in inspect.getmembers(module, inspect.isclass) | ||
if ( | ||
issubclass(obj, base_class) | ||
and (get_private or not name.startswith("_")) | ||
and (get_abstract or not inspect.isabstract(obj)) | ||
and (get_imports_too or _locally_defined(obj, module)) | ||
) | ||
] | ||
|
||
|
||
def _locally_defined(obj, module): | ||
obj_module_name = obj.__module__ | ||
obj_module = importlib.import_module(obj_module_name) | ||
return obj_module.__file__ == module.__file__ | ||
|
||
|
||
def find_nodes(source: str | Path | ModuleType) -> list[type[Node]]: | ||
""" | ||
Get a list of all public, non-abstract nodes defined in the source. | ||
""" | ||
return _get_subclasses(source, Node) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from pathlib import Path | ||
import unittest | ||
|
||
from pyiron_workflow._tests import ensure_tests_in_python_path | ||
from pyiron_workflow.find import find_nodes | ||
|
||
|
||
class TestFind(unittest.TestCase): | ||
def test_find_nodes(self): | ||
""" | ||
We compare names instead of direct `is` comparisons with the imported objects | ||
because the class factories are being forced to create new classes on repeated | ||
import, i.e. we don't leverage classfactory's ability to make the dynamic | ||
classes be the same object. | ||
This is because users might _intentionally_ be re-calling the factories, e.g. | ||
with new output labels, and we then _want_ new classes to get generated. | ||
There is probably a workaround that lets us have our cake and eat it to (i.e. | ||
only generate new classes when they are strictly needed), but we don't have it | ||
now. | ||
""" | ||
demo_nodes_file = str( | ||
Path(__file__).parent.joinpath("..", "static", "demo_nodes.py").resolve() | ||
) | ||
found_by_string = find_nodes(demo_nodes_file) | ||
path = Path(demo_nodes_file) | ||
found_by_path = find_nodes(path) | ||
|
||
ensure_tests_in_python_path() | ||
from static import demo_nodes | ||
found_by_module = find_nodes(demo_nodes) | ||
|
||
self.assertListEqual( | ||
[o.__name__ for o in found_by_path], | ||
[o.__name__ for o in found_by_string], | ||
msg=f"You should find the same thing regardless of source representation;" | ||
f"by path got {found_by_path} and by string got {found_by_string}" | ||
) | ||
self.assertListEqual( | ||
[o.__name__ for o in found_by_string], | ||
[o.__name__ for o in found_by_module], | ||
msg=f"You should find the same thing regardless of source representation;" | ||
f"by string got {found_by_string} and by module got {found_by_module}" | ||
) | ||
self.assertListEqual( | ||
[o.__name__ for o in found_by_string], | ||
[ | ||
demo_nodes.AddPlusOne.__name__, | ||
demo_nodes.AddThree.__name__, | ||
demo_nodes.Dynamic.__name__, | ||
demo_nodes.OptionallyAdd.__name__ | ||
], | ||
msg=f"Should match a hand-selected expectation list that ignores the " | ||
f"private and non-local nodes. If you update the demo nodes this may " | ||
f"fail and need to be trivially updated. Got {found_by_module}" | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |