Skip to content
This repository has been archived by the owner on Dec 10, 2024. It is now read-only.

Commit

Permalink
fix trait override
Browse files Browse the repository at this point in the history
  • Loading branch information
iopapamanoglou committed Aug 28, 2024
1 parent a4ec66d commit 0661a45
Show file tree
Hide file tree
Showing 13 changed files with 197 additions and 202 deletions.
2 changes: 1 addition & 1 deletion src/faebryk/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# only for typechecker

if TYPE_CHECKING:
from faebryk.core.core import Link
from faebryk.core.link import Link

# TODO create GraphView base class

Expand Down
2 changes: 1 addition & 1 deletion src/faebryk/core/graph_backends/graphgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# only for typechecker

if TYPE_CHECKING:
from faebryk.core.core import Link
from faebryk.core.link import Link


class GraphGT[T](Graph[T, gt.Graph]):
Expand Down
2 changes: 1 addition & 1 deletion src/faebryk/core/graph_backends/graphnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# only for typechecker

if TYPE_CHECKING:
from faebryk.core.core import Link
from faebryk.core.link import Link


class GraphNX[T](Graph[T, nx.Graph]):
Expand Down
2 changes: 1 addition & 1 deletion src/faebryk/core/graph_backends/graphpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# only for typechecker

if TYPE_CHECKING:
from faebryk.core.core import Link
from faebryk.core.link import Link

type L = "Link"

Expand Down
20 changes: 15 additions & 5 deletions src/faebryk/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ def __init__(self, node: "Node", *args: object) -> None:
self.node = node


class NodeAlreadyBound(NodeException):
def __init__(self, node: "Node", other: "Node", *args: object) -> None:
super().__init__(
node,
*args,
f"Node {other} already bound to"
f" {other.get_parent()}, can't bind to {node}",
)


class Node(FaebrykLibObject, metaclass=PostInitCaller):
runtime_anon: list["Node"]
runtime: dict[str, "Node"]
Expand Down Expand Up @@ -369,17 +379,17 @@ def _handle_add_gif(self, name: str, gif: GraphInterface):
gif.connect(self.self_gif, linkcls=LinkSibling)

def _handle_add_node(self, name: str, node: "Node"):
assert not (
other_p := node.get_parent()
), f"{node} already has parent: {other_p}"
if node.get_parent():
raise NodeAlreadyBound(self, node)

from faebryk.core.trait import TraitImpl

if isinstance(node, TraitImpl):
if self.has_trait(node._trait):
node.handle_duplicate(
if not node.handle_duplicate(
cast_assert(TraitImpl, self.get_trait(node._trait)), self
)
):
return

node.parent.connect(self.children, LinkNamedParent.curry(name))
node._handle_added_to_parent()
Expand Down
18 changes: 15 additions & 3 deletions src/faebryk/core/trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def __init__(self, node: Node, trait: "TraitImpl", *args: object) -> None:
self.trait = trait


class TraitUnbound(NodeException):
def __init__(self, node: Node, *args: object) -> None:
super().__init__(node, *args, f"Trait {node} is not bound to a node")


class Trait(Node):
@classmethod
def impl[T: "Trait"](cls: type[T]):
Expand Down Expand Up @@ -64,7 +69,7 @@ def __preinit__(self) -> None:
def obj(self) -> Node:
p = self.get_parent()
if not p:
raise Exception("trait is not linked to node")
raise TraitUnbound(self)
return p[0]

def get_obj[T: Node](self, type: type[T]) -> T:
Expand Down Expand Up @@ -95,9 +100,16 @@ def _handle_added_to_parent(self):

def on_obj_set(self): ...

def handle_duplicate(self, other: "TraitImpl", node: Node):
def handle_duplicate(self, other: "TraitImpl", node: Node) -> bool:
assert other is not self
raise TraitAlreadyExists(node, self)
_, candidate = other.cmp(self)
if candidate is not self:
return False

node.del_trait(other._trait)
return True

# raise TraitAlreadyExists(node, self)

# override this to implement a dynamic trait
def is_implemented(self):
Expand Down
18 changes: 12 additions & 6 deletions test/core/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
from typing import cast

from faebryk.core.link import LinkDirect, LinkParent, LinkSibling
from faebryk.core.node import Node
from faebryk.core.trait import Trait, TraitImpl
from faebryk.core.node import Node, NodeAlreadyBound
from faebryk.core.trait import (
Trait,
TraitAlreadyExists,

Check failure on line 12 in test/core/test_core.py

View workflow job for this annotation

GitHub Actions / test

Ruff (F401)

test/core/test_core.py:12:5: F401 `faebryk.core.trait.TraitAlreadyExists` imported but unused
TraitImpl,
TraitNotFound,
TraitUnbound,
)


class TestTraits(unittest.TestCase):
Expand Down Expand Up @@ -111,7 +117,7 @@ class impl2(trait2.impl()):

# Test failure on getting non existent
self.assertFalse(obj.has_trait(trait1))
self.assertRaises(AssertionError, lambda: obj.get_trait(trait1))
self.assertRaises(TraitNotFound, lambda: obj.get_trait(trait1))

trait1_inst = trait1impl()
cfgtrait1_inst = cfgtrait1(5)
Expand All @@ -124,7 +130,7 @@ class impl2(trait2.impl()):
self.assertEqual(trait1_inst.do(), obj.get_trait(trait1).do())

# Test double add
self.assertRaises(AssertionError, lambda: obj.add_trait(trait1_inst))
self.assertRaises(NodeAlreadyBound, lambda: obj.add_trait(trait1_inst))

# Test replace
obj.add_trait(cfgtrait1_inst)
Expand All @@ -140,12 +146,12 @@ class impl2(trait2.impl()):
self.assertFalse(obj.has_trait(trait1))

# Test get obj
self.assertRaises(AssertionError, lambda: trait1_inst.obj)
self.assertRaises(TraitUnbound, lambda: trait1_inst.obj)
obj.add_trait(trait1_inst)
_impl: TraitImpl = cast(TraitImpl, obj.get_trait(trait1))
self.assertEqual(_impl.obj, obj)
obj.del_trait(trait1)
self.assertRaises(AssertionError, lambda: trait1_inst.obj)
self.assertRaises(TraitUnbound, lambda: trait1_inst.obj)

# Test specific override
obj.add_trait(impl2_inst)
Expand Down
82 changes: 39 additions & 43 deletions test/core/test_hierarchy_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@
import unittest
from itertools import chain

from faebryk.core.core import (
LinkDirect,
LinkDirectShallow,
Module,
ModuleInterface,
_TLinkDirectShallow,
)
from faebryk.core.core import logger as core_logger
from faebryk.core.link import LinkDirect, LinkDirectShallow, _TLinkDirectShallow
from faebryk.core.module import Module
from faebryk.core.moduleinterface import ModuleInterface
from faebryk.core.util import specialize_interface
from faebryk.library.Electrical import Electrical
from faebryk.library.ElectricLogic import ElectricLogic
Expand All @@ -38,18 +34,18 @@ class _IFs(super().IFS()):

self.IFs = _IFs(self)

bus_in = self.IFs.bus_in
bus_out = self.IFs.bus_out
bus_in = self.bus_in
bus_out = self.bus_out

bus_in.IFs.rx.IFs.signal.connect(bus_out.IFs.rx.IFs.signal)
bus_in.IFs.tx.IFs.signal.connect(bus_out.IFs.tx.IFs.signal)
bus_in.IFs.rx.IFs.reference.connect(bus_out.IFs.rx.IFs.reference)
bus_in.rx.signal.connect(bus_out.rx.signal)
bus_in.tx.signal.connect(bus_out.tx.signal)
bus_in.rx.reference.connect(bus_out.rx.reference)

app = UARTBuffer()

self.assertTrue(app.IFs.bus_in.IFs.rx.is_connected_to(app.IFs.bus_out.IFs.rx))
self.assertTrue(app.IFs.bus_in.IFs.tx.is_connected_to(app.IFs.bus_out.IFs.tx))
self.assertTrue(app.IFs.bus_in.is_connected_to(app.IFs.bus_out))
self.assertTrue(app.bus_in.rx.is_connected_to(app.bus_out.rx))
self.assertTrue(app.bus_in.tx.is_connected_to(app.bus_out.tx))
self.assertTrue(app.bus_in.is_connected_to(app.bus_out))

def test_chains(self):
mifs = times(3, ModuleInterface)
Expand All @@ -76,16 +72,16 @@ def test_chains(self):
self.assertTrue(mifs[0].is_connected_to(mifs[2]))
self.assertIsInstance(mifs[0].is_connected_to(mifs[2]), _TLinkDirectShallow)

self.assertTrue(mifs[1].IFs.signal.is_connected_to(mifs[2].IFs.signal))
self.assertTrue(mifs[1].IFs.reference.is_connected_to(mifs[2].IFs.reference))
self.assertFalse(mifs[0].IFs.signal.is_connected_to(mifs[1].IFs.signal))
self.assertFalse(mifs[0].IFs.reference.is_connected_to(mifs[1].IFs.reference))
self.assertFalse(mifs[0].IFs.signal.is_connected_to(mifs[2].IFs.signal))
self.assertFalse(mifs[0].IFs.reference.is_connected_to(mifs[2].IFs.reference))
self.assertTrue(mifs[1].signal.is_connected_to(mifs[2].signal))
self.assertTrue(mifs[1].reference.is_connected_to(mifs[2].reference))
self.assertFalse(mifs[0].signal.is_connected_to(mifs[1].signal))
self.assertFalse(mifs[0].reference.is_connected_to(mifs[1].reference))
self.assertFalse(mifs[0].signal.is_connected_to(mifs[2].signal))
self.assertFalse(mifs[0].reference.is_connected_to(mifs[2].reference))

# Test duplicate resolution
mifs[0].IFs.signal.connect(mifs[1].IFs.signal)
mifs[0].IFs.reference.connect(mifs[1].IFs.reference)
mifs[0].signal.connect(mifs[1].signal)
mifs[0].reference.connect(mifs[1].reference)
self.assertIsInstance(mifs[0].is_connected_to(mifs[1]), LinkDirect)
self.assertIsInstance(mifs[0].is_connected_to(mifs[2]), LinkDirect)

Expand All @@ -107,12 +103,12 @@ class _IFs(super().IFS()):
self.add_trait(has_single_electric_reference_defined(ref))

for el, lo in chain(
zip(self.IFs.ins, self.IFs.ins_l),
zip(self.IFs.outs, self.IFs.outs_l),
zip(self.ins, self.ins_l),
zip(self.outs, self.outs_l),
):
lo.IFs.signal.connect(el)
lo.signal.connect(el)

for l1, l2 in zip(self.IFs.ins_l, self.IFs.outs_l):
for l1, l2 in zip(self.ins_l, self.outs_l):
l1.connect_shallow(l2)

class UARTBuffer(Module):
Expand All @@ -131,14 +127,14 @@ class _IFs(super().IFS()):

ElectricLogic.connect_all_module_references(self)

bus1 = self.IFs.bus_in
bus2 = self.IFs.bus_out
buf = self.NODEs.buf
bus1 = self.bus_in
bus2 = self.bus_out
buf = self.buf

bus1.IFs.tx.IFs.signal.connect(buf.IFs.ins[0])
bus1.IFs.rx.IFs.signal.connect(buf.IFs.ins[1])
bus2.IFs.tx.IFs.signal.connect(buf.IFs.outs[0])
bus2.IFs.rx.IFs.signal.connect(buf.IFs.outs[1])
bus1.tx.signal.connect(buf.ins[0])
bus1.rx.signal.connect(buf.ins[1])
bus2.tx.signal.connect(buf.outs[0])
bus2.rx.signal.connect(buf.outs[1])

import faebryk.core.core as c

Expand All @@ -153,19 +149,19 @@ def _assert_no_link(mif1, mif2):
err = "\n" + print_stack(link.tb)
self.assertFalse(link, err)

bus1 = app.IFs.bus_in
bus2 = app.IFs.bus_out
buf = app.NODEs.buf
bus1 = app.bus_in
bus2 = app.bus_out
buf = app.buf

# Check that the two buffer sides are not connected electrically
_assert_no_link(buf.IFs.ins[0], buf.IFs.outs[0])
_assert_no_link(buf.IFs.ins[1], buf.IFs.outs[1])
_assert_no_link(bus1.IFs.rx.IFs.signal, bus2.IFs.rx.IFs.signal)
_assert_no_link(bus1.IFs.tx.IFs.signal, bus2.IFs.tx.IFs.signal)
_assert_no_link(buf.ins[0], buf.outs[0])
_assert_no_link(buf.ins[1], buf.outs[1])
_assert_no_link(bus1.rx.signal, bus2.rx.signal)
_assert_no_link(bus1.tx.signal, bus2.tx.signal)

# Check that the two buffer sides are connected logically
self.assertTrue(bus1.IFs.rx.is_connected_to(bus2.IFs.rx))
self.assertTrue(bus1.IFs.tx.is_connected_to(bus2.IFs.tx))
self.assertTrue(bus1.rx.is_connected_to(bus2.rx))
self.assertTrue(bus1.tx.is_connected_to(bus2.tx))
self.assertTrue(bus1.is_connected_to(bus2))

def test_specialize(self):
Expand Down
49 changes: 25 additions & 24 deletions test/core/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import TypeVar

from faebryk.core.core import logger as core_logger
from faebryk.core.module import Module, Parameter
from faebryk.core.module import Module
from faebryk.core.parameter import Parameter
from faebryk.core.util import specialize_module
from faebryk.library.ANY import ANY
from faebryk.library.Constant import Constant
Expand Down Expand Up @@ -271,26 +272,26 @@ class NODES(super().NODES()):

m = Modules()

UART_A = m.NODEs.UART_A
UART_B = m.NODEs.UART_B
UART_C = m.NODEs.UART_C
UART_A = m.UART_A
UART_B = m.UART_B
UART_C = m.UART_C

UART_A.connect(UART_B)

UART_A.PARAMs.baud.merge(Constant(9600 * P.baud))
UART_A.baud.merge(Constant(9600 * P.baud))

for uart in [UART_A, UART_B]:
self.assertEqual(
assertIsInstance(uart.PARAMs.baud.get_most_narrow(), Constant).value,
assertIsInstance(uart.baud.get_most_narrow(), Constant).value,
9600 * P.baud,
)

UART_C.PARAMs.baud.merge(Range(1200 * P.baud, 115200 * P.baud))
UART_C.baud.merge(Range(1200 * P.baud, 115200 * P.baud))
UART_A.connect(UART_C)

for uart in [UART_A, UART_B, UART_C]:
self.assertEqual(
assertIsInstance(uart.PARAMs.baud.get_most_narrow(), Constant).value,
assertIsInstance(uart.baud.get_most_narrow(), Constant).value,
9600 * P.baud,
)

Expand Down Expand Up @@ -343,34 +344,34 @@ class _NODES(Module.NODES()):

self.NODEs = _NODES(self)

self.NODEs.led.IFs.power.connect(self.NODEs.battery.IFs.power)
self.led.power.connect(self.battery.power)

# Parametrize
self.NODEs.led.NODEs.led.PARAMs.color.merge(F.LED.Color.YELLOW)
self.NODEs.led.NODEs.led.PARAMs.brightness.merge(
self.led.led.color.merge(F.LED.Color.YELLOW)
self.led.led.brightness.merge(
TypicalLuminousIntensity.APPLICATION_LED_INDICATOR_INSIDE.value.value
)

app = App()

bcell = specialize_module(app.NODEs.battery, F.ButtonCell())
bcell.PARAMs.voltage.merge(3 * P.V)
bcell.PARAMs.capacity.merge(Range.from_center(225 * P.mAh, 50 * P.mAh))
bcell.PARAMs.material.merge(F.ButtonCell.Material.Lithium)
bcell.PARAMs.size.merge(F.ButtonCell.Size.N_2032)
bcell.PARAMs.shape.merge(F.ButtonCell.Shape.Round)
bcell = specialize_module(app.battery, F.ButtonCell())
bcell.voltage.merge(3 * P.V)
bcell.capacity.merge(Range.from_center(225 * P.mAh, 50 * P.mAh))
bcell.material.merge(F.ButtonCell.Material.Lithium)
bcell.size.merge(F.ButtonCell.Size.N_2032)
bcell.shape.merge(F.ButtonCell.Shape.Round)

app.NODEs.led.NODEs.led.PARAMs.color.merge(F.LED.Color.YELLOW)
app.NODEs.led.NODEs.led.PARAMs.max_brightness.merge(500 * P.millicandela)
app.NODEs.led.NODEs.led.PARAMs.forward_voltage.merge(1.2 * P.V)
app.NODEs.led.NODEs.led.PARAMs.max_current.merge(20 * P.mA)
app.led.led.color.merge(F.LED.Color.YELLOW)
app.led.led.max_brightness.merge(500 * P.millicandela)
app.led.led.forward_voltage.merge(1.2 * P.V)
app.led.led.max_current.merge(20 * P.mA)

v = app.NODEs.battery.PARAMs.voltage
# vbcell = bcell.PARAMs.voltage
v = app.battery.voltage
# vbcell = bcell.voltage
# print(pretty_param_tree_top(v))
# print(pretty_param_tree_top(vbcell))
self.assertEqual(v.get_most_narrow(), 3 * P.V)
r = app.NODEs.led.NODEs.current_limiting_resistor.PARAMs.resistance
r = app.led.current_limiting_resistor.resistance
r = r.get_most_narrow()
self.assertIsInstance(r, Range, f"{type(r)}")

Expand Down
Loading

0 comments on commit 0661a45

Please sign in to comment.