From 57516b659d61ca15055ca68f804439fd1909b7d1 Mon Sep 17 00:00:00 2001 From: iopapamanoglou Date: Fri, 23 Aug 2024 04:01:09 +0200 Subject: [PATCH] Implement field setup --- new_holders_flat.py | 64 +++-- src/faebryk/core/link.py | 3 +- src/faebryk/core/moduleinterface.py | 3 + src/faebryk/core/node.py | 232 +++++++++++++++--- src/faebryk/core/trait.py | 4 +- .../exporters/pcb/kicad/transformer.py | 17 +- src/faebryk/library/Diode.py | 1 + src/faebryk/library/Footprint.py | 7 +- src/faebryk/library/FootprintTrait.py | 17 -- src/faebryk/library/has_linked_pad.py | 4 +- src/faebryk/library/is_esphome_bus.py | 2 +- src/faebryk/libs/app/checks.py | 2 +- src/faebryk/libs/app/pcb.py | 2 +- src/faebryk/libs/picker/picker.py | 4 +- test/core/test_parameters.py | 2 +- 15 files changed, 264 insertions(+), 100 deletions(-) delete mode 100644 src/faebryk/library/FootprintTrait.py diff --git a/new_holders_flat.py b/new_holders_flat.py index b87bf3c8..561eb2b3 100644 --- a/new_holders_flat.py +++ b/new_holders_flat.py @@ -1,15 +1,13 @@ from dataclasses import field +import typer + from faebryk.core.module import Module from faebryk.core.node import d_field, if_list, rt_field -from faebryk.core.util import as_unit from faebryk.library.can_bridge_defined import can_bridge_defined from faebryk.library.Electrical import Electrical from faebryk.library.has_designator_prefix import has_designator_prefix from faebryk.library.has_designator_prefix_defined import has_designator_prefix_defined -from faebryk.library.has_simple_value_representation_based_on_param import ( - has_simple_value_representation_based_on_param, -) from faebryk.library.TBD import TBD from faebryk.libs.units import Quantity from faebryk.libs.util import times @@ -27,6 +25,8 @@ class Diode2(Module): anode: Electrical cathode: Electrical + XXXXX: Electrical = d_field(Electrical) + # static trait designator_prefix: has_designator_prefix = d_field( lambda: has_designator_prefix_defined("D") @@ -37,43 +37,55 @@ class Diode2(Module): def bridge(self): return can_bridge_defined(self.anode, self.cathode) - def __finit__(self): - print("Called Diode __finit__") + def __preinit__(self): + print("Called Diode __preinit__") # anonymous dynamic trait - self.add( - has_simple_value_representation_based_on_param( - self.forward_voltage, - lambda p: as_unit(p, "V"), - ) # type: ignore - ) + # self.add( + # has_simple_value_representation_based_on_param( + # self.forward_voltage, + # lambda p: as_unit(p, "V"), + # ) # type: ignore + # ) class LED2(Diode2): color: TBD[float] - def __finit__(self): - print("Called LED __finit__") + def __preinit__(self): + print("Called LED __preinit__") class LED2_NOINT(LED2, init=False): - def __finit__(self): - print("Called LED_NOINT __finit__") + def __preinit__(self): + print("Called LED_NOINT __preinit__") class LED2_WITHEXTRAT_IFS(LED2): extra: list[Electrical] = field(default_factory=lambda: times(2, Electrical)) extra2: list[Electrical] = if_list(Electrical, 2) - def __finit__(self): - print("Called LED_WITHEXTRAT_IFS __finit__") + @rt_field + def bridge(self): + return can_bridge_defined(self.extra2[0], self.extra2[1]) + + def __preinit__(self): + print("Called LED_WITHEXTRAT_IFS __preinit__") + + +def main(): + print("Diode init ----") + _D = Diode2() + print("LED init ----") + _L = LED2() + print("LEDNOINIT init ----") + L2 = LED2_NOINT() + print("LEDEXTRA init ----") + L3 = LED2_WITHEXTRAT_IFS() + + L3.cathode.connect(L2.cathode) + + assert L3.cathode.is_connected_to(L2.cathode) -print("Diode init ----") -D = Diode2() -print("LED init ----") -L = LED2() -print("LEDNOINIT init ----") -L2 = LED2_NOINT() -print("LEDEXTRA init ----") -L3 = LED2_WITHEXTRAT_IFS() +typer.run(main) diff --git a/src/faebryk/core/link.py b/src/faebryk/core/link.py index 94108d70..d8920dee 100644 --- a/src/faebryk/core/link.py +++ b/src/faebryk/core/link.py @@ -50,8 +50,9 @@ def get_connections(self) -> list["GraphInterface"]: class LinkParent(Link): def __init__(self, interfaces: list["GraphInterface"]) -> None: super().__init__() + from faebryk.core.graphinterface import GraphInterfaceHierarchical - assert all([isinstance(i, "GraphInterfaceHierarchical") for i in interfaces]) + assert all([isinstance(i, GraphInterfaceHierarchical) for i in interfaces]) # TODO rethink invariant assert len(interfaces) == 2 assert len([i for i in interfaces if i.is_parent]) == 1 # type: ignore diff --git a/src/faebryk/core/moduleinterface.py b/src/faebryk/core/moduleinterface.py index bbdd24a6..765ecf84 100644 --- a/src/faebryk/core/moduleinterface.py +++ b/src/faebryk/core/moduleinterface.py @@ -21,6 +21,7 @@ _TLinkDirectShallow, ) from faebryk.core.node import Node +from faebryk.core.trait import Trait from faebryk.libs.util import print_stack logger = logging.getLogger(__name__) @@ -90,6 +91,8 @@ class GraphInterfaceModuleConnection(GraphInterface): ... class ModuleInterface(Node): + class TraitT(Trait["ModuleInterface"]): ... + specializes: GraphInterface specialized: GraphInterface connected: GraphInterfaceModuleConnection diff --git a/src/faebryk/core/node.py b/src/faebryk/core/node.py index d09d1fa0..7252d4c0 100644 --- a/src/faebryk/core/node.py +++ b/src/faebryk/core/node.py @@ -1,11 +1,9 @@ # This file is part of the faebryk project # SPDX-License-Identifier: MIT import logging -from dataclasses import Field, field, fields, is_dataclass from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Type +from typing import TYPE_CHECKING, Any, Callable, Type, get_args, get_origin -from attr import dataclass from deprecated import deprecated from faebryk.core.core import ID_REPR, FaebrykLibObject @@ -15,7 +13,7 @@ GraphInterfaceSelf, ) from faebryk.core.link import LinkNamedParent, LinkSibling -from faebryk.libs.util import KeyErrorNotFound, find, try_avoid_endless_recursion +from faebryk.libs.util import KeyErrorNotFound, find, times, try_avoid_endless_recursion if TYPE_CHECKING: from faebryk.core.trait import Trait, TraitImpl @@ -36,34 +34,46 @@ class FieldContainerError(FieldError): def if_list[T](if_type: type[T], n: int) -> list[T]: - return field(default_factory=lambda: [if_type() for _ in range(n)]) + return d_field(lambda: times(n, if_type)) -class rt_field[T](property): +class f_field: + pass + + +class rt_field[T](property, f_field): def __init__(self, fget: Callable[[T], Any]) -> None: super().__init__() self.func = fget - def _construct(self, obj: T, holder: type): - self.constructed = self.func(holder, obj) + def _construct(self, obj: T): + self.constructed = self.func(obj) + return self.constructed def __get__(self, instance: Any, owner: type | None = None) -> Any: - return self.constructed() + return self.constructed + + +class _d_field[T](f_field): + def __init__(self, default_factory: Callable[[], T]) -> None: + self.type = None + self.default_factory = default_factory -def d_field(default_factory: Callable[[], Any], **kwargs): - return field(default_factory=default_factory, **kwargs) +def d_field[T](default_factory: Callable[[], T]) -> T: + return _d_field(default_factory) # type: ignore # ----------------------------------------------------------------------------- +# @dataclass(init=False, kw_only=True) class Node(FaebrykLibObject): runtime_anon: list["Node"] runtime: dict[str, "Node"] specialized: list["Node"] - self_gif: GraphInterface + self_gif: GraphInterfaceSelf children: GraphInterfaceHierarchical = d_field( lambda: GraphInterfaceHierarchical(is_parent=True) ) @@ -74,7 +84,8 @@ class Node(FaebrykLibObject): _init: bool = False def __hash__(self) -> int: - raise NotImplementedError() + # TODO proper hash + return hash(id(self)) def add( self, @@ -107,42 +118,189 @@ def add( self._handle_add_node(name, obj) def __init_subclass__(cls, *, init: bool = True) -> None: - print("Called Node __subclass__", "-" * 20, cls.__qualname__) super().__init_subclass__() - cls._init = init - for name, obj in chain( - [(f.name, f.type) for f in fields(cls)] if is_dataclass(cls) else [], - *[ - base.__annotations__.items() - for base in cls.__mro__ - if hasattr(base, "__annotations__") - ], - [ - (name, f) - for name, f in vars(cls).items() - if isinstance(f, (rt_field, Field)) - ], - ): - if name.startswith("_"): - continue - print(f"{cls.__qualname__}.{name} = {obj}, {type(obj)}") + def _setup_fields(self, cls): + def all_vars(cls): + return {k: v for c in reversed(cls.__mro__) for k, v in vars(c).items()} + + def all_anno(cls): + return { + k: v + for c in reversed(cls.__mro__) + if hasattr(c, "__annotations__") + for k, v in c.__annotations__.items() + } + + LL_Types = (Node, GraphInterface) + + annos = all_anno(cls) + vars_ = all_vars(cls) + for name, obj in vars_.items(): + if isinstance(obj, _d_field): + obj.type = annos[name] + + def is_node_field(obj): + def is_genalias_node(obj): + origin = get_origin(obj) + assert origin is not None + + if issubclass(origin, LL_Types): + return True + + if issubclass(origin, (list, dict)): + arg = get_args(obj)[-1] + return is_node_field(arg) + + if isinstance(obj, LL_Types): + raise FieldError("Node instances not allowed") + + if isinstance(obj, str): + return obj in [L.__name__ for L in LL_Types] + + if isinstance(obj, type): + return issubclass(obj, LL_Types) + + if isinstance(obj, _d_field): + t = obj.type + if isinstance(t, type): + return issubclass(t, LL_Types) + + if get_origin(t): + return is_genalias_node(t) + + if get_origin(obj): + return is_genalias_node(obj) + + if isinstance(obj, rt_field): + return True + + return False + + clsfields_unf = { + name: obj + for name, obj in chain( + [(name, f) for name, f in annos.items()], + [(name, f) for name, f in vars_.items() if isinstance(f, f_field)], + ) + if not name.startswith("_") + } + + clsfields = { + name: obj for name, obj in clsfields_unf.items() if is_node_field(obj) + } + + # for name, obj in clsfields_unf.items(): + # if isinstance(obj, _d_field): + # obj = obj.type + # filtered = name not in clsfields + # filtered_str = " FILTERED" if filtered else "" + # print( + # f"{cls.__qualname__+"."+name+filtered_str:<60} = {str(obj):<70} " + # "| {type(obj)}" + # ) + + objects: dict[str, Node | GraphInterface] = {} + + def append(name, inst): + if isinstance(inst, LL_Types): + objects[name] = inst + elif isinstance(inst, list): + for i, obj in enumerate(inst): + objects[f"{name}[{i}]"] = obj + elif isinstance(inst, dict): + for k, obj in inst.items(): + objects[f"{name}[{k}]"] = obj + + return inst + + def setup_field(name, obj): + def setup_gen_alias(name, obj): + origin = get_origin(obj) + assert origin + if isinstance(origin, type): + setattr(self, name, append(name, origin())) + return + raise NotImplementedError(origin) + + if isinstance(obj, str): + raise NotImplementedError() + + if get_origin(obj): + setup_gen_alias(name, obj) + return + + if isinstance(obj, _d_field): + t = obj.type + + if isinstance(obj, _d_field): + inst = append(name, obj.default_factory()) + setattr(self, name, inst) + return + + if isinstance(t, type): + setattr(self, name, append(name, t())) + return - # NOTES: - # - first construct than call handle (for eliminating hazards) + if get_origin(t): + setup_gen_alias(name, t) + return + + raise NotImplementedError() + + if isinstance(obj, type): + setattr(self, name, append(name, obj())) + return + + if isinstance(obj, rt_field): + append(name, obj._construct(self)) + return + + raise NotImplementedError() + + for name, obj in clsfields.items(): + setup_field(name, obj) + + return objects, clsfields def __init__(self) -> None: - print("Called Node init", "-" * 20) + cls = type(self) + # print(f"Called Node init {cls.__qualname__:<20} {'-' * 80}") + + # check if accidentally added a node instance instead of field + node_instances = [f for f in vars(cls).values() if isinstance(f, Node)] + if node_instances: + raise FieldError(f"Node instances not allowed: {node_instances}") + + # Construct Fields + objects, _ = self._setup_fields(cls) + + # Add Fields to Node + for name, obj in sorted( + objects.items(), key=lambda x: isinstance(x[1], GraphInterfaceSelf) + ): + if isinstance(obj, GraphInterface): + self._handle_add_gif(name, obj) + elif isinstance(obj, Node): + self._handle_add_node(name, obj) + else: + assert False + + # Call 2-stage constructors if self._init: for base in reversed(type(self).mro()): - if hasattr(base, "__finit__"): - base.__finit__(self) + if hasattr(base, "__preinit__"): + base.__preinit__(self) + for base in reversed(type(self).mro()): + if hasattr(base, "__postinit__"): + base.__postinit__(self) def _handle_add_gif(self, name: str, gif: GraphInterface): gif.node = self gif.name = name - gif.connect(self.self_gif, linkcls=LinkSibling) + if not isinstance(gif, GraphInterfaceSelf): + gif.connect(self.self_gif, linkcls=LinkSibling) def _handle_add_node(self, name: str, node: "Node"): assert not ( diff --git a/src/faebryk/core/trait.py b/src/faebryk/core/trait.py index b88eff13..e629ffd3 100644 --- a/src/faebryk/core/trait.py +++ b/src/faebryk/core/trait.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class Trait[T: Node]: +class Trait[T: Node](Node): @classmethod def impl(cls: type["Trait"]): class _Impl[T_: Node](TraitImpl[T_], cls): ... @@ -23,7 +23,7 @@ class _Impl[T_: Node](TraitImpl[T_], cls): ... U = TypeVar("U", bound="FaebrykLibObject") -class TraitImpl[U: Node](Node, ABC): +class TraitImpl[U: Node](ABC): _trait: type[Trait[U]] def __finit__(self) -> None: diff --git a/src/faebryk/exporters/pcb/kicad/transformer.py b/src/faebryk/exporters/pcb/kicad/transformer.py index 7a9b6862..cb313668 100644 --- a/src/faebryk/exporters/pcb/kicad/transformer.py +++ b/src/faebryk/exporters/pcb/kicad/transformer.py @@ -14,14 +14,15 @@ from shapely import Polygon from typing_extensions import deprecated -from faebryk.core.core import ( - Graph, - Module, - ModuleInterface.TraitT, - Module.TraitT, - Node, +from faebryk.core.graphinterface import Graph +from faebryk.core.module import Module +from faebryk.core.moduleinterface import ModuleInterface +from faebryk.core.node import Node +from faebryk.core.util import ( + get_all_nodes_with_trait, + get_all_nodes_with_traits, + get_children, ) -from faebryk.core.util import get_all_nodes_with_trait, get_all_nodes_with_traits from faebryk.library.Electrical import Electrical from faebryk.library.Footprint import ( Footprint as FFootprint, @@ -259,7 +260,7 @@ def attach(self): node.add_trait(self.has_linked_kicad_footprint_defined(fp, self)) pin_names = g_fp.get_trait(has_kicad_footprint).get_pin_names() - for fpad in g_fp.IFs.get_all(): + for fpad in get_children(g_fp, direct_only=True, types=ModuleInterface): pads = [ pad for pad in fp.pads diff --git a/src/faebryk/library/Diode.py b/src/faebryk/library/Diode.py index 75ae7144..b8650d65 100644 --- a/src/faebryk/library/Diode.py +++ b/src/faebryk/library/Diode.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT from faebryk.core.module import Module +from faebryk.core.parameter import Parameter from faebryk.core.util import as_unit from faebryk.library.can_bridge_defined import can_bridge_defined from faebryk.library.Electrical import Electrical diff --git a/src/faebryk/library/Footprint.py b/src/faebryk/library/Footprint.py index ed0454e2..20eb2449 100644 --- a/src/faebryk/library/Footprint.py +++ b/src/faebryk/library/Footprint.py @@ -2,10 +2,15 @@ # SPDX-License-Identifier: MIT -from faebryk.core.module import Module, ModuleInterface, Node +from faebryk.core.module import Module +from faebryk.core.moduleinterface import ModuleInterface +from faebryk.core.node import Node +from faebryk.core.trait import Trait class Footprint(Module): + class TraitT(Trait["Footprint"]): ... + def __init__(self) -> None: super().__init__() diff --git a/src/faebryk/library/FootprintTrait.py b/src/faebryk/library/FootprintTrait.py deleted file mode 100644 index c640c8cd..00000000 --- a/src/faebryk/library/FootprintTrait.py +++ /dev/null @@ -1,17 +0,0 @@ -# This file is part of the faebryk project -# SPDX-License-Identifier: MIT - -from typing import Generic, TypeVar - -from faebryk.core.core import ( - _Module.TraitT, -) -from faebryk.library.Footprint import Footprint - -TF = TypeVar("TF", bound="Footprint") - - -class _FootprintTrait(Generic[TF], _Module.TraitT[TF]): ... - - -class FootprintTrait(_FootprintTrait["Footprint"]): ... diff --git a/src/faebryk/library/has_linked_pad.py b/src/faebryk/library/has_linked_pad.py index 408bef5e..74fcb3a5 100644 --- a/src/faebryk/library/has_linked_pad.py +++ b/src/faebryk/library/has_linked_pad.py @@ -3,9 +3,7 @@ from abc import abstractmethod -from faebryk.core.core import ( - ModuleInterface.TraitT, -) +from faebryk.core.moduleinterface import ModuleInterface from faebryk.library.Pad import Pad diff --git a/src/faebryk/library/is_esphome_bus.py b/src/faebryk/library/is_esphome_bus.py index aca39db3..d42104a2 100644 --- a/src/faebryk/library/is_esphome_bus.py +++ b/src/faebryk/library/is_esphome_bus.py @@ -3,7 +3,7 @@ from abc import abstractmethod -from faebryk.core.moduleinterface import ModuleInterface, ModuleInterface.TraitT +from faebryk.core.moduleinterface import ModuleInterface from faebryk.libs.util import find diff --git a/src/faebryk/libs/app/checks.py b/src/faebryk/libs/app/checks.py index 76957107..44439aaf 100644 --- a/src/faebryk/libs/app/checks.py +++ b/src/faebryk/libs/app/checks.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: MIT -from faebryk.core.module import Module from faebryk.core.graph import Graph +from faebryk.core.module import Module from faebryk.libs.app.erc import simple_erc diff --git a/src/faebryk/libs/app/pcb.py b/src/faebryk/libs/app/pcb.py index 27880e46..a90e7b44 100644 --- a/src/faebryk/libs/app/pcb.py +++ b/src/faebryk/libs/app/pcb.py @@ -7,8 +7,8 @@ from pathlib import Path from typing import Any, Callable -from faebryk.core.module import Module from faebryk.core.graph import Graph +from faebryk.core.module import Module from faebryk.core.util import get_node_tree, iter_tree_by_depth from faebryk.exporters.pcb.kicad.transformer import PCB_Transformer from faebryk.exporters.pcb.routing.util import apply_route_in_pcb diff --git a/src/faebryk/libs/picker/picker.py b/src/faebryk/libs/picker/picker.py index 20691cbf..b7218045 100644 --- a/src/faebryk/libs/picker/picker.py +++ b/src/faebryk/libs/picker/picker.py @@ -12,7 +12,9 @@ from rich.progress import Progress -from faebryk.core.module import Module, ModuleInterface, Module.TraitT, Parameter +from faebryk.core.module import Module +from faebryk.core.moduleinterface import ModuleInterface +from faebryk.core.parameter import Parameter from faebryk.core.util import ( get_all_modules, get_children, diff --git a/test/core/test_parameters.py b/test/core/test_parameters.py index 253abae2..8c1c2977 100644 --- a/test/core/test_parameters.py +++ b/test/core/test_parameters.py @@ -6,8 +6,8 @@ from operator import add from typing import TypeVar -from faebryk.core.module import Module, Parameter from faebryk.core.core import logger as core_logger +from faebryk.core.module import Module, Parameter from faebryk.core.util import specialize_module from faebryk.library.ANY import ANY from faebryk.library.Constant import Constant