diff --git a/src/faebryk/core/graphinterface.py b/src/faebryk/core/graphinterface.py index 9e836af8..9a9ee7c9 100644 --- a/src/faebryk/core/graphinterface.py +++ b/src/faebryk/core/graphinterface.py @@ -7,10 +7,17 @@ from faebryk.core.core import ID_REPR, FaebrykLibObject from faebryk.core.graph_backends.default import GraphImpl -from faebryk.core.link import Link, LinkDirect, LinkNamedParent +from faebryk.core.link import ( + Link, + LinkDirect, + LinkNamedParent, + LinkPointer, + LinkSibling, +) from faebryk.libs.util import ( NotNone, exceptions_to_log, + find, try_avoid_endless_recursion, ) @@ -191,3 +198,16 @@ def disconnect_parent(self): class GraphInterfaceSelf(GraphInterface): ... + + +class GraphInterfaceReference[T: "Node"](GraphInterface): + """Represents a reference to a node object""" + + def get_referenced_gif(self) -> GraphInterfaceSelf: + return find( + self.get_links_by_type(LinkPointer), + lambda link: not isinstance(link, LinkSibling), + ).self_gif + + def get_reference(self) -> T: + return self.get_referenced_gif().node diff --git a/src/faebryk/core/link.py b/src/faebryk/core/link.py index deee44a1..3b5e58fd 100644 --- a/src/faebryk/core/link.py +++ b/src/faebryk/core/link.py @@ -9,7 +9,11 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from faebryk.core.graphinterface import GraphInterface, GraphInterfaceHierarchical + from faebryk.core.graphinterface import ( + GraphInterface, + GraphInterfaceHierarchical, + GraphInterfaceSelf, + ) class Link(FaebrykLibObject): @@ -38,13 +42,36 @@ def __repr__(self) -> str: return f"{type(self).__name__}()" -class LinkSibling(Link): - def __init__(self, interfaces: list["GraphInterface"]) -> None: +class LinkPointer(Link): + """A Link that points towards a self-gif""" + + def __init__( + self, + interfaces: list["GraphInterfaceSelf | GraphInterface"], + ) -> None: + from faebryk.core.graphinterface import GraphInterfaceSelf + super().__init__() - self.interfaces = interfaces + assert len(interfaces) == 2 + if isinstance(interfaces[0], GraphInterfaceSelf) and not isinstance( + interfaces[1], GraphInterfaceSelf + ): + self.self_gif = interfaces[0] + self.other_gif = interfaces[1] + elif isinstance(interfaces[1], GraphInterfaceSelf) and not isinstance( + interfaces[0], GraphInterfaceSelf + ): + self.self_gif = interfaces[1] + self.other_gif = interfaces[0] + else: + raise TypeError("Interfaces must be one self-gif and one other-gif") def get_connections(self) -> list["GraphInterface"]: - return self.interfaces + return [self.self_gif, self.other_gif] + + +class LinkSibling(LinkPointer): + """A link represents a connection between a self-gif and a gif in the same node""" class LinkParent(Link): @@ -83,9 +110,13 @@ def curried(interfaces: list["GraphInterface"]): class LinkDirect(Link): + """Represents a symmetrical link between two interfaces of the same type""" + def __init__(self, interfaces: list["GraphInterface"]) -> None: super().__init__() - assert len(set(map(type, interfaces))) == 1 + assert ( + len(set(map(type, interfaces))) == 1 + ), "Interfaces must be of the same type" self.interfaces = interfaces def get_connections(self) -> list["GraphInterface"]: diff --git a/src/faebryk/core/node.py b/src/faebryk/core/node.py index 03438051..5bbf8e05 100644 --- a/src/faebryk/core/node.py +++ b/src/faebryk/core/node.py @@ -1,6 +1,8 @@ # This file is part of the faebryk project # SPDX-License-Identifier: MIT import logging +from abc import abstractmethod +from dataclasses import InitVar as dataclass_InitVar from itertools import chain from typing import ( TYPE_CHECKING, @@ -65,13 +67,21 @@ class fab_field: pass -class rt_field[T, O](property, fab_field): +class _rt_field[T, O](property, fab_field): + """TODO: what does an rt_field represent? what does "rt" stand for?""" + + @abstractmethod + def __construct__(self, obj: T) -> O: + pass + + +class rt_field[T, O](_rt_field): def __init__(self, fget: Callable[[T], O]) -> None: super().__init__() self.func = fget self.lookup: dict[T, O] = {} - def _construct(self, obj: T): + def __construct__(self, obj: T): constructed = self.func(obj) # TODO find a better way for this # in python 3.13 name support @@ -145,6 +155,14 @@ def __init__(self, node: "Node", other: "Node", *args: object) -> None: class NodeNoParent(NodeException): ... +class InitVar(dataclass_InitVar): + """ + This is a type-marker which instructs the Node constructor to ignore the field. + + Inspired by dataclasses.InitVar, which it inherits from. + """ + + # ----------------------------------------------------------------------------- @@ -167,7 +185,7 @@ def __hash__(self) -> int: # TODO proper hash return hash(id(self)) - def add[T: Node]( + def add[T: Node | GraphInterface]( self, obj: T, name: str | None = None, @@ -176,9 +194,10 @@ def add[T: Node]( assert obj is not None if container is None: - container = self.runtime_anon if name: container = self.runtime + else: + container = self.runtime_anon try: container_name = find(vars(self).items(), lambda x: x[1] is container)[0] @@ -197,7 +216,11 @@ def add[T: Node]( container.append(obj) name = f"{container_name}[{len(container) - 1}]" - self._handle_add_node(name, obj) + if isinstance(obj, GraphInterface): + self._handle_add_gif(name, obj) + else: + self._handle_add_node(name, obj) + return obj def add_to_container[T: Node]( @@ -259,7 +282,7 @@ def is_genalias_node(obj): if get_origin(obj): return is_genalias_node(obj) - if isinstance(obj, rt_field): + if isinstance(obj, _rt_field): return True return False @@ -267,8 +290,12 @@ def is_genalias_node(obj): clsfields_unf = { name: obj for name, obj in chain( - [(name, f) for name, f in annos.items()], - [(name, f) for name, f in vars_.items() if isinstance(f, fab_field)], + ( + (name, obj) + for name, obj in annos.items() + if not isinstance(obj, InitVar) + ), + ((name, f) for name, f in vars_.items() if isinstance(f, fab_field)), ) if not name.startswith("_") } @@ -277,6 +304,12 @@ def is_genalias_node(obj): name: obj for name, obj in clsfields_unf.items() if is_node_field(obj) } + return clsfields + + def _setup_fields(self, cls): + clsfields = self.__faebryk_fields__() + LL_Types = (Node, GraphInterface) + # for name, obj in clsfields_unf.items(): # if isinstance(obj, _d_field): # obj = obj.type @@ -337,23 +370,23 @@ def setup_gen_alias(name, obj): setattr(self, name, append(name, obj())) return - if isinstance(obj, rt_field): - append(name, obj._construct(self)) + if isinstance(obj, _rt_field): + append(name, obj.__construct__(self)) return raise NotImplementedError() def setup_field(name, obj): - try: - _setup_field(name, obj) - except Exception as e: - raise FieldConstructionError( - self, - name, - f'An exception occurred while constructing field "{name}"', - ) from e - - nonrt, rt = partition(lambda x: isinstance(x[1], rt_field), clsfields.items()) + # try: + _setup_field(name, obj) + # except Exception as e: + # raise FieldConstructionError( + # self, + # name, + # f'An exception occurred while constructing field "{name}"', + # ) from e + + nonrt, rt = partition(lambda x: isinstance(x[1], _rt_field), clsfields.items()) for name, obj in nonrt: setup_field(name, obj) @@ -709,3 +742,48 @@ def zip_children_by_name_with[N: Node]( @staticmethod def with_names[N: Node](nodes: Iterable[N]) -> dict[str, N]: return {n.get_name(): n for n in nodes} + + +def magic_pointer[O: Node](out_type: type[O] | None = None) -> O: + """ + magic_pointer let's you create simple references to other nodes + that are properly encoded in the graph. + + This final wrapper is primarily to fudge the typing. + """ + from faebryk.core.graphinterface import GraphInterfaceReference + from faebryk.core.link import LinkPointer + + class magic_pointer(property): + """ + magic_pointer let's you create simple references to other nodes + that are properly encoded in the graph. + """ + + def __init__(self): + self.gifs_by_start: dict[Node, GraphInterfaceReference] = {} + + def get(instance: Node) -> O: + try: + my_gif = self.gifs_by_start[instance] + except KeyError: + raise AttributeError("magic_pointer not set to anything") + return my_gif.get_reference() + + def set(instance: Node, value: O): + if instance in self.gifs_by_start: + raise TypeError( + f"{self.__class__.__name__} already set and are immutable" + ) + + if out_type is not None and not isinstance(value, out_type): + raise TypeError(f"Expected {out_type} got {type(value)}") + + self.gifs_by_start[instance] = instance.add(GraphInterfaceReference()) + + gif = self.gifs_by_start[instance] + gif.connect(value.self_gif, LinkPointer) + + property.__init__(self, get, set) + + return magic_pointer() diff --git a/test/core/node/test_magic_pointer.py b/test/core/node/test_magic_pointer.py new file mode 100644 index 00000000..52c5b815 --- /dev/null +++ b/test/core/node/test_magic_pointer.py @@ -0,0 +1,56 @@ +import pytest + +from faebryk.core.node import InitVar, Node, magic_pointer + + +def test_points_to_correct_node(): + class A(Node): + pass + + class B(Node): + x: InitVar[A] = magic_pointer(A) + + a = A() + b = B() + b.x = a + assert b.x is a + + +def test_immutable(): + class A(Node): + pass + + class B(Node): + x: InitVar[A] = magic_pointer(A) + + b = B() + a = A() + b.x = a + assert b.x is a + + with pytest.raises(TypeError): + b.x = A() + + +def test_unset(): + class A(Node): + pass + + class B(Node): + x: InitVar[A] = magic_pointer(A) + + b = B() + with pytest.raises(AttributeError): + b.x + + +def test_wrong_type(): + class A(Node): + pass + + class B(Node): + x: InitVar[A] = magic_pointer(A) + + b = B() + with pytest.raises(TypeError): + b.x = 1