Skip to content
This repository has been archived by the owner on Dec 10, 2024. It is now read-only.

Commit

Permalink
Remove trait generic
Browse files Browse the repository at this point in the history
  • Loading branch information
iopapamanoglou committed Aug 23, 2024
1 parent 57516b6 commit d75d024
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 31 deletions.
28 changes: 17 additions & 11 deletions new_holders_flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@

from faebryk.core.module import Module
from faebryk.core.node import d_field, if_list, rt_field
from faebryk.core.util import as_unit
from faebryk.library.can_bridge_defined import can_bridge_defined
from faebryk.library.Electrical import Electrical
from faebryk.library.has_designator_prefix import has_designator_prefix
from faebryk.library.has_designator_prefix_defined import has_designator_prefix_defined
from faebryk.library.has_simple_value_representation import (
has_simple_value_representation,
)
from faebryk.library.has_simple_value_representation_based_on_param import (
has_simple_value_representation_based_on_param,
)
from faebryk.library.TBD import TBD
from faebryk.libs.units import Quantity
from faebryk.libs.units import P, Quantity
from faebryk.libs.util import times

# -----------------------------------------------------------------------------
Expand All @@ -25,10 +31,8 @@ class Diode2(Module):
anode: Electrical
cathode: Electrical

XXXXX: Electrical = d_field(Electrical)

# static trait
designator_prefix: has_designator_prefix = d_field(
designator_prefix: has_designator_prefix_defined = d_field(
lambda: has_designator_prefix_defined("D")
)

Expand All @@ -41,12 +45,12 @@ def __preinit__(self):
print("Called Diode __preinit__")

# anonymous dynamic trait
# self.add(
# has_simple_value_representation_based_on_param(
# self.forward_voltage,
# lambda p: as_unit(p, "V"),
# ) # type: ignore
# )
self.add(
has_simple_value_representation_based_on_param(
self.forward_voltage,
lambda p: as_unit(p, "V"),
)
)


class LED2(Diode2):
Expand Down Expand Up @@ -86,6 +90,8 @@ def main():
L3.cathode.connect(L2.cathode)

assert L3.cathode.is_connected_to(L2.cathode)
L3.forward_voltage.merge(5 * P.V)
L3.get_trait(has_simple_value_representation).get_value()


typer.run(main)
2 changes: 1 addition & 1 deletion src/faebryk/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class Module(Node):
class TraitT(Trait["Module"]): ...
class TraitT(Trait): ...

specializes: GraphInterface
specialized: GraphInterface
Expand Down
2 changes: 1 addition & 1 deletion src/faebryk/core/moduleinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class GraphInterfaceModuleConnection(GraphInterface): ...


class ModuleInterface(Node):
class TraitT(Trait["ModuleInterface"]): ...
class TraitT(Trait): ...

specializes: GraphInterface
specialized: GraphInterface
Expand Down
12 changes: 9 additions & 3 deletions src/faebryk/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def add(
container = self.runtime

try:
container_name = find(vars(self).items(), lambda x: x[1] == container)[0]
container_name = find(vars(self).items(), lambda x: x[1] is container)[0]
except KeyErrorNotFound:
raise FieldContainerError("Container not in fields")

Expand Down Expand Up @@ -296,6 +296,9 @@ def __init__(self) -> None:
if hasattr(base, "__postinit__"):
base.__postinit__(self)

def __preinit__(self): ...
def __postinit__(self): ...

def _handle_add_gif(self, name: str, gif: GraphInterface):
gif.node = self
gif.name = name
Expand Down Expand Up @@ -361,13 +364,14 @@ def add_trait[_TImpl: "TraitImpl"](self, trait: _TImpl) -> _TImpl:
return trait

def _find(self, trait, only_implemented: bool):
from faebryk.core.trait import TraitImpl
from faebryk.core.util import get_children

traits = get_children(self, direct_only=True, types=TraitImpl)
impls = get_children(self, direct_only=True, types=TraitImpl)

return [
impl
for impl in traits
for impl in impls
if impl.implements(trait)
and (impl.is_implemented() or not only_implemented)
]
Expand All @@ -385,6 +389,8 @@ def has_trait(self, trait) -> bool:
return len(self._find(trait, only_implemented=True)) > 0

def get_trait[V: "Trait"](self, trait: Type[V]) -> V:
from faebryk.core.trait import TraitImpl

assert not issubclass(
trait, TraitImpl
), "You need to specify the trait, not an impl"
Expand Down
2 changes: 1 addition & 1 deletion src/faebryk/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Parameter[PV](Node):
type LIT = PV | set[PV] | tuple[PV, PV]
type LIT_OR_PARAM = LIT | "Parameter[PV]"

class TraitT(Trait["Parameter"]): ...
class TraitT(Trait): ...

narrowed_by: GraphInterface
narrows: GraphInterface
Expand Down
23 changes: 9 additions & 14 deletions src/faebryk/core/trait.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,26 @@
# SPDX-License-Identifier: MIT
import logging
from abc import ABC
from typing import TypeVar

from deprecated import deprecated

from faebryk.core.core import FaebrykLibObject
from faebryk.core.node import Node

logger = logging.getLogger(__name__)


class Trait[T: Node](Node):
class Trait(Node):
@classmethod
def impl(cls: type["Trait"]):
class _Impl[T_: Node](TraitImpl[T_], cls): ...
def impl[T: "Trait"](cls: type[T]):
class _Impl(TraitImpl, cls): ...

return _Impl[T]
return _Impl


U = TypeVar("U", bound="FaebrykLibObject")
class TraitImpl(ABC):
_trait: type[Trait]


class TraitImpl[U: Node](ABC):
_trait: type[Trait[U]]

def __finit__(self) -> None:
def __preinit__(self) -> None:
found = False
bases = type(self).__bases__
while not found:
Expand Down Expand Up @@ -56,14 +51,14 @@ def remove_obj(self):
self._obj = None

@property
def obj(self) -> U:
def obj(self) -> Node:
p = self.get_parent()
if not p:
raise Exception("trait is not linked to node")
return p[0]

@deprecated("Use obj property")
def get_obj(self) -> U:
def get_obj(self) -> Node:
return self.obj

def cmp(self, other: "TraitImpl") -> tuple[bool, "TraitImpl"]:
Expand Down

0 comments on commit d75d024

Please sign in to comment.