Skip to content
This repository has been archived by the owner on Dec 10, 2024. It is now read-only.

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
mawildoer committed Sep 19, 2024
1 parent 5a62bdd commit 22d702d
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 187 deletions.
46 changes: 27 additions & 19 deletions src/faebryk/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def __init_subclass__(cls, *, init: bool = True) -> None:
super().__init_subclass__()
cls._init = init

def _setup_fields(self, cls):
@classmethod
def __faebryk_fields__(cls) -> tuple[dict[str, Any], dict[str, Any]]:
def all_vars(cls):
return {k: v for c in reversed(cls.__mro__) for k, v in vars(c).items()}

Expand Down Expand Up @@ -273,9 +274,15 @@ def is_genalias_node(obj):
if not name.startswith("_")
}

clsfields = {
name: obj for name, obj in clsfields_unf.items() if is_node_field(obj)
}
nonfabfields, fabfields = partition(
lambda x: is_node_field(x[1]), clsfields_unf.items()
)

return dict(fabfields), dict(nonfabfields)

def _setup_fields(self, cls):
clsfields, _ = self.__faebryk_fields__()
LL_Types = (Node, GraphInterface)

# for name, obj in clsfields_unf.items():
# if isinstance(obj, _d_field):
Expand Down Expand Up @@ -422,9 +429,9 @@ def _handle_add_node(self, name: str, node: "Node"):
from faebryk.core.trait import TraitImpl

if isinstance(node, TraitImpl):
if self.has_trait(node._trait):
if self.has_trait(node.__trait__):
if not node.handle_duplicate(
cast_assert(TraitImpl, self.get_trait(node._trait)), self
cast_assert(TraitImpl, self.get_trait(node.__trait__)), self
):
return

Expand Down Expand Up @@ -491,38 +498,39 @@ def pretty_params(self) -> str:
def add_trait[_TImpl: "TraitImpl"](self, trait: _TImpl) -> _TImpl:
return self.add(trait)

def _find[V: "Trait"](self, trait: type[V], only_implemented: bool) -> V | None:
from faebryk.core.trait import TraitImpl
def _find_trait_impl[V: "Trait | TraitImpl"](self, trait: type[V], only_implemented: bool) -> V | None:

Check failure on line 501 in src/faebryk/core/node.py

View workflow job for this annotation

GitHub Actions / test

Ruff (E501)

src/faebryk/core/node.py:501:89: E501 Line too long (107 > 88)
from faebryk.core.trait import TraitImpl, TraitImplementationConfusedWithTrait

if issubclass(trait, TraitImpl) and not trait.__trait__.__decless_trait__:
raise TraitImplementationConfusedWithTrait(self, trait)
trait = trait.__trait__

out = self.get_children(
direct_only=True,
types=TraitImpl,
f_filter=lambda impl: impl.implements(trait)
and (impl.is_implemented() or not only_implemented),
)

assert len(out) <= 1
return cast_assert(trait, next(iter(out))) if out else None

def del_trait(self, trait: type["Trait"]):
impl = self._find(trait, only_implemented=False)
impl = self._find_trait_impl(trait, only_implemented=False)
if not impl:
return
self._remove_child(impl)

def try_get_trait[V: "Trait"](self, trait: Type[V]) -> V | None:
return self._find(trait, only_implemented=True)
def try_get_trait[V: "Trait | TraitImpl"](self, trait: Type[V]) -> V | None:
return self._find_trait_impl(trait, only_implemented=True)

def has_trait(self, trait: type["Trait"]) -> bool:
def has_trait(self, trait: type["Trait | TraitImpl"]) -> bool:
return self.try_get_trait(trait) is not None

def get_trait[V: "Trait"](self, trait: Type[V]) -> V:
from faebryk.core.trait import TraitImpl, TraitNotFound

assert not issubclass(
trait, TraitImpl
), "You need to specify the trait, not an impl"
def get_trait[V: "Trait | TraitImpl"](self, trait: Type[V]) -> V:
from faebryk.core.trait import TraitNotFound

impl = self._find(trait, only_implemented=True)
impl = self.try_get_trait(trait)
if not impl:
raise TraitNotFound(self, trait)

Expand Down
51 changes: 39 additions & 12 deletions src/faebryk/core/trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,19 @@ def __init__(self, node: Node, trait: type["Trait"], *args: object) -> None:
self.trait = trait


class TraitImplementationConfusedWithTrait(NodeException):
def __init__(self, node: Node, trait: type["Trait"], *args: object) -> None:
super().__init__(
node,
*args,
"Implementation or trait was used where the other was expected.",
)
self.trait = trait


class TraitAlreadyExists(NodeException):
def __init__(self, node: Node, trait: "TraitImpl", *args: object) -> None:
trait_type = trait._trait
trait_type = trait.__trait__
super().__init__(
node,
*args,
Expand All @@ -34,6 +44,9 @@ def __init__(self, node: Node, *args: object) -> None:


class Trait(Node):
__trait__: type["Trait"]
__decless_trait__ = False

@classmethod
def impl[T: "Trait"](cls: type[T]):
class _Impl(TraitImpl, cls): ...
Expand All @@ -46,17 +59,31 @@ def __new__(cls, *args, **kwargs):

return super().__new__(cls)

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.__trait__ = cls

@classmethod
def decless(cls):
class _Trait(cls):
__decless_trait__ = True

return _Trait.impl()


Trait.__trait__ = Trait


class TraitImpl(Node):
_trait: type[Trait]
__trait__: type[Trait]

def __preinit__(self) -> None:
found = False
bases = type(self).__bases__
while not found:
for base in bases:
if not issubclass(base, TraitImpl) and issubclass(base, Trait):
self._trait = base
self.__trait__ = base
found = True
break
bases = [
Expand All @@ -67,9 +94,9 @@ def __preinit__(self) -> None:
]
assert len(bases) > 0

assert isinstance(self._trait, type)
assert issubclass(self._trait, Trait)
assert self._trait is not TraitImpl
assert isinstance(self.__trait__, type)
assert issubclass(self.__trait__, Trait)
assert self.__trait__ is not TraitImpl

@property
def obj(self) -> Node:
Expand All @@ -85,19 +112,19 @@ def cmp(self, other: "TraitImpl") -> tuple[bool, "TraitImpl"]:
assert type(other), TraitImpl

# If other same or more specific
if other.implements(self._trait):
if other.implements(self.__trait__):
return True, other

# If we are more specific
if self.implements(other._trait):
if self.implements(other.__trait__):
return True, self

return False, self

def implements(self, trait: type):
assert issubclass(trait, Trait)
def implements(self, trait: type[Trait]):
assert issubclass(trait.__trait__, Trait)

return issubclass(self._trait, trait)
return issubclass(self.__trait__, trait)

# Overwriteable --------------------------------------------------------------------

Expand All @@ -112,7 +139,7 @@ def handle_duplicate(self, other: "TraitImpl", node: Node) -> bool:
if candidate is not self:
return False

node.del_trait(other._trait)
node.del_trait(other.__trait__)
return True

# raise TraitAlreadyExists(node, self)
Expand Down
157 changes: 1 addition & 156 deletions test/core/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,167 +2,12 @@
# SPDX-License-Identifier: MIT

import unittest
from abc import abstractmethod
from typing import cast

from faebryk.core.link import LinkDirect, LinkParent, LinkSibling
from faebryk.core.node import Node, NodeAlreadyBound
from faebryk.core.trait import (
Trait,
TraitImpl,
TraitNotFound,
TraitUnbound,
)
from faebryk.core.node import Node
from faebryk.libs.library import L


class TestTraits(unittest.TestCase):
def test_equality(self):
from faebryk.core.trait import Trait

class _trait1(Trait):
pass

class _trait1_1(_trait1):
pass

class _trait2(Trait):
pass

class impl1(_trait1.impl()):
pass

class impl1_1(impl1):
pass

class impl1_1_1(impl1_1):
pass

class impl1_2(impl1):
pass

class impl_1_1(_trait1_1.impl()):
pass

class impl2(_trait2.impl()):
pass

a = impl1()

# different inst
self.assertNotEqual(impl1(), impl1())
# same inst
self.assertEqual(a, a)
# same class
self.assertEqual(impl1, impl1)
# class & parent/child class
self.assertNotEqual(impl1_1, impl1)
# class & parallel class
self.assertNotEqual(impl2, impl1)

def assertCmpTrue(one, two):
self.assertTrue(one.cmp(two)[0])

def assertCmpFalse(one, two):
self.assertFalse(one.cmp(two)[0])

# inst & class
assertCmpTrue(impl1(), impl1())
assertCmpTrue(impl1_1(), impl1_1())
# inst & parent class
assertCmpTrue(impl1_1(), impl1())
# inst & child class
assertCmpTrue(impl1(), impl1_1())
# inst & parallel class
assertCmpFalse(impl2(), impl1())
assertCmpFalse(impl2(), impl1_1())
# inst & double child class
assertCmpTrue(impl1_1_1(), impl1())
# inst & double parent class
assertCmpTrue(impl1(), impl1_1_1())
# inst & sister class
assertCmpTrue(impl1_2(), impl1_1())
# inst & nephew class
assertCmpTrue(impl1_2(), impl1_1_1())

# Trait inheritance
assertCmpTrue(impl1(), impl_1_1())
assertCmpTrue(impl_1_1(), impl1())

def test_obj_traits(self):
obj = Node()

class trait1(Trait):
@abstractmethod
def do(self) -> int:
raise NotImplementedError

class trait1impl(trait1.impl()):
def do(self) -> int:
return 1

class cfgtrait1(trait1impl):
def __init__(self, cfg) -> None:
super().__init__()
self.cfg = cfg

def do(self) -> int:
return self.cfg

class trait2(trait1):
pass

class impl2(trait2.impl()):
pass

# Test failure on getting non existent
self.assertFalse(obj.has_trait(trait1))
self.assertRaises(TraitNotFound, lambda: obj.get_trait(trait1))

trait1_inst = trait1impl()
cfgtrait1_inst = cfgtrait1(5)
impl2_inst = impl2()

# Test getting trait
obj.add(trait1_inst)
self.assertTrue(obj.has_trait(trait1))
self.assertEqual(trait1_inst, obj.get_trait(trait1))
self.assertEqual(trait1_inst.do(), obj.get_trait(trait1).do())

# Test double add
self.assertRaises(NodeAlreadyBound, lambda: obj.add(trait1_inst))

# Test replace
obj.add(cfgtrait1_inst)
self.assertEqual(cfgtrait1_inst, obj.get_trait(trait1))
self.assertEqual(cfgtrait1_inst.do(), obj.get_trait(trait1).do())
obj.add(trait1_inst)
self.assertEqual(trait1_inst, obj.get_trait(trait1))

# Test remove
obj.del_trait(trait2)
self.assertTrue(obj.has_trait(trait1))
obj.del_trait(trait1)
self.assertFalse(obj.has_trait(trait1))

# Test get obj
self.assertRaises(TraitUnbound, lambda: trait1_inst.obj)
obj.add(trait1_inst)
_impl: TraitImpl = cast(TraitImpl, obj.get_trait(trait1))
self.assertEqual(_impl.obj, obj)
obj.del_trait(trait1)
self.assertRaises(TraitUnbound, lambda: trait1_inst.obj)

# Test specific override
obj.add(impl2_inst)
obj.add(trait1_inst)
self.assertEqual(impl2_inst, obj.get_trait(trait1))

# Test child delete
obj.del_trait(trait1)
self.assertFalse(obj.has_trait(trait1))


class TestGraph(unittest.TestCase):
def test_gifs(self):
from faebryk.core.graphinterface import GraphInterface as GIF
Expand Down
Loading

0 comments on commit 22d702d

Please sign in to comment.