diff --git a/src/faebryk/core/defaultsolver.py b/src/faebryk/core/defaultsolver.py index 801faebe..b334aa61 100644 --- a/src/faebryk/core/defaultsolver.py +++ b/src/faebryk/core/defaultsolver.py @@ -1,6 +1,7 @@ # This file is part of the faebryk project # SPDX-License-Identifier: MIT +from collections import defaultdict import logging from collections.abc import Iterable from statistics import median @@ -29,27 +30,60 @@ from faebryk.core.solver import Solver from faebryk.libs.sets import Range, Ranges from faebryk.libs.units import Quantity, dimensionless -from faebryk.libs.util import EquivalenceClasses +from faebryk.libs.util import EquivalenceClasses, unique logger = logging.getLogger(__name__) -def parameter_alias_classes(G: Graph) -> list[set[Parameter]]: +def debug_print(repr_map: dict[ParameterOperatable, ParameterOperatable]): + import sys + + if getattr(sys, "gettrace", lambda: None)(): + log = print + else: + log = logger.info + for s, d in repr_map.items(): + if isinstance(d, Expression): + if isinstance(s, Expression): + log(f"{s}[{s.operands}] -> {d}[{d.operands} | G {d.get_graph()!r}]") + else: + log(f"{s} -> {d}[{d.operands} | G {d.get_graph()!r}]") + else: + log(f"{s} -> {d} | G {d.get_graph()!r}") + graphs = unique(map(lambda p: p.get_graph(), repr_map.values()), lambda g: g()) + log(f"{len(graphs)} graphs") + + +def parameter_ops_alias_classes( + G: Graph, +) -> dict[ParameterOperatable, set[ParameterOperatable]]: # TODO just get passed - params = [ + param_ops = { p - for p in G.nodes_of_type(Parameter) + for p in G.nodes_of_type(ParameterOperatable) if get_constrained_predicates_involved_in(p) - ] - full_eq = EquivalenceClasses[Parameter](params) + }.difference(G.nodes_of_type(Predicate)) + full_eq = EquivalenceClasses[ParameterOperatable](param_ops) is_exprs = [e for e in G.nodes_of_type(Is) if e.constrained] for is_expr in is_exprs: - params_ops = [op for op in is_expr.operands if isinstance(op, Parameter)] - full_eq.add_eq(*params_ops) + full_eq.add_eq(*is_expr.operands) + + obvious_eq = defaultdict(list) + for p in param_ops: + obvious_eq[p.obviously_eq_hash()].append(p) + logger.info(f"obvious eq: {obvious_eq}") - return full_eq.get() + for candidates in obvious_eq.values(): + if len(candidates) > 1: + logger.debug(f"#obvious eq candidates: {len(candidates)}") + for i, p in enumerate(candidates): + for q in candidates[:i]: + if p.obviously_eq(q): + full_eq.add_eq(p, q) + break + return full_eq.classes def get_params_for_expr(expr: Expression) -> set[Parameter]: @@ -60,7 +94,7 @@ def get_params_for_expr(expr: Expression) -> set[Parameter]: def get_constrained_predicates_involved_in( - p: Parameter | Expression, + p: ParameterOperatable, ) -> set[Predicate]: # p.self -> p.operated_on -> e1.operates_on -> e1.self dependants = p.bfs_node( @@ -108,6 +142,9 @@ def create_new_expr( old_expr: Expression, *operands: ParameterOperatable.All ) -> Expression: new_expr = type(old_expr)(*operands) + for op in operands: + if isinstance(op, ParameterOperatable): + assert op.get_graph() == new_expr.get_graph() if isinstance(old_expr, Constrainable): cast(Constrainable, new_expr).constrained = old_expr.constrained return new_expr @@ -206,9 +243,9 @@ def resolve_alias_classes( G: Graph, ) -> tuple[dict[ParameterOperatable, ParameterOperatable], bool]: dirty = False - params = [ + params_ops = [ p - for p in G.nodes_of_type(Parameter) + for p in G.nodes_of_type(ParameterOperatable) if get_constrained_predicates_involved_in(p) ] exprs = G.nodes_of_type(Expression) @@ -216,11 +253,11 @@ def resolve_alias_classes( exprs.difference_update(predicates) exprs = {e for e in exprs if get_constrained_predicates_involved_in(e)} - p_alias_classes = parameter_alias_classes(G) + p_alias_classes = parameter_ops_alias_classes(G) dependency_classes = parameter_dependency_classes(G) infostr = ( - f"{len(params)} parameters" + f"{len(params_ops)} parametersoperable" f"\n {len(p_alias_classes)} alias classes" f"\n {len(dependency_classes)} dependency classes" "\n" @@ -230,61 +267,88 @@ def resolve_alias_classes( repr_map: dict[ParameterOperatable, ParameterOperatable] = {} # Make new param repre for alias classes - for alias_class in p_alias_classes: - # TODO short-cut if len() == 1 - if len(alias_class) > 1: - dirty = True + for param_op in ParameterOperatable.sort_by_depth(params_ops, ascending=True): + if param_op in repr_map or param_op not in p_alias_classes: + continue + + alias_class = p_alias_classes[param_op] + + # TODO short-cut if len() == 1 ? + param_alias_class = [p for p in alias_class if isinstance(p, Parameter)] + expr_alias_class = [p for p in alias_class if isinstance(p, Expression)] + # TODO non unit/numeric params, i.e. enums, bools # single unit unit_candidates = {p.units for p in alias_class} if len(unit_candidates) > 1: raise ValueError("Incompatible units in alias class") + if len(param_alias_class) > 0: + dirty |= len(param_alias_class) > 1 + + # single domain + domain_candidates = {p.domain for p in param_alias_class} + if len(domain_candidates) > 1: + raise ValueError("Incompatible domains in alias class") + + # intersect ranges + within_ranges = { + p.within for p in param_alias_class if p.within is not None + } + within = None + if within_ranges: + within = Ranges.op_intersect_ranges(*within_ranges) + + # heuristic: + # intersect soft sets + soft_sets = { + p.soft_set for p in param_alias_class if p.soft_set is not None + } + soft_set = None + if soft_sets: + soft_set = Ranges.op_intersect_ranges(*soft_sets) + + # heuristic: + # get median + guesses = {p.guess for p in param_alias_class if p.guess is not None} + guess = None + if guesses: + guess = median(guesses) # type: ignore + + # heuristic: + # max tolerance guess + tolerance_guesses = { + p.tolerance_guess + for p in param_alias_class + if p.tolerance_guess is not None + } + tolerance_guess = None + if tolerance_guesses: + tolerance_guess = max(tolerance_guesses) + + likely_constrained = any(p.likely_constrained for p in param_alias_class) + + representative = Parameter( + units=unit_candidates.pop(), + within=within, + soft_set=soft_set, + guess=guess, + tolerance_guess=tolerance_guess, + likely_constrained=likely_constrained, + ) + repr_map.update({p: representative for p in param_alias_class}) + elif len(expr_alias_class) > 1: + dirty = True + representative = Parameter(units=unit_candidates.pop()) - # single domain - domain_candidates = {p.domain for p in alias_class} - if len(domain_candidates) > 1: - raise ValueError("Incompatible domains in alias class") - - # intersect ranges - within_ranges = {p.within for p in alias_class if p.within is not None} - within = None - if within_ranges: - within = Ranges.op_intersect_ranges(*within_ranges) - - # heuristic: - # intersect soft sets - soft_sets = {p.soft_set for p in alias_class if p.soft_set is not None} - soft_set = None - if soft_sets: - soft_set = Ranges.op_intersect_ranges(*soft_sets) - - # heuristic: - # get median - guesses = {p.guess for p in alias_class if p.guess is not None} - guess = None - if guesses: - guess = median(guesses) # type: ignore - - # heuristic: - # max tolerance guess - tolerance_guesses = { - p.tolerance_guess for p in alias_class if p.tolerance_guess is not None - } - tolerance_guess = None - if tolerance_guesses: - tolerance_guess = max(tolerance_guesses) - - likely_constrained = any(p.likely_constrained for p in alias_class) - - representative = Parameter( - units=unit_candidates.pop(), - within=within, - soft_set=soft_set, - guess=guess, - tolerance_guess=tolerance_guess, - likely_constrained=likely_constrained, - ) - repr_map.update({p: representative for p in alias_class}) + if len(expr_alias_class) > 0: + for e in expr_alias_class: + copy_expr = copy_operand_recursively(e, repr_map) + repr_map[e] = ( + representative # copy_expr TODO make sure this makes sense + ) + # TODO, if it doesn't have implicit constraints and it's operands don't aren't constraint, we can get rid of it + assert isinstance(copy_expr, Constrainable) + copy_expr.alias_is(representative) # replace parameters in expressions and predicates for expr in cast( @@ -301,12 +365,13 @@ def try_replace(o: ParameterOperatable.All): # filter alias class Is if isinstance(expr, Is): - if all(isinstance(o, Parameter) for o in expr.operands): - continue + continue - operands = [try_replace(o) for o in expr.operands] - new_expr = create_new_expr(expr, *operands) - repr_map[expr] = new_expr + assert all( + o in repr_map or not isinstance(o, ParameterOperatable) + for o in expr.operands + ) + repr_map[expr] = copy_operand_recursively(expr, repr_map) return repr_map, dirty @@ -713,6 +778,8 @@ def compress_arithmetic_expressions( }, dirty +# TODO move to expression? +# TODO recursive? def has_implicit_constraint(po: ParameterOperatable) -> bool: if isinstance(po, Parameter | Add | Subtract | Multiply | Power): # TODO others return False @@ -761,20 +828,6 @@ class DefaultSolver(Solver): timeout: int = 1000 def phase_one_no_guess_solving(self, g: Graph) -> None: - def debug_print(repr_map: dict[ParameterOperatable, ParameterOperatable]): - for s, d in repr_map.items(): - if isinstance(d, Expression): - if isinstance(s, Expression): - logger.info( - f"{s}[{s.operands}] -> {d}[{d.operands} | G {d.get_graph()!r}]" - ) - else: - logger.info(f"{s} -> {d}[{d.operands} | G {d.get_graph()!r}]") - else: - logger.info(f"{s} -> {d} | G {d.get_graph()!r}") - graphs = {p.get_graph() for p in repr_map.values()} - logger.info(f"{len(graphs)} graphs") - logger.info(f"Phase 1 Solving: No guesses {'-' * 80}") # strategies @@ -803,13 +856,13 @@ def debug_print(repr_map: dict[ParameterOperatable, ParameterOperatable]): logger.info("Phase 0 Solving: normalize graph") repr_map = normalize_graph(g) debug_print(repr_map) - graphs = {p.get_graph() for p in repr_map.values()} + graphs = unique(map(lambda p: p.get_graph(), repr_map.values()), lambda g: g()) # TODO assert all new graphs dirty = True iter = 0 - while dirty: + while dirty and len(graphs) > 0: iter += 1 logger.info(f"Iteration {iter}") logger.info("Phase 1 Solving: Alias classes") @@ -818,7 +871,9 @@ def debug_print(repr_map: dict[ParameterOperatable, ParameterOperatable]): alias_repr_map, alias_dirty = resolve_alias_classes(g) repr_map.update(alias_repr_map) debug_print(repr_map) - graphs = {p.get_graph() for p in repr_map.values()} + graphs = unique( + map(lambda p: p.get_graph(), repr_map.values()), lambda g: g() + ) # TODO assert all new graphs logger.info("Phase 2a Solving: Add/Mul associative expressions") @@ -829,7 +884,22 @@ def debug_print(repr_map: dict[ParameterOperatable, ParameterOperatable]): ) repr_map.update(assoc_add_mul_repr_map) debug_print(repr_map) - graphs = {p.get_graph() for p in repr_map.values()} + graphs = unique( + map(lambda p: p.get_graph(), repr_map.values()), lambda g: g() + ) + # TODO assert all new graphs + + logger.info("Phase 2a Solving: Add/Mul associative expressions") + repr_map = {} + for g in graphs: + assoc_add_mul_repr_map, assoc_add_mul_dirty = ( + compress_associative_add_mul(g) + ) + repr_map.update(assoc_add_mul_repr_map) + debug_print(repr_map) + graphs = unique( + map(lambda p: p.get_graph(), repr_map.values()), lambda g: g() + ) # TODO assert all new graphs logger.info("Phase 2b Solving: Subtract associative expressions") @@ -838,7 +908,9 @@ def debug_print(repr_map: dict[ParameterOperatable, ParameterOperatable]): assoc_sub_repr_map, assoc_sub_dirty = compress_associative_sub(g) repr_map.update(assoc_sub_repr_map) debug_print(repr_map) - graphs = {p.get_graph() for p in repr_map.values()} + graphs = unique( + map(lambda p: p.get_graph(), repr_map.values()), lambda g: g() + ) # TODO assert all new graphs logger.info("Phase 3 Solving: Arithmetic expressions") @@ -847,7 +919,9 @@ def debug_print(repr_map: dict[ParameterOperatable, ParameterOperatable]): arith_repr_map, arith_dirty = compress_arithmetic_expressions(g) repr_map.update(arith_repr_map) debug_print(repr_map) - graphs = {p.get_graph() for p in repr_map.values()} + graphs = unique( + map(lambda p: p.get_graph(), repr_map.values()), lambda g: g() + ) # TODO assert all new graphs logger.info("Phase 4 Solving: Remove obvious tautologies") @@ -856,7 +930,9 @@ def debug_print(repr_map: dict[ParameterOperatable, ParameterOperatable]): tautology_repr_map, tautology_dirty = remove_obvious_tautologies(g) repr_map.update(tautology_repr_map) debug_print(repr_map) - graphs = {p.get_graph() for p in repr_map.values()} + graphs = unique( + map(lambda p: p.get_graph(), repr_map.values()), lambda g: g() + ) # TODO assert all new graphs logger.info("Phase 5 Solving: Subset of literals") @@ -865,7 +941,9 @@ def debug_print(repr_map: dict[ParameterOperatable, ParameterOperatable]): subset_repr_map, subset_dirty = subset_of_literal(g) repr_map.update(subset_repr_map) debug_print(repr_map) - graphs = {p.get_graph() for p in repr_map.values()} + graphs = unique( + map(lambda p: p.get_graph(), repr_map.values()), lambda g: g() + ) # TODO assert all new graphs dirty = ( diff --git a/src/faebryk/core/parameter.py b/src/faebryk/core/parameter.py index 3f0a9a0a..f43c42ef 100644 --- a/src/faebryk/core/parameter.py +++ b/src/faebryk/core/parameter.py @@ -5,7 +5,7 @@ from collections.abc import Iterable from enum import Enum, auto from types import NotImplementedType -from typing import Any, Callable, Self +from typing import Any, Callable, Self, override from faebryk.core.core import Namespace from faebryk.core.graphinterface import GraphInterface @@ -54,6 +54,36 @@ def key(e: ParameterOperatable): return sorted(exprs, key=key, reverse=not ascending) + def _is_constrains(self) -> list["Is"]: + return [ + i for i in self.operated_on.get_connected_nodes(types=Is) if i.constrained + ] + + def obviously_eq(self, other: "ParameterOperatable.All") -> bool: + if self == other: + return True + if other in self._is_constrains(): + return True + return False + + def obviously_eq_hash(self) -> int: + if hasattr(self, "__hash"): + return self.__hash + + ises = [i for i in self._is_constrains() if not isinstance(i, Expression)] + + def keyfn(i: Is): + if isinstance(i, Parameter): + return 1 << 63 + return hash(i) % (1 << 63) + + sorted_ises = sorted(ises, key=keyfn) + if len(sorted_ises) > 0: + self.__hash = hash(sorted_ises[0]) + else: + self.__hash = id(self) + return self.__hash + def operation_add(self, other: NumberLike): return Add(self, other) @@ -283,6 +313,18 @@ def if_then_else( # ) -> None: ... +def obviously_eq(a: ParameterOperatable.All, b: ParameterOperatable.All) -> bool: + if a == b: + return True + if isinstance(a, ParameterOperatable): + return a.obviously_eq(b) + elif isinstance(b, ParameterOperatable): + return b.obviously_eq(a) + return False + + +# TODO mixes two things, those that a constraining predicate can be called on, +# and the predicate, which can have it's constrained be set?? class Constrainable: type All = ParameterOperatable.All type Sets = ParameterOperatable.Sets @@ -339,7 +381,7 @@ class Expression(ParameterOperatable): def __init__(self, *operands: ParameterOperatable.All): super().__init__() - self.operands = operands + self.operands = tuple(operands) self.operatable_operands = { op for op in operands if isinstance(op, ParameterOperatable) } @@ -359,7 +401,24 @@ def depth(self) -> int: ) return self._depth + # TODO caching + @override + def obviously_eq(self, other: ParameterOperatable.All) -> bool: + if super().obviously_eq(other): + return True + if type(self) is type(other): + for s, o in zip(self.operands, other.operands): + if not obviously_eq(s, o): + return False + return True + return False + def obviously_eq_hash(self) -> int: + return hash((type(self), self.operands)) + + +# TODO are any expressions not constrainable? +# parameters are contstrainable, too, so all parameter-operatables are constrainable? @abstract class ConstrainableExpression(Expression, Constrainable): def __init__(self, *operands: ParameterOperatable.All): @@ -382,7 +441,6 @@ def __init__(self, *operands: ParameterOperatable.NumberLike): if isinstance(param, Parameter) ): raise ValueError("parameters must have domain Numbers or ESeries") - self.operands = operands @abstract @@ -395,10 +453,33 @@ def __init__(self, *operands): raise ValueError("All operands must have compatible units") +def _associative_obviously_eq(self: Expression, other: Expression) -> bool: + remaining = list(other.operands) + for op in self.operands: + for r in remaining: + if obviously_eq(op, r): + remaining.remove(r) + break + return not remaining + + class Add(Additive): def __init__(self, *operands): super().__init__(*operands) + # TODO caching + @override + def obviously_eq(self, other: ParameterOperatable.All) -> bool: + if ParameterOperatable.obviously_eq(self, other): + return True + if isinstance(other, Add): + return _associative_obviously_eq(self, other) + return False + + def obviously_eq_hash(self) -> int: + op_hash = sum(hash(op) for op in self.operands) + return hash((type(self), op_hash)) + class Subtract(Additive): def __init__(self, minuend, subtrahend): @@ -413,6 +494,19 @@ def __init__(self, *operands): for u in units[1:]: self.units = cast_assert(Unit, self.units * u) + # TODO caching + @override + def obviously_eq(self, other: ParameterOperatable.All) -> bool: + if ParameterOperatable.obviously_eq(self, other): + return True + if isinstance(other, Add): + return _associative_obviously_eq(self, other) + return False + + def obviously_eq_hash(self) -> int: + op_hash = sum(hash(op) for op in self.operands) + return hash((type(self), op_hash)) + class Divide(Arithmetic): def __init__(self, numerator, denominator): @@ -498,7 +592,6 @@ def __init__(self, *operands): if isinstance(param, Parameter) ): raise ValueError("parameters must have domain Boolean without a unit") - self.operands = operands class And(Logic): @@ -607,7 +700,6 @@ def __init__(self, left, right): r_units = HasUnit.get_units_or_dimensionless(right) if not l_units.is_compatible_with(r_units): raise ValueError("operands must have compatible units") - self.operands = [left, right] def __bool__(self): raise ValueError("Predicate cannot be converted to bool") diff --git a/test/core/test_parameters.py b/test/core/test_parameters.py index f8cd3f80..a7de2aa2 100644 --- a/test/core/test_parameters.py +++ b/test/core/test_parameters.py @@ -111,6 +111,21 @@ def test_subset_of_literal(): solver.phase_one_no_guess_solving(G) +def test_alias_classes(): + p0, p1, p2, p3, p4 = ( + Parameter(units=dimensionless, within=Range(0, i)) for i in range(5) + ) + p0.alias_is(p1) + addition = p2 + p3 + p1.alias_is(addition) + addition2 = p3 + p2 + p4.alias_is(addition2) + + G = p0.get_graph() + solver = DefaultSolver() + solver.phase_one_no_guess_solving(G) + + def test_solve_realworld(): app = F.RP2040() solver = DefaultSolver() @@ -173,7 +188,7 @@ def test_visualize_inspect_app(): # if run in jupyter notebook import sys - func = test_subset_of_literal + func = test_solve_realworld if "ipykernel" in sys.modules: func()