diff --git a/src/faebryk/core/moduleinterface.py b/src/faebryk/core/moduleinterface.py index e2ad8858..9f950a82 100644 --- a/src/faebryk/core/moduleinterface.py +++ b/src/faebryk/core/moduleinterface.py @@ -131,7 +131,7 @@ def __init__(self): def __preinit__(self) -> None: ... @staticmethod - def _get_connected(gif: GraphInterface): + def _get_connected(gif: GraphInterface, clss: bool): assert isinstance(gif.node, ModuleInterface) connections = gif.edges.items() @@ -139,36 +139,29 @@ def _get_connected(gif: GraphInterface): assert len(connections) == len({c[0] for c in connections}) return { - cast_assert(ModuleInterface, s.node): link + cast_assert(ModuleInterface, s.node): (link if not clss else type(link)) for s, link in connections if s.node is not gif.node } - def get_connected(self): - return self._get_connected(self.connected) + def get_connected(self, clss: bool = False): + return self._get_connected(self.connected, clss) - def get_specialized(self): - return self._get_connected(self.specialized) + def get_specialized(self, clss: bool = False): + return self._get_connected(self.specialized, clss) - def get_specializes(self): - return self._get_connected(self.specializes) + def get_specializes(self, clss: bool = False): + return self._get_connected(self.specializes, clss) @staticmethod def _cross_connect( - s_group_: dict["ModuleInterface", type[Link] | Link], - d_group_: dict["ModuleInterface", type[Link] | Link], + s_group: dict["ModuleInterface", type[Link]], + d_group: dict["ModuleInterface", type[Link]], linkcls: type[Link], hint=None, ): if logger.isEnabledFor(logging.DEBUG) and hint is not None: - logger.debug(f"Connect {hint} {s_group_} -> {d_group_}") - - s_group: dict["ModuleInterface", type[Link]] = { - k: type(v) if not isinstance(v, type) else v for k, v in s_group_.items() - } - d_group: dict["ModuleInterface", type[Link]] = { - k: type(v) if not isinstance(v, type) else v for k, v in d_group_.items() - } + logger.debug(f"Connect {hint} {s_group} -> {d_group}") for s, slink in s_group.items(): linkclss = {slink, linkcls} @@ -203,13 +196,21 @@ def _connect_siblings_and_connections( logger.debug(f"MIF connection: {self} to {other}") # Connect to all connections - s_con = self.get_connected() | {self: linkcls} - d_con = other.get_connected() | {other: linkcls} + s_con = self.get_connected(clss=True) | {self: linkcls} + d_con = other.get_connected(clss=True) | {other: linkcls} ModuleInterface._cross_connect(s_con, d_con, linkcls, "connections") # Connect to all siblings - s_sib = self.get_specialized() | self.get_specializes() | {self: linkcls} - d_sib = other.get_specialized() | other.get_specializes() | {other: linkcls} + s_sib = ( + self.get_specialized(clss=True) + | self.get_specializes(clss=True) + | {self: linkcls} + ) + d_sib = ( + other.get_specialized(clss=True) + | other.get_specializes(clss=True) + | {other: linkcls} + ) ModuleInterface._cross_connect(s_sib, d_sib, linkcls, "siblings") return self