diff --git a/src/faebryk/core/moduleinterface.py b/src/faebryk/core/moduleinterface.py index 8c70ced3..ae3f0c7c 100644 --- a/src/faebryk/core/moduleinterface.py +++ b/src/faebryk/core/moduleinterface.py @@ -22,7 +22,7 @@ ) from faebryk.core.node import Node from faebryk.core.trait import Trait -from faebryk.libs.util import print_stack +from faebryk.libs.util import once, print_stack logger = logging.getLogger(__name__) @@ -118,6 +118,7 @@ class TraitT(Trait): ... # TODO rename @classmethod + @once def LinkDirectShallow(cls): """ Make link that only connects up but not down @@ -130,13 +131,11 @@ class _LinkDirectShallowMif( LinkDirectShallow(lambda link, gif: test(gif.node)) ): ... - return _LinkDirectShallowMif + print("Make shallow for", cls) - _LinkDirectShallow: type[_TLinkDirectShallow] | None = None + return _LinkDirectShallowMif - def __preinit__(self) -> None: - if not type(self)._LinkDirectShallow: - type(self)._LinkDirectShallow = type(self).LinkDirectShallow() + def __preinit__(self) -> None: ... def _connect_siblings_and_connections( self, other: "ModuleInterface", linkcls: type[Link] @@ -335,7 +334,7 @@ def connect_via( intf.connect(other, linkcls=linkcls) def connect_shallow(self, other: Self) -> Self: - return self.connect(other, linkcls=type(self)._LinkDirectShallow) + return self.connect(other, linkcls=type(self).LinkDirectShallow()) def is_connected_to(self, other: "ModuleInterface"): return self.connected.is_connected(other.connected) diff --git a/src/faebryk/libs/util.py b/src/faebryk/libs/util.py index 8f68a717..10266b89 100644 --- a/src/faebryk/libs/util.py +++ b/src/faebryk/libs/util.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT import asyncio +import functools import inspect import logging from abc import abstractmethod @@ -797,3 +798,20 @@ def __() -> T: return __ return _ + + +def once[T, **P](f: Callable[P, T]) -> Callable[P, T]: + class _once: + def __init__(self) -> None: + self.cache = {} + + def __call__(self, *args: P.args, **kwds: P.kwargs) -> Any: + lookup = (args, tuple(kwds.items())) + if lookup in self.cache: + return self.cache[lookup] + + result = f(*args, **kwds) + self.cache[lookup] = result + return result + + return _once() diff --git a/test/core/test_hierarchy_connect.py b/test/core/test_hierarchy_connect.py index 559d3596..ed9e78b9 100644 --- a/test/core/test_hierarchy_connect.py +++ b/test/core/test_hierarchy_connect.py @@ -79,6 +79,11 @@ def test_chains(self): def test_bridge(self): self_ = self + # U1 ---> _________B________ ---> U2 + # TX IL ===> OL TX + # S --> I -> S S -> O --> S + # R -------- R ----- R -------- R + class Buffer(Module): ins = L.if_list(2, F.Electrical) outs = L.if_list(2, F.Electrical) @@ -171,12 +176,18 @@ def _assert_link(mif1: ModuleInterface, mif2: ModuleInterface, link=None): ) _assert_link(buf.ins_l[0].reference, buf.outs_l[0].reference) _assert_link(buf.outs_l[1].reference, buf.ins_l[0].reference) - _assert_link(bus1.rx.reference, bus2.rx.reference, LinkDirect) + # connect through up + _assert_link(bus1.tx, buf.ins_l[0], LinkDirect) + _assert_link(bus2.tx, buf.outs_l[0], LinkDirect) + + # connect shallow + _assert_link(buf.ins_l[0], buf.outs_l[0], _TLinkDirectShallow) + # Check that the two buffer sides are connected logically - _assert_link(bus1.rx, bus2.rx) _assert_link(bus1.tx, bus2.tx) + _assert_link(bus1.rx, bus2.rx) _assert_link(bus1, bus2) def test_specialize(self): @@ -222,7 +233,4 @@ class _Link(LinkDirectShallow(lambda link, gif: True)): ... if __name__ == "__main__": - # unittest.main() - import typer - - typer.run(TestHierarchy().test_bridge) + unittest.main() diff --git a/test/libs/util.py b/test/libs/util.py index 97c3fb37..445c4bf6 100644 --- a/test/libs/util.py +++ b/test/libs/util.py @@ -5,7 +5,7 @@ from itertools import combinations from faebryk.libs.logging import setup_basic_logging -from faebryk.libs.util import SharedReference, zip_non_locked +from faebryk.libs.util import SharedReference, once, zip_non_locked class TestUtil(unittest.TestCase): @@ -54,6 +54,44 @@ def all_equal(*args: SharedReference): all_equal(r1, r2, r3, r4, r5) + def test_once(self): + global ran + ran = False + + @once + def do(val: int): + global ran + ran = True + return val + + self.assertFalse(ran) + + self.assertEqual(do(5), 5) + self.assertTrue(ran) + ran = False + + self.assertEqual(do(5), 5) + self.assertFalse(ran) + + self.assertEqual(do(6), 6) + self.assertTrue(ran) + ran = False + + class A: + @classmethod + @once + def do(cls): + global ran + ran = True + return cls + + self.assertEqual(A.do(), A) + self.assertTrue(ran) + ran = False + + self.assertEqual(A.do(), A) + self.assertFalse(ran) + if __name__ == "__main__": setup_basic_logging()