diff --git a/src/faebryk/core/graphinterface.py b/src/faebryk/core/graphinterface.py index 9e836af8..a64b09cf 100644 --- a/src/faebryk/core/graphinterface.py +++ b/src/faebryk/core/graphinterface.py @@ -7,10 +7,18 @@ from faebryk.core.core import ID_REPR, FaebrykLibObject from faebryk.core.graph_backends.default import GraphImpl -from faebryk.core.link import Link, LinkDirect, LinkNamedParent +from faebryk.core.link import ( + Link, + LinkDirect, + LinkNamedParent, + LinkPointer, + LinkSibling, +) from faebryk.libs.util import ( + KeyErrorNotFound, NotNone, exceptions_to_log, + find, try_avoid_endless_recursion, ) @@ -106,7 +114,7 @@ def is_connected(self, other: "GraphInterface"): # Less graph-specific stuff # TODO make link trait to initialize from list - def connect(self, other: Self, linkcls=None) -> Self: + def connect(self, other: "GraphInterface", linkcls=None) -> Self: assert other is not self if linkcls is None: @@ -191,3 +199,22 @@ def disconnect_parent(self): class GraphInterfaceSelf(GraphInterface): ... + + +class GraphInterfaceReference[T: "Node"](GraphInterface): + """Represents a reference to a node object""" + + class UnboundError(Exception): + """Cannot resolve unbound reference""" + + def get_referenced_gif(self) -> GraphInterfaceSelf: + try: + return find( + self.get_links_by_type(LinkPointer), + lambda link: not isinstance(link, LinkSibling), + ).pointee + except KeyErrorNotFound as ex: + raise GraphInterfaceReference.UnboundError from ex + + def get_reference(self) -> T: + return self.get_referenced_gif().node diff --git a/src/faebryk/core/link.py b/src/faebryk/core/link.py index deee44a1..7ba9dfd4 100644 --- a/src/faebryk/core/link.py +++ b/src/faebryk/core/link.py @@ -5,11 +5,16 @@ from typing import TYPE_CHECKING, Callable from faebryk.core.core import LINK_TB, FaebrykLibObject +from faebryk.libs.util import is_type_pair logger = logging.getLogger(__name__) if TYPE_CHECKING: - from faebryk.core.graphinterface import GraphInterface, GraphInterfaceHierarchical + from faebryk.core.graphinterface import ( + GraphInterface, + GraphInterfaceHierarchical, + GraphInterfaceSelf, + ) class Link(FaebrykLibObject): @@ -38,13 +43,31 @@ def __repr__(self) -> str: return f"{type(self).__name__}()" -class LinkSibling(Link): - def __init__(self, interfaces: list["GraphInterface"]) -> None: +class LinkPointer(Link): + """A Link that points towards a self-gif""" + + def __init__( + self, + interfaces: list["GraphInterfaceSelf | GraphInterface"], + ) -> None: + from faebryk.core.graphinterface import GraphInterface, GraphInterfaceSelf + super().__init__() - self.interfaces = interfaces + assert len(interfaces) == 2 + + pair = is_type_pair( + interfaces[0], interfaces[1], GraphInterfaceSelf, GraphInterface + ) + if not pair: + raise TypeError("Interfaces must be one self-gif and one other-gif") + self.pointee, self.pointer = pair def get_connections(self) -> list["GraphInterface"]: - return self.interfaces + return [self.pointee, self.pointer] + + +class LinkSibling(LinkPointer): + """A link represents a connection between a self-gif and a gif in the same node""" class LinkParent(Link): @@ -83,9 +106,13 @@ def curried(interfaces: list["GraphInterface"]): class LinkDirect(Link): + """Represents a symmetrical link between two interfaces of the same type""" + def __init__(self, interfaces: list["GraphInterface"]) -> None: super().__init__() - assert len(set(map(type, interfaces))) == 1 + assert ( + len(set(map(type, interfaces))) == 1 + ), "Interfaces must be of the same type" self.interfaces = interfaces def get_connections(self) -> list["GraphInterface"]: diff --git a/src/faebryk/core/node.py b/src/faebryk/core/node.py index fd186ddc..bd24c26d 100644 --- a/src/faebryk/core/node.py +++ b/src/faebryk/core/node.py @@ -1,6 +1,8 @@ # This file is part of the faebryk project # SPDX-License-Identifier: MIT import logging +from abc import abstractmethod +from dataclasses import InitVar as dataclass_InitVar from itertools import chain from typing import ( TYPE_CHECKING, @@ -30,6 +32,7 @@ PostInitCaller, Tree, cast_assert, + debugging, find, times, try_avoid_endless_recursion, @@ -65,13 +68,34 @@ class fab_field: pass -class rt_field[T, O](property, fab_field): +class constructed_field[T: "Node", O: "Node"](property, fab_field): + """ + Field which is constructed after the node is created. + The constructor gets one argument: the node instance. + + The constructor should return the constructed faebryk object or None. + If a faebryk object is returned, it will be added to the node. + """ + + @abstractmethod + def __construct__(self, obj: T) -> O | None: + pass + + +class rt_field[T, O](constructed_field): + """ + rt_fields (runtime_fields) are the last fields excecuted before the + __preinit__ and __postinit__ functions are called. + It gives the function passed to it access to the node instance. + This is useful to do construction that depends on parameters passed by __init__. + """ + def __init__(self, fget: Callable[[T], O]) -> None: super().__init__() self.func = fget self.lookup: dict[T, O] = {} - def _construct(self, obj: T): + def __construct__(self, obj: T): constructed = self.func(obj) # TODO find a better way for this # in python 3.13 name support @@ -145,6 +169,14 @@ def __init__(self, node: "Node", other: "Node", *args: object) -> None: class NodeNoParent(NodeException): ... +class InitVar(dataclass_InitVar): + """ + This is a type-marker which instructs the Node constructor to ignore the field. + + Inspired by dataclasses.InitVar, which it inherits from. + """ + + # ----------------------------------------------------------------------------- @@ -167,7 +199,7 @@ def __hash__(self) -> int: # TODO proper hash return hash(id(self)) - def add[T: Node]( + def add[T: Node | GraphInterface]( self, obj: T, name: str | None = None, @@ -176,9 +208,10 @@ def add[T: Node]( assert obj is not None if container is None: - container = self.runtime_anon if name: container = self.runtime + else: + container = self.runtime_anon try: container_name = find(vars(self).items(), lambda x: x[1] is container)[0] @@ -197,7 +230,11 @@ def add[T: Node]( container.append(obj) name = f"{container_name}[{len(container) - 1}]" - self._handle_add_node(name, obj) + if isinstance(obj, GraphInterface): + self._handle_add_gif(name, obj) + else: + self._handle_add_node(name, obj) + return obj def add_to_container[T: Node]( @@ -230,8 +267,64 @@ def all_anno(cls): LL_Types = (Node, GraphInterface) - annos = all_anno(cls) - vars_ = all_vars(cls) + vars_ = { + name: obj + for name, obj in all_vars(cls).items() + # private fields are always ignored + if not name.startswith("_") + # only consider fab_fields + and isinstance(obj, fab_field) + } + annos = { + name: obj + for name, obj in all_anno(cls).items() + # private fields are always ignored + if not name.startswith("_") + # explicitly ignore InitVars + and not isinstance(obj, InitVar) + # variables take precedence over annos + and name not in vars_ + } + + # ensure no field annotations are a property + # If properties are constructed as instance fields, their + # getters and setters aren't called when assigning to them. + # + # This means we won't actually construct the underlying graph properly. + # It's pretty insidious because it's completely non-obvious that we're + # missing these graph connections. + # TODO: make this an exception group instead + for name, obj in annos.items(): + if (origin := get_origin(obj)) is not None: + # you can't truly subclass properties because they're a descriptor + # type, so instead we check if the origin is a property via our fields + if issubclass(origin, constructed_field): + raise FieldError( + f"{name} is a property, which cannot be created from a field " + "annotation. Please instantiate the field directly." + ) + + # FIXME: something's fucked up in the new version of this, + # but I can't for the life of me figure out what + clsfields_unf_new = dict(chain(annos.items(), vars_.items())) + clsfields_unf_old = { + name: obj + for name, obj in chain( + ( + (name, obj) + for name, obj in all_anno(cls).items() + if not isinstance(obj, InitVar) + ), + ( + (name, f) + for name, f in all_vars(cls).items() + if isinstance(f, fab_field) + ), + ) + if not name.startswith("_") + } + assert clsfields_unf_old == clsfields_unf_new + clsfields_unf = clsfields_unf_old def is_node_field(obj): def is_genalias_node(obj): @@ -260,20 +353,11 @@ def is_genalias_node(obj): if get_origin(obj): return is_genalias_node(obj) - if isinstance(obj, rt_field): + if isinstance(obj, constructed_field): return True return False - clsfields_unf = { - name: obj - for name, obj in chain( - [(name, f) for name, f in annos.items()], - [(name, f) for name, f in vars_.items() if isinstance(f, fab_field)], - ) - if not name.startswith("_") - } - nonfabfields, fabfields = partition( lambda x: is_node_field(x[1]), clsfields_unf.items() ) @@ -305,7 +389,9 @@ def handle_add(name, obj): elif isinstance(obj, Node): self._handle_add_node(name, obj) else: - assert False + raise TypeError( + f"Cannot handle adding field {name=} of type {type(obj)}" + ) def append(name, inst): if isinstance(inst, LL_Types): @@ -321,21 +407,15 @@ def append(name, inst): return inst def _setup_field(name, obj): - def setup_gen_alias(name, obj): - origin = get_origin(obj) - assert origin + if isinstance(obj, str): + raise NotImplementedError() + + if (origin := get_origin(obj)) is not None: if isinstance(origin, type): setattr(self, name, append(name, origin())) return raise NotImplementedError(origin) - if isinstance(obj, str): - raise NotImplementedError() - - if get_origin(obj): - setup_gen_alias(name, obj) - return - if isinstance(obj, _d_field): setattr(self, name, append(name, obj.default_factory())) return @@ -344,8 +424,9 @@ def setup_gen_alias(name, obj): setattr(self, name, append(name, obj())) return - if isinstance(obj, rt_field): - append(name, obj._construct(self)) + if isinstance(obj, constructed_field): + if (constructed := obj.__construct__(self)) is not None: + append(name, constructed) return raise NotImplementedError() @@ -354,13 +435,19 @@ def setup_field(name, obj): try: _setup_field(name, obj) except Exception as e: + # this is a bit of a hack to provide complete context to debuggers + # for underlying field construction errors + if debugging(): + raise raise FieldConstructionError( self, name, f'An exception occurred while constructing field "{name}"', ) from e - nonrt, rt = partition(lambda x: isinstance(x[1], rt_field), clsfields.items()) + nonrt, rt = partition( + lambda x: isinstance(x[1], constructed_field), clsfields.items() + ) for name, obj in nonrt: setup_field(name, obj) diff --git a/src/faebryk/core/reference.py b/src/faebryk/core/reference.py new file mode 100644 index 00000000..b05d7bd2 --- /dev/null +++ b/src/faebryk/core/reference.py @@ -0,0 +1,59 @@ +from collections import defaultdict + +from faebryk.core.graphinterface import GraphInterfaceReference +from faebryk.core.link import LinkPointer +from faebryk.core.node import Node, constructed_field + + +class Reference[O: Node](constructed_field): + """ + Create a simple reference to other nodes that are properly encoded in the graph. + """ + + class UnboundError(Exception): + """Cannot resolve unbound reference""" + + def __init__(self, out_type: type[O] | None = None): + self.gifs: dict[Node, GraphInterfaceReference] = defaultdict( + GraphInterfaceReference + ) + self.is_set: set[Node] = set() + + def get(instance: Node) -> O: + try: + return self.gifs[instance].get_reference() + except GraphInterfaceReference.UnboundError as ex: + raise Reference.UnboundError from ex + + def set_(instance: Node, value: O): + if instance in self.is_set: + # TypeError is also raised when attempting to assign + # to an immutable (eg. tuple) + raise TypeError( + f"{self.__class__.__name__} already set and are immutable" + ) + self.is_set.add(instance) + + if out_type is not None and not isinstance(value, out_type): + raise TypeError(f"Expected {out_type} got {type(value)}") + + # attach our gif to what we're referring to + self.gifs[instance].connect(value.self_gif, LinkPointer) + + property.__init__(self, get, set_) + + def __construct__(self, obj: Node) -> None: + # add our gif to our instance object + obj.add(self.gifs[obj]) + + # don't attach anything additional to the Node during field setup + return None + + +def reference[O: Node](out_type: type[O] | None = None) -> O | Reference: + """ + Create a simple reference to other nodes properly encoded in the graph. + + This final wrapper is primarily to fudge the typing. + """ + return Reference(out_type=out_type) diff --git a/src/faebryk/library/_F.py b/src/faebryk/library/_F.py index e5b77714..843264bf 100644 --- a/src/faebryk/library/_F.py +++ b/src/faebryk/library/_F.py @@ -42,6 +42,7 @@ from faebryk.library.has_picker import has_picker from faebryk.library.has_pcb_layout import has_pcb_layout from faebryk.library.has_pcb_routing_strategy import has_pcb_routing_strategy +from faebryk.library.has_reference import has_reference from faebryk.library.has_resistance import has_resistance from faebryk.library.has_single_connection import has_single_connection from faebryk.library.is_representable_by_single_value import is_representable_by_single_value diff --git a/src/faebryk/library/has_reference.py b/src/faebryk/library/has_reference.py new file mode 100644 index 00000000..24c8fc11 --- /dev/null +++ b/src/faebryk/library/has_reference.py @@ -0,0 +1,13 @@ +from faebryk.core.node import Node +from faebryk.core.reference import Reference +from faebryk.core.trait import Trait + + +class has_reference[T: Node](Trait): + """Trait-attached reference""" + + reference: T = Reference() + + def __init__(self, reference: T): + super().__init__() + self.reference = reference diff --git a/src/faebryk/libs/library/L.py b/src/faebryk/libs/library/L.py index 3a51110b..a795c50c 100644 --- a/src/faebryk/libs/library/L.py +++ b/src/faebryk/libs/library/L.py @@ -5,6 +5,7 @@ from faebryk.core.module import Module # noqa: F401 from faebryk.core.node import ( # noqa: F401 + InitVar, Node, d_field, f_field, diff --git a/src/faebryk/libs/util.py b/src/faebryk/libs/util.py index 425f2b3c..fe2c491c 100644 --- a/src/faebryk/libs/util.py +++ b/src/faebryk/libs/util.py @@ -940,3 +940,14 @@ def exceptions_to_log( ) if not mute: raise + + +def debugging() -> bool: + """ + Check if a debugger is connected. + """ + try: + import debugpy + except (ImportError, ModuleNotFoundError): + return False + return debugpy.is_client_connected() diff --git a/test/library/test_basic.py b/test/library/test_basic.py index aa08e031..c31eb986 100644 --- a/test/library/test_basic.py +++ b/test/library/test_basic.py @@ -1,58 +1,67 @@ # This file is part of the faebryk project # SPDX-License-Identifier: MIT -import unittest +import inspect + +import pytest from faebryk.core.core import Namespace from faebryk.core.node import Node +from faebryk.core.trait import Trait, TraitImpl from faebryk.libs.library import L +try: + import faebryk.library._F as F +except ImportError: + F = None + + +def test_load_library(): + assert F is not None, "Failed to load library" + + +@pytest.mark.skipif(F is None, reason="Library not loaded") +@pytest.mark.parametrize("name, module", list(vars(F).items())) +def test_symbol_types(name: str, module): + # private symbols get a pass + if name.startswith("_"): + return + + # skip once wrappers + # allow once wrappers for type generators + if getattr(module, "_is_once_wrapper", False): + return + + # otherwise, only allow Node or Namespace class objects + assert isinstance(module, type) and issubclass(module, (Node, Namespace)) + + +@pytest.mark.skipif(F is None, reason="Library not loaded") +@pytest.mark.parametrize( + "name, module", + [ + (name, module) + for name, module in vars(F).items() + if not ( + name.startswith("_") + or not isinstance(module, type) + or not issubclass(module, Node) + or (issubclass(module, Trait) and not issubclass(module, TraitImpl)) + ) + ], +) +def test_init_args(name: str, module): + """Make sure we can instantiate all classes without error""" + + # check if constructor has no args & no varargs + init_signature = inspect.signature(module.__init__) + if len(init_signature.parameters) > 1 or any( + param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + for param in init_signature.parameters.values() + ): + pytest.skip("Skipped module with init args because we can't instantiate it") -class TestBasicLibrary(unittest.TestCase): - def test_load_library(self): - import faebryk.library._F # noqa: F401 - - def test_symbol_types(self): - import faebryk.library._F as F - - symbols = { - k: v - for k, v in vars(F).items() - if not k.startswith("_") - and (not isinstance(v, type) or not issubclass(v, (Node, Namespace))) - # allow once wrappers for type generators - and not getattr(v, "_is_once_wrapper", False) - } - self.assertFalse(symbols, f"Found unexpected symbols: {symbols}") - - def test_imports(self): - import faebryk.library._F as F - from faebryk.core.trait import Trait, TraitImpl - - # get all symbols in F - symbols = { - k: v - for k, v in vars(F).items() - if not k.startswith("_") - and isinstance(v, type) - and issubclass(v, Node) - # check if constructor has no args & no varargs - and ( - v.__init__.__code__.co_argcount == 1 - and not v.__init__.__code__.co_flags & 0x04 - ) - # no trait base - and (not issubclass(v, Trait) or issubclass(v, TraitImpl)) - } - - for k, v in symbols.items(): - try: - v() - except L.AbstractclassError: - pass - except Exception as e: - self.fail(f"Failed to instantiate {k}: {e}") - - -if __name__ == "__main__": - unittest.main() + try: + module() + except L.AbstractclassError: + pytest.skip("Skipped abstract class") diff --git a/test/library/test_reference.py b/test/library/test_reference.py new file mode 100644 index 00000000..efb76f2b --- /dev/null +++ b/test/library/test_reference.py @@ -0,0 +1,137 @@ +import pytest + +from faebryk.core.node import FieldError, Node +from faebryk.core.reference import Reference + + +def test_points_to_correct_node(): + class A(Node): + pass + + class B(Node): + x = Reference(A) + + a = A() + b = B() + b.x = a + assert b.x is a + + +def test_immutable(): + class A(Node): + pass + + class B(Node): + x = Reference(A) + + b = B() + a = A() + b.x = a + + with pytest.raises(TypeError): + b.x = A() + + +def test_unset(): + class A(Node): + pass + + class B(Node): + x = Reference(A) + + b = B() + with pytest.raises(Reference.UnboundError): + b.x + + +def test_wrong_type(): + class A(Node): + pass + + class B(Node): + x = Reference(A) + + b = B() + with pytest.raises(TypeError): + b.x = 1 + + +def test_set_value_before_constuction(): + class A(Node): + pass + + class B(Node): + x = Reference(A) + + def __init__(self, x): + self.x = x + + a = A() + b = B(a) + assert b.x is a + + +def test_get_value_before_constuction(): + class A(Node): + pass + + class B(Node): + x = Reference(A) + y = Reference(A) + + def __init__(self, x): + self.x = x + self.y = self.x + + a = A() + b = B(a) + assert b.y is a + + +def test_typed_construction_doesnt_work(): + class B(Node): + x: Reference + + # check using the property directly that everything is working + with pytest.raises(AttributeError): + B.x + + +def test_typed_construction_protection(): + """ + Ensure references aren't constructed as a field + + If properties are constructed as instance fields, their + getters and setters aren't called when assigning to them. + + This means we won't actually construct the underlying graph properly. + It's pretty insidious because it's completely non-obvious that we're + missing these graph connections. + """ + + class A(Node): + pass + + class B(Node): + x: Reference[A] + + with pytest.raises(FieldError): + B() + + +def test_underlying_property_explicitly(): + class A(Node): + pass + + class B(Node): + x = Reference(A) + + a = A() + b = B() + b.x = a + + # check using the property directly that everything is working + assert B.x.gifs[b].get_reference() is a + + # check that the property is set + assert b.x is a