Skip to content
This repository has been archived by the owner on Dec 10, 2024. It is now read-only.

Commit

Permalink
normalize params, remove some tautologies
Browse files Browse the repository at this point in the history
  • Loading branch information
NoR8quoh1r committed Nov 7, 2024
1 parent f0e34d5 commit 865996e
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 84 deletions.
270 changes: 188 additions & 82 deletions src/faebryk/core/defaultsolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
Divide,
Expression,
Is,
Log,
Multiply,
Parameter,
ParameterOperatable,
Power,
Predicate,
Sqrt,
Subtract,
)
from faebryk.core.solver import Solver
from faebryk.libs.sets import Ranges
from faebryk.libs.units import dimensionless
from faebryk.libs.sets import Range, Ranges
from faebryk.libs.units import Quantity, dimensionless
from faebryk.libs.util import EquivalenceClasses

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,7 +60,7 @@ def get_params_for_expr(expr: Expression) -> set[Parameter]:

def get_constrained_predicates_involved_in(
p: Parameter | Expression,
) -> list[Predicate]:
) -> set[Predicate]:
# p.self -> p.operated_on -> e1.operates_on -> e1.self
dependants = p.bfs_node(
lambda path, _: isinstance(path[-1].node, ParameterOperatable)
Expand All @@ -77,7 +79,7 @@ def get_constrained_predicates_involved_in(
)
)
)
res = [p for p in dependants if isinstance(p, Predicate) and p.constrained]
res = {p for p in dependants if isinstance(p, Predicate) and p.constrained}
return res


Expand Down Expand Up @@ -110,6 +112,95 @@ def create_new_expr(
return new_expr


def copy_param(p: Parameter) -> Parameter:
return Parameter(
units=p.units,
within=p.within,
domain=p.domain,
soft_set=p.soft_set,
guess=p.guess,
tolerance_guess=p.tolerance_guess,
likely_constrained=p.likely_constrained,
)


def copy_operand_recursively(
o: ParameterOperatable.All, repr_map: dict[ParameterOperatable, ParameterOperatable]
) -> ParameterOperatable.All:
if o in repr_map:
return repr_map[o]
if isinstance(o, Expression):
new_ops = []
for op in o.operands:
new_op = copy_operand_recursively(op, repr_map)
if isinstance(op, ParameterOperatable):
repr_map[op] = new_op
new_ops.append(new_op)
expr = create_new_expr(o, *new_ops)
repr_map[o] = expr
return expr
elif isinstance(o, Parameter):
param = copy_param(o)
repr_map[o] = param
return param
else:
return o


# units -> base units (dimensionless)
# within -> constrain is subset
# scalar to single
def normalize_graph(G: Graph) -> dict[ParameterOperatable, ParameterOperatable]:
def set_to_base_units(s: Ranges | Range | None) -> Ranges | Range | None:
if s is None:
return None
if isinstance(s, Ranges):
return Ranges._from_ranges(s._ranges, dimensionless)
return Range._from_range(s._range, dimensionless)

def scalar_to_base_units(q: int | float | Quantity | None) -> Quantity | None:
if q is None:
return None
if isinstance(q, Quantity):
return q.to_base_units().magnitude * dimensionless
return q * dimensionless

param_ops = G.nodes_of_type(ParameterOperatable)

repr_map: dict[ParameterOperatable, ParameterOperatable] = {}

for po in cast(
Iterable[ParameterOperatable],
ParameterOperatable.sort_by_depth(param_ops, ascending=True),
):
if isinstance(po, Parameter):
new_param = Parameter(
units=dimensionless,
within=None,
domain=po.domain,
soft_set=set_to_base_units(po.soft_set),
guess=scalar_to_base_units(po.guess),
tolerance_guess=po.tolerance_guess,
likely_constrained=po.likely_constrained,
)
repr_map[po] = new_param
if po.within is not None:
new_param.constrain_subset(set_to_base_units(po.within))
elif isinstance(po, Expression):
new_ops = []
for op in po.operands:
if isinstance(op, ParameterOperatable):
assert op in repr_map
new_ops.append(repr_map[op])
elif isinstance(op, int | float | Quantity):
new_ops.append(scalar_to_base_units(op))
else:
new_ops.append(set_to_base_units(op))
repr_map[po] = create_new_expr(po, *new_ops)

return repr_map


def resolve_alias_classes(
G: Graph,
) -> tuple[dict[ParameterOperatable, ParameterOperatable], bool]:
Expand Down Expand Up @@ -219,39 +310,8 @@ def try_replace(o: ParameterOperatable.All):
return repr_map, dirty


def copy_param(p: Parameter) -> Parameter:
return Parameter(
units=p.units,
within=p.within,
domain=p.domain,
soft_set=p.soft_set,
guess=p.guess,
tolerance_guess=p.tolerance_guess,
likely_constrained=p.likely_constrained,
)


def copy_operand_recursively(
o: ParameterOperatable.All, repr_map: dict[ParameterOperatable, ParameterOperatable]
) -> ParameterOperatable.All:
if o in repr_map:
return repr_map[o]
if isinstance(o, Expression):
new_ops = []
for op in o.operands:
new_op = copy_operand_recursively(op, repr_map)
if isinstance(op, ParameterOperatable):
repr_map[op] = new_op
new_ops.append(new_op)
expr = create_new_expr(o, *new_ops)
repr_map[o] = expr
return expr
elif isinstance(o, Parameter):
param = copy_param(o)
repr_map[o] = param
return param
else:
return o


def is_replacable(
Expand Down Expand Up @@ -587,13 +647,15 @@ def compress_arithmetic_expressions(
else:
repr_map[expr] = copy_operand_recursively(expr, repr_map)

other_param_op = ParameterOperatable.sort_by_depth(
(
p
for p in G.nodes_of_type(ParameterOperatable)
if p not in repr_map and p not in removed
),
ascending=True,
other_param_op = (
ParameterOperatable.sort_by_depth( # TODO, do we need the sort here? same above
(
p
for p in G.nodes_of_type(ParameterOperatable)
if p not in repr_map and p not in removed
),
ascending=True,
)
)
for o in other_param_op:
copy_operand_recursively(o, repr_map)
Expand All @@ -603,10 +665,68 @@ def compress_arithmetic_expressions(
}, dirty


def has_implicit_constraint(po: ParameterOperatable) -> bool:
if isinstance(po, Parameter | Add | Subtract | Multiply | Power): # TODO others
return False
if isinstance(po, Divide):
return True # implicit constraint: divisor not zero
if isinstance(po, Sqrt | Log):
return True # implicit constraint: non-negative
return True


def remove_obvious_tautologies(
G: Graph,
) -> tuple[dict[ParameterOperatable, ParameterOperatable], bool]:
removed = set()
dirty = False
for pred_is in ParameterOperatable.sort_by_depth(
G.nodes_of_type(Is), ascending=True
):

def known_unconstrained(po: ParameterOperatable) -> bool:
no_other_constraints = (
len(get_constrained_predicates_involved_in(po).difference({pred_is}))
== 0
)
return no_other_constraints and not has_implicit_constraint(po)

pred_is = cast(Is, pred_is)
if pred_is.operands[0] is pred_is.operands[1] and not known_unconstrained(
pred_is.operands[0]
):
removed.add(pred_is)
dirty = True
elif known_unconstrained(pred_is.operands[0]) or known_unconstrained(
pred_is.operands[1]
):
removed.add(pred_is)
dirty = True
repr_map = {}
for p in G.nodes_of_type(ParameterOperatable):
if p not in removed:
repr_map[p] = copy_operand_recursively(p, repr_map)
return repr_map, dirty


class DefaultSolver(Solver):
timeout: int = 1000

def phase_one_no_guess_solving(self, g: Graph) -> None:
def debug_print(repr_map: dict[ParameterOperatable, ParameterOperatable]):
for s, d in repr_map.items():
if isinstance(d, Expression):
if isinstance(s, Expression):
logger.info(
f"{s}[{s.operands}] -> {d}[{d.operands} | G {d.get_graph()!r}]"

Check failure on line 721 in src/faebryk/core/defaultsolver.py

View workflow job for this annotation

GitHub Actions / pytest

Ruff (E501)

src/faebryk/core/defaultsolver.py:721:89: E501 Line too long (91 > 88)
)
else:
logger.info(f"{s} -> {d}[{d.operands} | G {d.get_graph()!r}]")
else:
logger.info(f"{s} -> {d} | G {d.get_graph()!r}")
graphs = {p.get_graph() for p in repr_map.values()}
logger.info(f"{len(graphs)} graphs")

logger.info(f"Phase 1 Solving: No guesses {'-' * 80}")

# strategies
Expand All @@ -632,7 +752,12 @@ def phase_one_no_guess_solving(self, g: Graph) -> None:

# as long as progress iterate

graphs = {g}
logger.info("Phase 0 Solving: normalize graph")
repr_map = normalize_graph(g)
debug_print(repr_map)
graphs = {p.get_graph() for p in repr_map.values()}
# TODO assert all new graphs

dirty = True
iter = 0

Expand All @@ -644,16 +769,8 @@ def phase_one_no_guess_solving(self, g: Graph) -> None:
for g in graphs:
alias_repr_map, alias_dirty = resolve_alias_classes(g)
repr_map.update(alias_repr_map)
for s, d in repr_map.items():
if isinstance(d, Expression):
if isinstance(s, Expression):
logger.info(f"{s}[{s.operands}] -> {d}[{d.operands}]")
else:
logger.info(f"{s} -> {d}[{d.operands}]")
else:
logger.info(f"{s} -> {d}")
debug_print(repr_map)
graphs = {p.get_graph() for p in repr_map.values()}
logger.info(f"{len(graphs)} new graphs")
# TODO assert all new graphs

logger.info("Phase 2a Solving: Add/Mul associative expressions")
Expand All @@ -663,55 +780,44 @@ def phase_one_no_guess_solving(self, g: Graph) -> None:
compress_associative_add_mul(g)
)
repr_map.update(assoc_add_mul_repr_map)
for s, d in repr_map.items():
if isinstance(d, Expression):
if isinstance(s, Expression):
logger.info(f"{s}[{s.operands}] -> {d}[{d.operands}]")
else:
logger.info(f"{s} -> {d}[{d.operands}]")
else:
logger.info(f"{s} -> {d}")
debug_print(repr_map)
graphs = {p.get_graph() for p in repr_map.values()}
logger.info(f"{len(graphs)} new graphs")
# TODO assert all new graphs

logger.info("Phase 2b Solving: Subtract associative expressions")
repr_map = {}
for g in graphs:
assoc_sub_repr_map, assoc_sub_dirty = compress_associative_sub(g)
repr_map.update(assoc_sub_repr_map)
for s, d in repr_map.items():
if isinstance(d, Expression):
if isinstance(s, Expression):
logger.info(
f"{s}[{s.operands}] -> {d}[{d.operands} | G {d.get_graph()!r}]"
)
else:
logger.info(f"{s} -> {d}[{d.operands} | G {d.get_graph()!r}]")
else:
logger.info(f"{s} -> {d} | G {d.get_graph()!r}")
debug_print(repr_map)
graphs = {p.get_graph() for p in repr_map.values()}
logger.info(f"{len(graphs)} new graphs")
# TODO assert all new graphs

logger.info("Phase 3 Solving: Arithmetic expressions")
repr_map = {}
for g in graphs:
arith_repr_map, arith_dirty = compress_arithmetic_expressions(g)
repr_map.update(arith_repr_map)
for s, d in repr_map.items():
if isinstance(d, Expression):
if isinstance(s, Expression):
logger.info(f"{s}[{s.operands}] -> {d}[{d.operands}]")
else:
logger.info(f"{s} -> {d}[{d.operands}]")
else:
logger.info(f"{s} -> {d}")
debug_print(repr_map)
graphs = {p.get_graph() for p in repr_map.values()}
logger.info(f"{len(graphs)} new graphs")
# TODO assert all new graphs

dirty = alias_dirty or assoc_add_mul_dirty or assoc_sub_dirty or arith_dirty
logger.info("Phase 4 Solving: Remove obvious tautologies")
repr_map = {}
for g in graphs:
tautology_repr_map, tautology_dirty = remove_obvious_tautologies(g)
repr_map.update(tautology_repr_map)
debug_print(repr_map)
graphs = {p.get_graph() for p in repr_map.values()}
# TODO assert all new graphs

dirty = (
alias_dirty
or assoc_add_mul_dirty
or assoc_sub_dirty
or arith_dirty
or tautology_dirty
)

def get_any_single(
self,
Expand Down
Loading

0 comments on commit 865996e

Please sign in to comment.