diff --git a/src/faebryk/core/node.py b/src/faebryk/core/node.py index 03438051..b21ec37c 100644 --- a/src/faebryk/core/node.py +++ b/src/faebryk/core/node.py @@ -215,7 +215,8 @@ def __init_subclass__(cls, *, init: bool = True) -> None: super().__init_subclass__() cls._init = init - def _setup_fields(self, cls): + @classmethod + def __faebryk_fields__(cls) -> tuple[dict[str, Any], dict[str, Any]]: def all_vars(cls): return {k: v for c in reversed(cls.__mro__) for k, v in vars(c).items()} @@ -273,9 +274,15 @@ def is_genalias_node(obj): if not name.startswith("_") } - clsfields = { - name: obj for name, obj in clsfields_unf.items() if is_node_field(obj) - } + nonfabfields, fabfields = partition( + lambda x: is_node_field(x[1]), clsfields_unf.items() + ) + + return dict(fabfields), dict(nonfabfields) + + def _setup_fields(self, cls): + clsfields, _ = self.__faebryk_fields__() + LL_Types = (Node, GraphInterface) # for name, obj in clsfields_unf.items(): # if isinstance(obj, _d_field): @@ -422,9 +429,9 @@ def _handle_add_node(self, name: str, node: "Node"): from faebryk.core.trait import TraitImpl if isinstance(node, TraitImpl): - if self.has_trait(node._trait): + if self.has_trait(node.__trait__): if not node.handle_duplicate( - cast_assert(TraitImpl, self.get_trait(node._trait)), self + cast_assert(TraitImpl, self.get_trait(node.__trait__)), self ): return @@ -491,8 +498,12 @@ def pretty_params(self) -> str: def add_trait[_TImpl: "TraitImpl"](self, trait: _TImpl) -> _TImpl: return self.add(trait) - def _find[V: "Trait"](self, trait: type[V], only_implemented: bool) -> V | None: - from faebryk.core.trait import TraitImpl + def _find_trait_impl[V: "Trait | TraitImpl"](self, trait: type[V], only_implemented: bool) -> V | None: + from faebryk.core.trait import TraitImpl, TraitImplementationConfusedWithTrait + + if issubclass(trait, TraitImpl) and not trait.__trait__.__decless_trait__: + raise TraitImplementationConfusedWithTrait(self, trait) + trait = trait.__trait__ out = self.get_children( direct_only=True, @@ -500,29 +511,26 @@ def _find[V: "Trait"](self, trait: type[V], only_implemented: bool) -> V | None: f_filter=lambda impl: impl.implements(trait) and (impl.is_implemented() or not only_implemented), ) + assert len(out) <= 1 return cast_assert(trait, next(iter(out))) if out else None def del_trait(self, trait: type["Trait"]): - impl = self._find(trait, only_implemented=False) + impl = self._find_trait_impl(trait, only_implemented=False) if not impl: return self._remove_child(impl) - def try_get_trait[V: "Trait"](self, trait: Type[V]) -> V | None: - return self._find(trait, only_implemented=True) + def try_get_trait[V: "Trait | TraitImpl"](self, trait: Type[V]) -> V | None: + return self._find_trait_impl(trait, only_implemented=True) - def has_trait(self, trait: type["Trait"]) -> bool: + def has_trait(self, trait: type["Trait | TraitImpl"]) -> bool: return self.try_get_trait(trait) is not None - def get_trait[V: "Trait"](self, trait: Type[V]) -> V: - from faebryk.core.trait import TraitImpl, TraitNotFound - - assert not issubclass( - trait, TraitImpl - ), "You need to specify the trait, not an impl" + def get_trait[V: "Trait | TraitImpl"](self, trait: Type[V]) -> V: + from faebryk.core.trait import TraitNotFound - impl = self._find(trait, only_implemented=True) + impl = self.try_get_trait(trait) if not impl: raise TraitNotFound(self, trait) diff --git a/src/faebryk/core/trait.py b/src/faebryk/core/trait.py index d4fb40c9..069b5f2f 100644 --- a/src/faebryk/core/trait.py +++ b/src/faebryk/core/trait.py @@ -16,9 +16,19 @@ def __init__(self, node: Node, trait: type["Trait"], *args: object) -> None: self.trait = trait +class TraitImplementationConfusedWithTrait(NodeException): + def __init__(self, node: Node, trait: type["Trait"], *args: object) -> None: + super().__init__( + node, + *args, + "Implementation or trait was used where the other was expected.", + ) + self.trait = trait + + class TraitAlreadyExists(NodeException): def __init__(self, node: Node, trait: "TraitImpl", *args: object) -> None: - trait_type = trait._trait + trait_type = trait.__trait__ super().__init__( node, *args, @@ -34,6 +44,9 @@ def __init__(self, node: Node, *args: object) -> None: class Trait(Node): + __trait__: type["Trait"] + __decless_trait__ = False + @classmethod def impl[T: "Trait"](cls: type[T]): class _Impl(TraitImpl, cls): ... @@ -46,9 +59,23 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls) + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.__trait__ = cls + + @classmethod + def decless(cls): + class _Trait(cls): + __decless_trait__ = True + + return _Trait.impl() + + +Trait.__trait__ = Trait + class TraitImpl(Node): - _trait: type[Trait] + __trait__: type[Trait] def __preinit__(self) -> None: found = False @@ -56,7 +83,7 @@ def __preinit__(self) -> None: while not found: for base in bases: if not issubclass(base, TraitImpl) and issubclass(base, Trait): - self._trait = base + self.__trait__ = base found = True break bases = [ @@ -67,9 +94,9 @@ def __preinit__(self) -> None: ] assert len(bases) > 0 - assert isinstance(self._trait, type) - assert issubclass(self._trait, Trait) - assert self._trait is not TraitImpl + assert isinstance(self.__trait__, type) + assert issubclass(self.__trait__, Trait) + assert self.__trait__ is not TraitImpl @property def obj(self) -> Node: @@ -85,19 +112,19 @@ def cmp(self, other: "TraitImpl") -> tuple[bool, "TraitImpl"]: assert type(other), TraitImpl # If other same or more specific - if other.implements(self._trait): + if other.implements(self.__trait__): return True, other # If we are more specific - if self.implements(other._trait): + if self.implements(other.__trait__): return True, self return False, self - def implements(self, trait: type): - assert issubclass(trait, Trait) + def implements(self, trait: type[Trait]): + assert issubclass(trait.__trait__, Trait) - return issubclass(self._trait, trait) + return issubclass(self.__trait__, trait) # Overwriteable -------------------------------------------------------------------- @@ -112,7 +139,7 @@ def handle_duplicate(self, other: "TraitImpl", node: Node) -> bool: if candidate is not self: return False - node.del_trait(other._trait) + node.del_trait(other.__trait__) return True # raise TraitAlreadyExists(node, self) diff --git a/test/core/test_core.py b/test/core/test_core.py index 238a8a5f..0d3e784a 100644 --- a/test/core/test_core.py +++ b/test/core/test_core.py @@ -2,167 +2,12 @@ # SPDX-License-Identifier: MIT import unittest -from abc import abstractmethod -from typing import cast from faebryk.core.link import LinkDirect, LinkParent, LinkSibling -from faebryk.core.node import Node, NodeAlreadyBound -from faebryk.core.trait import ( - Trait, - TraitImpl, - TraitNotFound, - TraitUnbound, -) +from faebryk.core.node import Node from faebryk.libs.library import L -class TestTraits(unittest.TestCase): - def test_equality(self): - from faebryk.core.trait import Trait - - class _trait1(Trait): - pass - - class _trait1_1(_trait1): - pass - - class _trait2(Trait): - pass - - class impl1(_trait1.impl()): - pass - - class impl1_1(impl1): - pass - - class impl1_1_1(impl1_1): - pass - - class impl1_2(impl1): - pass - - class impl_1_1(_trait1_1.impl()): - pass - - class impl2(_trait2.impl()): - pass - - a = impl1() - - # different inst - self.assertNotEqual(impl1(), impl1()) - # same inst - self.assertEqual(a, a) - # same class - self.assertEqual(impl1, impl1) - # class & parent/child class - self.assertNotEqual(impl1_1, impl1) - # class & parallel class - self.assertNotEqual(impl2, impl1) - - def assertCmpTrue(one, two): - self.assertTrue(one.cmp(two)[0]) - - def assertCmpFalse(one, two): - self.assertFalse(one.cmp(two)[0]) - - # inst & class - assertCmpTrue(impl1(), impl1()) - assertCmpTrue(impl1_1(), impl1_1()) - # inst & parent class - assertCmpTrue(impl1_1(), impl1()) - # inst & child class - assertCmpTrue(impl1(), impl1_1()) - # inst & parallel class - assertCmpFalse(impl2(), impl1()) - assertCmpFalse(impl2(), impl1_1()) - # inst & double child class - assertCmpTrue(impl1_1_1(), impl1()) - # inst & double parent class - assertCmpTrue(impl1(), impl1_1_1()) - # inst & sister class - assertCmpTrue(impl1_2(), impl1_1()) - # inst & nephew class - assertCmpTrue(impl1_2(), impl1_1_1()) - - # Trait inheritance - assertCmpTrue(impl1(), impl_1_1()) - assertCmpTrue(impl_1_1(), impl1()) - - def test_obj_traits(self): - obj = Node() - - class trait1(Trait): - @abstractmethod - def do(self) -> int: - raise NotImplementedError - - class trait1impl(trait1.impl()): - def do(self) -> int: - return 1 - - class cfgtrait1(trait1impl): - def __init__(self, cfg) -> None: - super().__init__() - self.cfg = cfg - - def do(self) -> int: - return self.cfg - - class trait2(trait1): - pass - - class impl2(trait2.impl()): - pass - - # Test failure on getting non existent - self.assertFalse(obj.has_trait(trait1)) - self.assertRaises(TraitNotFound, lambda: obj.get_trait(trait1)) - - trait1_inst = trait1impl() - cfgtrait1_inst = cfgtrait1(5) - impl2_inst = impl2() - - # Test getting trait - obj.add(trait1_inst) - self.assertTrue(obj.has_trait(trait1)) - self.assertEqual(trait1_inst, obj.get_trait(trait1)) - self.assertEqual(trait1_inst.do(), obj.get_trait(trait1).do()) - - # Test double add - self.assertRaises(NodeAlreadyBound, lambda: obj.add(trait1_inst)) - - # Test replace - obj.add(cfgtrait1_inst) - self.assertEqual(cfgtrait1_inst, obj.get_trait(trait1)) - self.assertEqual(cfgtrait1_inst.do(), obj.get_trait(trait1).do()) - obj.add(trait1_inst) - self.assertEqual(trait1_inst, obj.get_trait(trait1)) - - # Test remove - obj.del_trait(trait2) - self.assertTrue(obj.has_trait(trait1)) - obj.del_trait(trait1) - self.assertFalse(obj.has_trait(trait1)) - - # Test get obj - self.assertRaises(TraitUnbound, lambda: trait1_inst.obj) - obj.add(trait1_inst) - _impl: TraitImpl = cast(TraitImpl, obj.get_trait(trait1)) - self.assertEqual(_impl.obj, obj) - obj.del_trait(trait1) - self.assertRaises(TraitUnbound, lambda: trait1_inst.obj) - - # Test specific override - obj.add(impl2_inst) - obj.add(trait1_inst) - self.assertEqual(impl2_inst, obj.get_trait(trait1)) - - # Test child delete - obj.del_trait(trait1) - self.assertFalse(obj.has_trait(trait1)) - - class TestGraph(unittest.TestCase): def test_gifs(self): from faebryk.core.graphinterface import GraphInterface as GIF diff --git a/test/core/test_traits.py b/test/core/test_traits.py new file mode 100644 index 00000000..bd4e8036 --- /dev/null +++ b/test/core/test_traits.py @@ -0,0 +1,210 @@ +from abc import abstractmethod +from typing import cast + +import pytest + +from faebryk.core.node import Node, NodeAlreadyBound +from faebryk.core.trait import ( + Trait, + TraitImpl, + TraitNotFound, + TraitUnbound, + TraitImplementationConfusedWithTrait +) + + +def test_trait_equality(): + from faebryk.core.trait import Trait + + class _trait1(Trait): + pass + + class _trait1_1(_trait1): + pass + + class _trait2(Trait): + pass + + class impl1(_trait1.impl()): + pass + + class impl1_1(impl1): + pass + + class impl1_1_1(impl1_1): + pass + + class impl1_2(impl1): + pass + + class impl_1_1(_trait1_1.impl()): + pass + + class impl2(_trait2.impl()): + pass + + a = impl1() + + # Test instance and class equality + assert impl1() != impl1() # different instances + assert a == a # same instance + assert impl1 == impl1 # same class + assert impl1_1 != impl1 # class & parent/child class + assert impl2 != impl1 # class & parallel class + + # Test trait implementation comparisons + assert impl1().cmp(impl1())[0] # inst & class + assert impl1_1().cmp(impl1_1())[0] # inst & class + assert impl1_1().cmp(impl1())[0] # inst & parent class + assert impl1().cmp(impl1_1())[0] # inst & child class + assert not impl2().cmp(impl1())[0] # inst & parallel class + assert not impl2().cmp(impl1_1())[0] # inst & parallel class + assert impl1_1_1().cmp(impl1())[0] # inst & double child class + assert impl1().cmp(impl1_1_1())[0] # inst & double parent class + assert impl1_2().cmp(impl1_1())[0] # inst & sister class + assert impl1_2().cmp(impl1_1_1())[0] # inst & nephew class + + # Test trait inheritance + assert impl1().cmp(impl_1_1())[0] + assert impl_1_1().cmp(impl1())[0] + + +def test_trait_basic_operations(): + obj = Node() + + class trait1(Trait): + @abstractmethod + def do(self) -> int: + raise NotImplementedError + + class trait1impl(trait1.impl()): + def do(self) -> int: + return 1 + + # Test failure on getting non-existent trait + assert not obj.has_trait(trait1) + with pytest.raises(TraitNotFound): + obj.get_trait(trait1) + + # Test adding and getting trait + trait1_inst = trait1impl() + obj.add(trait1_inst) + assert obj.has_trait(trait1) + assert trait1_inst == obj.get_trait(trait1) + assert trait1_inst.do() == obj.get_trait(trait1).do() + + # Test double add + with pytest.raises(NodeAlreadyBound): + obj.add(trait1_inst) + + # Test trait removal + obj.del_trait(trait1) + assert not obj.has_trait(trait1) + + +def test_trait_replacement(): + obj = Node() + + class trait1(Trait): + @abstractmethod + def do(self) -> int: + raise NotImplementedError + + class trait1impl(trait1.impl()): + def do(self) -> int: + return 1 + + class cfgtrait1(trait1impl): + def __init__(self, cfg) -> None: + super().__init__() + self.cfg = cfg + + def do(self) -> int: + return self.cfg + + trait1_inst = trait1impl() + cfgtrait1_inst = cfgtrait1(5) + + # Test trait replacement + obj.add(trait1_inst) + obj.add(cfgtrait1_inst) + assert cfgtrait1_inst == obj.get_trait(trait1) + assert cfgtrait1_inst.do() == obj.get_trait(trait1).do() + obj.add(trait1_inst) + assert trait1_inst == obj.get_trait(trait1) + + +def test_trait_object_binding(): + obj = Node() + + class trait1(Trait): + pass + + class trait1impl(trait1.impl()): + pass + + trait1_inst = trait1impl() + + # Test trait object binding + with pytest.raises(TraitUnbound): + trait1_inst.obj + obj.add(trait1_inst) + _impl: TraitImpl = cast(TraitImpl, obj.get_trait(trait1)) + assert _impl.obj == obj + obj.del_trait(trait1) + with pytest.raises(TraitUnbound): + trait1_inst.obj + + +def test_trait_inheritance_and_override(): + obj = Node() + + class trait1(Trait): + pass + + class trait2(trait1): + pass + + class impl1(trait1.impl()): + pass + + class impl2(trait2.impl()): + pass + + impl1_inst = impl1() + impl2_inst = impl2() + + # Test specific override + obj.add(impl2_inst) + obj.add(impl1_inst) + assert impl2_inst == obj.get_trait(trait1) + + # Test child delete + obj.del_trait(trait1) + assert not obj.has_trait(trait1) + + +def test_trait_impl_confusion(): + obj = Node() + + class trait1(Trait): + pass + + class trait1impl(trait1.impl()): + pass + + t1 = obj.add(trait1impl()) + with pytest.raises(TraitImplementationConfusedWithTrait): + obj.get_trait(trait1impl) + + assert obj.get_trait(trait1) == t1 + + +def test_trait_impl_exception(): + obj = Node() + + class trait1impl(Trait.decless()): + pass + + t1 = obj.add(trait1impl()) + assert obj.get_trait(trait1impl) is t1