From f0e34d5e03895bd5e20198c2f7436e028ab8b965 Mon Sep 17 00:00:00 2001 From: NoR8quoh1r <20768237+NoR8quoh1r@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:38:12 +0100 Subject: [PATCH] expression simplification --- src/faebryk/core/defaultsolver.py | 384 +++++++++++++++++++++++------- src/faebryk/core/solver.py | 2 + src/faebryk/libs/sets.py | 4 +- test/core/test_parameters.py | 13 +- 4 files changed, 306 insertions(+), 97 deletions(-) diff --git a/src/faebryk/core/defaultsolver.py b/src/faebryk/core/defaultsolver.py index f2463e11..80e8c253 100644 --- a/src/faebryk/core/defaultsolver.py +++ b/src/faebryk/core/defaultsolver.py @@ -1,8 +1,8 @@ # This file is part of the faebryk project # SPDX-License-Identifier: MIT -from collections.abc import Iterable import logging +from collections.abc import Iterable from statistics import median from typing import Any, cast @@ -13,6 +13,7 @@ Add, Arithmetic, Constrainable, + Divide, Expression, Is, Multiply, @@ -24,6 +25,7 @@ ) from faebryk.core.solver import Solver from faebryk.libs.sets import Ranges +from faebryk.libs.units import dimensionless from faebryk.libs.util import EquivalenceClasses logger = logging.getLogger(__name__) @@ -131,7 +133,6 @@ def resolve_alias_classes( f"\n {len(dependency_classes)} dependency classes" "\n" ) - logger.info("Phase 1 Solving: Alias classes") logger.info(infostr) repr_map: dict[ParameterOperatable, ParameterOperatable] = {} @@ -213,7 +214,6 @@ def try_replace(o: ParameterOperatable.All): operands = [try_replace(o) for o in expr.operands] new_expr = create_new_expr(expr, *operands) - logger.info(f"{expr}[{expr.operands}] ->\n {new_expr}[{new_expr.operands}]") repr_map[expr] = new_expr return repr_map, dirty @@ -231,26 +231,42 @@ def copy_param(p: Parameter) -> Parameter: ) -def copy_pop( +def copy_operand_recursively( o: ParameterOperatable.All, repr_map: dict[ParameterOperatable, ParameterOperatable] ) -> ParameterOperatable.All: if o in repr_map: return repr_map[o] if isinstance(o, Expression): - return create_new_expr( - o, - *( - repr_map[op] if op in repr_map else copy_pop(op, repr_map) - for op in o.operands - ), - ) + new_ops = [] + for op in o.operands: + new_op = copy_operand_recursively(op, repr_map) + if isinstance(op, ParameterOperatable): + repr_map[op] = new_op + new_ops.append(new_op) + expr = create_new_expr(o, *new_ops) + repr_map[o] = expr + return expr elif isinstance(o, Parameter): - return copy_param(o) + param = copy_param(o) + repr_map[o] = param + return param else: return o -def compress_associative_expressions( +def is_replacable( + repr_map: dict[ParameterOperatable, ParameterOperatable], + e: Expression, + parent_expr: Expression, +) -> bool: + if e in repr_map: # overly restrictive: equivalent replacement would be ok + return False + if e.get_operations() != {parent_expr}: + return False + return True + + +def compress_associative_add_mul( G: Graph, ) -> tuple[dict[ParameterOperatable, ParameterOperatable], bool]: dirty = False @@ -261,6 +277,7 @@ def compress_associative_expressions( } repr_map: dict[ParameterOperatable, ParameterOperatable] = {} + removed = set() # (A + B) + C # X -> Y @@ -273,16 +290,18 @@ def flatten_operands_of_ops_with_same_type[T: Add | Multiply]( ) -> tuple[list[T], bool]: dirty = False operands = e.operands - noncomp, compressable = partition(lambda o: type(o) is type(e), operands) + noncomp, compressible = partition( + lambda o: type(o) is type(e) and is_replacable(repr_map, o, e), operands + ) out = [] - for c in compressable: + for c in compressible: dirty = True - if c in repr_map: - out.append(repr_map[c]) - else: - sub_out, sub_dirty = flatten_operands_of_ops_with_same_type(c) - dirty |= sub_dirty - out += sub_out + removed.add(c) + sub_out, sub_dirty = flatten_operands_of_ops_with_same_type(c) + dirty |= sub_dirty + out += sub_out + if len(out) > 0: + logger.info(f"FLATTENED {type(e).__name__} {e} -> {out}") return out + list(noncomp), dirty for expr in cast( @@ -290,33 +309,122 @@ def flatten_operands_of_ops_with_same_type[T: Add | Multiply]( ParameterOperatable.sort_by_depth(parent_add_muls, ascending=True), ): operands, sub_dirty = flatten_operands_of_ops_with_same_type(expr) - dirty |= sub_dirty - # copy - for o in operands: - if isinstance(o, ParameterOperatable): - repr_map[o] = copy_pop(o, repr_map) - - # make new compressed expr with (copied) operands - new_expr = create_new_expr( - expr, - *( - repr_map[o] if o in repr_map else copy_pop(o, repr_map) - for o in operands - ), + if sub_dirty: + dirty = True + copy_operands = [copy_operand_recursively(o, repr_map) for o in operands] + + new_expr = create_new_expr( + expr, + *copy_operands, + ) + repr_map[expr] = new_expr + + # copy other param ops + other_param_op = ParameterOperatable.sort_by_depth( + ( + p + for p in G.nodes_of_type(ParameterOperatable) + if p not in repr_map and p not in removed + ), + ascending=True, + ) + for o in other_param_op: + copy_operand_recursively(o, repr_map) + + return repr_map, dirty + + +def compress_associative_sub( + G: Graph, +) -> tuple[dict[ParameterOperatable, ParameterOperatable], bool]: + logger.info("Compressing Subtracts") + dirty = False + subs = cast(set[Subtract], G.nodes_of_type(Subtract)) + # get out deepest expr in compressable tree + parent_subs = { + e for e in subs if type(e) not in {type(n) for n in e.get_operations()} + } + + removed = set() + repr_map: dict[ParameterOperatable, ParameterOperatable] = {} + + def flatten_sub( + e: Subtract, + ) -> tuple[ + ParameterOperatable.All, + list[ParameterOperatable.All], + list[ParameterOperatable.All], + bool, + ]: + const_subtrahend = ( + [] if isinstance(e.operands[1], ParameterOperatable) else [e.operands[1]] ) - repr_map[expr] = new_expr + nonconst_subtrahend = [] if const_subtrahend else [e.operands[1]] + if isinstance(e.operands[0], Subtract) and is_replacable( + repr_map, e.operands[0], e + ): + removed.add(e.operands[0]) + minuend, const_subtrahends, nonconst_subtrahends, _ = flatten_sub( + e.operands[0] + ) + return ( + minuend, + const_subtrahends + const_subtrahend, + nonconst_subtrahends + nonconst_subtrahend, + True, + ) + else: + return e.operands[0], const_subtrahend, nonconst_subtrahend, False + + for expr in cast( + Iterable[Subtract], + ParameterOperatable.sort_by_depth(parent_subs, ascending=True), + ): + minuend, const_subtrahends, nonconst_subtrahends, sub_dirty = flatten_sub(expr) + if ( + isinstance(minuend, Add) + and is_replacable(repr_map, minuend, expr) + and len(const_subtrahends) > 0 + ): + copy_minuend = Add( + *(copy_operand_recursively(s, repr_map) for s in minuend.operands), + *(-1 * c for c in const_subtrahends), + ) + repr_map[expr] = copy_minuend + const_subtrahends = [] + sub_dirty = True + elif sub_dirty: + copy_minuend = copy_operand_recursively(minuend, repr_map) + if sub_dirty: + dirty = True + copy_subtrahends = [ + copy_operand_recursively(s, repr_map) + for s in nonconst_subtrahends + const_subtrahends + ] + if len(copy_subtrahends) > 0: + new_expr = Subtract( + copy_minuend, + Add(*copy_subtrahends), + ) + else: + new_expr = copy_minuend + removed.add(expr) + repr_map[expr] = new_expr + logger.info(f"REPRMAP {expr} -> {new_expr}") # copy other param ops other_param_op = ParameterOperatable.sort_by_depth( ( p for p in G.nodes_of_type(ParameterOperatable) - if p not in repr_map and p not in add_muls + if p not in repr_map and p not in removed ), ascending=True, ) - remaining_param_op = {p: copy_pop(p, repr_map) for p in other_param_op} - repr_map.update(remaining_param_op) + for o in other_param_op: + copy_o = copy_operand_recursively(o, repr_map) + logger.info(f"REMAINING {o} -> {copy_o}") + repr_map[o] = copy_o return repr_map, dirty @@ -328,20 +436,25 @@ def compress_arithmetic_expressions( arith_exprs = cast(set[Arithmetic], G.nodes_of_type(Arithmetic)) repr_map: dict[ParameterOperatable, ParameterOperatable] = {} + removed = set() for expr in cast( Iterable[Arithmetic], ParameterOperatable.sort_by_depth(arith_exprs, ascending=True), ): + if expr in repr_map or expr in removed: + continue + operands = expr.operands const_ops, nonconst_ops = partition( lambda o: isinstance(o, ParameterOperatable), operands ) + non_replacable_nonconst_ops, replacable_nonconst_ops = partition( + lambda o: o not in repr_map, nonconst_ops + ) multiplicity = {} - has_multiplicity = False - for n in nonconst_ops: + for n in replacable_nonconst_ops: if n in multiplicity: - has_multiplicity = True multiplicity[n] += 1 else: multiplicity[n] = 1 @@ -357,19 +470,32 @@ def compress_arithmetic_expressions( const_sum = [] except StopIteration: const_sum = [] - nonconst_prod = { - n: Multiply(n, m) if m > 1 else copy_pop(n, repr_map) - for n, m in multiplicity.items() - } - new_operands = (*nonconst_prod.values(), *const_sum) - if len(new_operands) > 1: - new_expr = Add(*new_operands) - elif len(new_operands) == 1: - new_expr = new_operands[0] - else: - raise ValueError("No operands, should not happen") - repr_map.update(nonconst_prod) - repr_map[expr] = new_expr + if any(m > 1 for m in multiplicity.values()): + dirty = True + if dirty: + copied = { + n: copy_operand_recursively(n, repr_map) for n in multiplicity + } + nonconst_prod = [ + Multiply(copied[n], m) if m > 1 else copied[n] + for n, m in multiplicity.items() + ] + new_operands = [ + *nonconst_prod, + *const_sum, + *( + copy_operand_recursively(o, repr_map) + for o in non_replacable_nonconst_ops + ), + ] + if len(new_operands) > 1: + new_expr = Add(*new_operands) + elif len(new_operands) == 1: + new_expr = new_operands[0] + removed.add(expr) + else: + raise ValueError("No operands, should not happen") + repr_map[expr] = new_expr elif isinstance(expr, Multiply): try: @@ -377,54 +503,100 @@ def compress_arithmetic_expressions( for c in const_ops: dirty = True const_prod[0] *= c - if const_prod[0] == 1 * expr.units: # TODO make work with all the types + if ( + const_prod[0] == 1 * dimensionless + ): # TODO make work with all the types dirty = True const_prod = [] except StopIteration: const_prod = [] if ( - len(const_prod) == 1 and const_prod[0] == 0 * expr.units + len(const_prod) == 1 and const_prod[0].magnitude == 0 ): # TODO make work with all the types dirty = True repr_map[expr] = 0 * expr.units else: - nonconst_prod = { - n: Power(n, m) if m > 1 else copy_pop(n, repr_map) - for n, m in multiplicity.items() - } - if has_multiplicity: + if any(m > 1 for m in multiplicity.values()): dirty = True - new_operands = (*nonconst_prod.values(), *const_prod) - if len(new_operands) > 1: - new_expr = Multiply(*new_operands) - elif len(new_operands) == 1: - new_expr = new_operands[0] - else: - raise ValueError("No operands, should not happen") - repr_map.update(nonconst_prod) - repr_map[expr] = new_expr + if dirty: + copied = { + n: copy_operand_recursively(n, repr_map) for n in multiplicity + } + nonconst_power = [ + Power(copied[n], m) if m > 1 else copied[n] + for n, m in multiplicity.items() + ] + new_operands = [ + *nonconst_power, + *const_prod, + *( + copy_operand_recursively(o, repr_map) + for o in non_replacable_nonconst_ops + ), + ] + if len(new_operands) > 1: + new_expr = Multiply(*new_operands) + elif len(new_operands) == 1: + new_expr = new_operands[0] + removed.add(expr) + else: + raise ValueError("No operands, should not happen") + repr_map[expr] = new_expr elif isinstance(expr, Subtract): - if expr.operands[0] is expr.operands[1]: + if sum(1 for _ in const_ops) == 2: + dirty = True + repr_map[expr] = expr.operands[0] - expr.operands[1] + removed.add(expr) + elif expr.operands[0] is expr.operands[1]: dirty = True repr_map[expr] = 0 * expr.units - elif len(const_ops) == 2: + removed.add(expr) + elif expr.operands[1] == 0 * expr.operands[1].units: dirty = True - repr_map[expr] = expr.operands[0] - expr.operands[1] + repr_map[expr.operands[0]] = repr_map.get( + expr.operands[0], + copy_operand_recursively(expr.operands[0], repr_map), + ) + repr_map[expr] = repr_map[expr.operands[0]] + removed.add(expr) else: - repr_map[expr] = copy_pop(expr, repr_map) + repr_map[expr] = copy_operand_recursively(expr, repr_map) + elif isinstance(expr, Divide): + if sum(1 for _ in const_ops) == 2: + if not expr.operands[1].magnitude == 0: + dirty = True + repr_map[expr] = expr.operands[0] / expr.operands[1] + removed.add(expr) + else: + # no valid solution but might not matter e.g. [phi(a,b,...) OR a/0 == b] + repr_map[expr] = copy_operand_recursively(expr, repr_map) + elif expr.operands[1] is expr.operands[0]: + dirty = True + repr_map[expr] = 1 * dimensionless + removed.add(expr) + elif expr.operands[1] == 1 * expr.operands[1].units: + dirty = True + repr_map[expr.operands[0]] = repr_map.get( + expr.operands[0], + copy_operand_recursively(expr.operands[0], repr_map), + ) + repr_map[expr] = repr_map[expr.operands[0]] + removed.add(expr) + else: + repr_map[expr] = copy_operand_recursively(expr, repr_map) else: - repr_map[expr] = copy_pop(expr, repr_map) + repr_map[expr] = copy_operand_recursively(expr, repr_map) other_param_op = ParameterOperatable.sort_by_depth( ( p for p in G.nodes_of_type(ParameterOperatable) - if p not in repr_map and p not in arith_exprs + if p not in repr_map and p not in removed ), ascending=True, ) - remaining_param_op = {p: copy_pop(p, repr_map) for p in other_param_op} - repr_map.update(remaining_param_op) + for o in other_param_op: + copy_operand_recursively(o, repr_map) return { k: v for k, v in repr_map.items() if isinstance(v, ParameterOperatable) @@ -467,45 +639,79 @@ def phase_one_no_guess_solving(self, g: Graph) -> None: while dirty: iter += 1 logger.info(f"Iteration {iter}") + logger.info("Phase 1 Solving: Alias classes") repr_map = {} for g in graphs: alias_repr_map, alias_dirty = resolve_alias_classes(g) repr_map.update(alias_repr_map) + for s, d in repr_map.items(): + if isinstance(d, Expression): + if isinstance(s, Expression): + logger.info(f"{s}[{s.operands}] -> {d}[{d.operands}]") + else: + logger.info(f"{s} -> {d}[{d.operands}]") + else: + logger.info(f"{s} -> {d}") graphs = {p.get_graph() for p in repr_map.values()} - for g in graphs: - logger.info(f"G: {g}") logger.info(f"{len(graphs)} new graphs") # TODO assert all new graphs - logger.info("Phase 2 Solving: Associative expressions") + logger.info("Phase 2a Solving: Add/Mul associative expressions") repr_map = {} for g in graphs: - assoc_repr_map, assoc_dirty = compress_associative_expressions(g) - repr_map.update(assoc_repr_map) + assoc_add_mul_repr_map, assoc_add_mul_dirty = ( + compress_associative_add_mul(g) + ) + repr_map.update(assoc_add_mul_repr_map) for s, d in repr_map.items(): - if isinstance(s, Expression): - logger.info(f"{s}[{s.operands}] -> {d}[{d.operands}]") + if isinstance(d, Expression): + if isinstance(s, Expression): + logger.info(f"{s}[{s.operands}] -> {d}[{d.operands}]") + else: + logger.info(f"{s} -> {d}[{d.operands}]") else: logger.info(f"{s} -> {d}") graphs = {p.get_graph() for p in repr_map.values()} logger.info(f"{len(graphs)} new graphs") # TODO assert all new graphs + logger.info("Phase 2b Solving: Subtract associative expressions") + repr_map = {} + for g in graphs: + assoc_sub_repr_map, assoc_sub_dirty = compress_associative_sub(g) + repr_map.update(assoc_sub_repr_map) + for s, d in repr_map.items(): + if isinstance(d, Expression): + if isinstance(s, Expression): + logger.info( + f"{s}[{s.operands}] -> {d}[{d.operands} | G {d.get_graph()!r}]" + ) + else: + logger.info(f"{s} -> {d}[{d.operands} | G {d.get_graph()!r}]") + else: + logger.info(f"{s} -> {d} | G {d.get_graph()!r}") + graphs = {p.get_graph() for p in repr_map.values()} + logger.info(f"{len(graphs)} new graphs") + # TODO assert all new graphs + logger.info("Phase 3 Solving: Arithmetic expressions") repr_map = {} for g in graphs: arith_repr_map, arith_dirty = compress_arithmetic_expressions(g) repr_map.update(arith_repr_map) for s, d in repr_map.items(): - if isinstance(s, Expression): - logger.info(f"{s}[{s.operands}] -> {d}[{d.operands}] | G: {id(g)}") + if isinstance(d, Expression): + if isinstance(s, Expression): + logger.info(f"{s}[{s.operands}] -> {d}[{d.operands}]") + else: + logger.info(f"{s} -> {d}[{d.operands}]") else: - logger.info(f"{s} -> {d} | G: {id(g)}") + logger.info(f"{s} -> {d}") graphs = {p.get_graph() for p in repr_map.values()} logger.info(f"{len(graphs)} new graphs") # TODO assert all new graphs - dirty = alias_dirty or assoc_dirty or arith_dirty + dirty = alias_dirty or assoc_add_mul_dirty or assoc_sub_dirty or arith_dirty def get_any_single( self, diff --git a/src/faebryk/core/solver.py b/src/faebryk/core/solver.py index d572ec4f..9b616034 100644 --- a/src/faebryk/core/solver.py +++ b/src/faebryk/core/solver.py @@ -23,6 +23,8 @@ class SolverError(Exception): ... class TimeoutError(SolverError): ... + class DivisionByZeroError(SolverError): ... + @dataclass class SolveResult: timed_out: bool diff --git a/src/faebryk/libs/sets.py b/src/faebryk/libs/sets.py index 0f91c5f7..88c5c175 100644 --- a/src/faebryk/libs/sets.py +++ b/src/faebryk/libs/sets.py @@ -731,12 +731,12 @@ def __hash__(self) -> int: def __repr__(self) -> str: if self.units.is_compatible_with(dimensionless): inner = ", ".join(f"[{r._min}, {r._max}]" for r in self._ranges.ranges) - return f"_RangeUnion({inner})" + return f"Ranges({inner})" inner = ", ".join( f"[{self.base_to_units(r._min)}, {self.base_to_units(r._max)}]" for r in self._ranges.ranges ) - return f"_RangeUnion({inner} | {self.units})" + return f"Ranges({inner} | {self.units})" class Ranges(NonIterableRanges[QuantityT], Iterable[Range[QuantityT]]): diff --git a/test/core/test_parameters.py b/test/core/test_parameters.py index dd7bca76..636344ae 100644 --- a/test/core/test_parameters.py +++ b/test/core/test_parameters.py @@ -54,17 +54,17 @@ class App(Module): solver.phase_one_no_guess_solving(voltage1.get_graph()) -def test_assoc_compress(): +def test_simplify(): class App(Module): ops = L.list_field(10, lambda: Parameter(units=dimensionless)) app = App() - # (((((((((A + B + 1) + C + 2) * D * 3) * E * 4) * F * 5) * G * (A - A)) + H + 7) + I + 8) + J + 9) < 11 - # => (H + I + J + 24) < 11 + # (((((((((((A + B + 1) + C + 2) * D * 3) * E * 4) * F * 5) * G * (A - A)) + H + 7) + I + 8) + J + 9) - 3) - 4) < 11 + # => (H + I + J + 17) < 11 constants = [c * dimensionless for c in range(0, 10)] constants[5] = app.ops[0] - app.ops[0] - # constants[9] = Ranges(Range(0 * dimensionless, 1 * dimensionless)) + constants[9] = Ranges(Range(0 * dimensionless, 1 * dimensionless)) acc = app.ops[0] for i, p in enumerate(app.ops[1:3]): acc += p + constants[i] @@ -73,7 +73,8 @@ class App(Module): for i, p in enumerate(app.ops[7:]): acc += p + constants[i + 7] - (acc < 11).constrain() + acc = (acc - 3 * dimensionless) - 4 * dimensionless + (acc < 11 * dimensionless).constrain() G = acc.get_graph() solver = DefaultSolver() @@ -142,7 +143,7 @@ def test_visualize_inspect_app(): # if run in jupyter notebook import sys - func = test_assoc_compress + func = test_simplify if "ipykernel" in sys.modules: func()