diff --git a/CHANGELOG.md b/CHANGELOG.md index 62b71cb..48bf0c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Moved finite domain goal constructors into fd.py - Exit early from FD goals if any var has no domain - When exiting early from FD goals, make sure a constraint is added to the store +- Use `immutables` map instead of pyrsistent map (better performance) +- Change `neq` signature from `neq((a, b), *rest_pairs)` to `neq(a, b, /, *rest)` +- Make `Constraint` frozen so it's hashable +- Store `Constraint` operands as tuples rather than lists/sets ## [0.3.0] - 2023-04-10 diff --git a/pyproject.toml b/pyproject.toml index 5f94fdb..dc32490 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ classifiers = [ dependencies = [ "pyrsistent ~= 0.19", "fastcons ~= 0.3.0", + "immutables ~= 0.19", ] license = {file = "LICENSE"} diff --git a/src/microkanren/core.py b/src/microkanren/core.py index 82e43be..680885a 100644 --- a/src/microkanren/core.py +++ b/src/microkanren/core.py @@ -10,12 +10,12 @@ from collections.abc import Callable, Generator from dataclasses import dataclass -from functools import reduce, update_wrapper, wraps +from functools import reduce, wraps from typing import Any, Optional, Protocol, TypeAlias, TypeVar +import immutables from fastcons import cons, nil -from pyrsistent import PClass, field, pmap -from pyrsistent.typing import PMap +from pyrsistent import PClass, field from microkanren.utils import identity @@ -38,19 +38,19 @@ def __repr__(self): Value: TypeAlias = ( Var | int | str | bool | tuple["Value", ...] | list["Value"] | cons | nil ) -Substitution: TypeAlias = PMap[Var, Value] +Substitution: TypeAlias = immutables.Map[Var, Value] NeqStore: TypeAlias = list[list[tuple[Var, Value]]] -DomainStore: TypeAlias = PMap[Var, set[int]] +DomainStore: TypeAlias = immutables.Map[Var, set[int]] ConstraintFunction: TypeAlias = Callable[["State"], Optional["State"]] ConstraintStore: TypeAlias = list["Constraint"] def empty_sub() -> Substitution: - return pmap() + return immutables.Map() def empty_domain_store() -> DomainStore: - return pmap() + return immutables.Map() def empty_constraint_store() -> ConstraintStore: @@ -81,10 +81,10 @@ def __call__(self, *args: Any) -> ConstraintFunction: ... -@dataclass(slots=True) +@dataclass(slots=True, frozen=True) class Constraint: func: ConstraintProto - operands: list[Value] + operands: tuple[Value] def __call__(self, state: State) -> State | None: return self.func(*self.operands)(state) @@ -105,7 +105,6 @@ def __call__(self, *args: Value) -> GoalProto: class Goal: def __init__(self, goal: GoalProto): - update_wrapper(self, goal) self.goal = goal def __call__(self, state: State) -> Stream: @@ -251,8 +250,8 @@ def delay(g: GoalProto) -> GoalProto: def disj(g: GoalProto, *goals: GoalProto) -> GoalProto: if goals == (): - return delay(g) - return reduce(_disj, (delay(goal) for goal in goals), delay(g)) + return g + return reduce(_disj, (goal for goal in goals), g) def _disj(g1: GoalProto, g2: GoalProto) -> GoalProto: @@ -264,8 +263,8 @@ def __disj(state: State) -> Stream: def conj(g: GoalProto, *goals: GoalProto) -> GoalProto: if goals == (): - return delay(g) - return reduce(_conj, (delay(goal) for goal in goals), delay(g)) + return g + return reduce(_conj, (goal for goal in goals), g) def _conj(g1: GoalProto, g2: GoalProto) -> GoalProto: @@ -308,6 +307,26 @@ def unify_all( return None +def pairs(xs): + _xs = iter(xs) + while True: + try: + a = next(_xs) + except StopIteration: + break + try: + b = next(_xs) + yield (a, b) + except StopIteration: + raise ValueError("got sequence with uneven length") + + +def unpairs(xs): + for a, b in xs: + yield a + yield b + + def maybe_unify( pair: tuple[Value, Value], sub: Substitution | None ) -> Substitution | None: @@ -325,23 +344,22 @@ def _flipped(x, y): return _flipped -def neq(*pairs) -> GoalProto: - return goal_from_constraint(neqc(pairs)) +def neq(u, v, /, *rest) -> GoalProto: + return goal_from_constraint(neqc(u, v, *rest)) -def neqc(pairs: tuple[tuple[Value, Value], ...]) -> ConstraintFunction: +def neqc(u, v, *rest) -> ConstraintFunction: def _neqc(state: State) -> State | None: - (u, v), *rest = pairs - new_sub = reduce(flip(maybe_unify), rest, unify(u, v, state.sub)) + new_sub = reduce(flip(maybe_unify), pairs(rest), unify(u, v, state.sub)) if new_sub is None: return state elif new_sub == state.sub: return None prefix = get_sub_prefix(new_sub, state.sub) - remaining_pairs = list(prefix.items()) + remaining_pairs = tuple(prefix.items()) return state.set( constraints=extend_constraint_store( - Constraint(neqc, [remaining_pairs]), state.constraints + Constraint(neqc, tuple(unpairs(remaining_pairs))), state.constraints ) ) @@ -440,7 +458,11 @@ def map_sum(goal_constructor: Callable[[A], GoalProto], xs: list[A]) -> GoalProt def get_sub_prefix(new_sub: Substitution, old_sub: Substitution) -> Substitution: - return pmap({k: v for k, v in new_sub.items() if k not in old_sub}) + mutation = new_sub.mutate() + for k in new_sub: + if k in old_sub: + del mutation[k] + return mutation.finish() def fresh(fp: Callable) -> GoalProto: diff --git a/src/microkanren/fd.py b/src/microkanren/fd.py index 6e8d6a8..433e3fa 100644 --- a/src/microkanren/fd.py +++ b/src/microkanren/fd.py @@ -71,7 +71,7 @@ def _ltfdc(state: State) -> State | None: next_state = state.set( constraints=extend_constraint_store( - Constraint(ltfdc, [_u, _v]), state.constraints + Constraint(ltfdc, (_u, _v)), state.constraints ) ) if not dom_u or not dom_v: @@ -100,7 +100,7 @@ def _ltefdc(state: State) -> State | None: next_state = state.set( constraints=extend_constraint_store( - Constraint(ltefdc, [_u, _v]), state.constraints + Constraint(ltefdc, (_u, _v)), state.constraints ) ) if not dom_u or not dom_v: @@ -131,7 +131,7 @@ def _plusfdc(state: State) -> State | None: next_state = state.set( constraints=extend_constraint_store( - Constraint(plusfdc, [_u, _v, _w]), state.constraints + Constraint(plusfdc, (_u, _v, _w)), state.constraints ) ) if not all((dom_u, dom_v, dom_w)): @@ -167,7 +167,7 @@ def _neqfdc(state: State) -> State | None: if dom_u is None or dom_v is None: return state.set( constraints=extend_constraint_store( - Constraint(neqfdc, [_u, _v]), state.constraints + Constraint(neqfdc, (_u, _v)), state.constraints ) ) elif len(dom_u) == 1 and len(dom_v) == 1 and dom_u == dom_v: @@ -177,7 +177,7 @@ def _neqfdc(state: State) -> State | None: next_state = state.set( constraints=extend_constraint_store( - Constraint(neqfdc, [_u, _v]), state.constraints + Constraint(neqfdc, (_u, _v)), state.constraints ) ) if len(dom_u) == 1: @@ -197,8 +197,8 @@ def alldifffd(*vs: Value) -> GoalProto: def alldifffdc(*vs: Value) -> ConstraintFunction: def _alldifffdc(state: State) -> State | None: unresolved, values = partition(lambda v: isinstance(v, Var), vs) - unresolved = list(unresolved) - values = list(values) + unresolved = tuple(unresolved) + values = tuple(values) values_domain = make_domain(*values) if len(values) == len(values_domain): return alldifffdc_resolve(unresolved, values_domain)(state) @@ -207,10 +207,12 @@ def _alldifffdc(state: State) -> State | None: return _alldifffdc -def alldifffdc_resolve(unresolved: list[Var], values: set[Value]) -> ConstraintFunction: +def alldifffdc_resolve( + unresolved: tuple[Var], values: set[Value] +) -> ConstraintFunction: def _alldifffdc_resolve(state: State) -> State | None: nonlocal values - values = values.copy() + values = set(values) remains_unresolved = [] for var in unresolved: v = walk(var, state.sub) @@ -223,7 +225,9 @@ def _alldifffdc_resolve(state: State) -> State | None: next_state = state.set( constraints=extend_constraint_store( - Constraint(alldifffdc_resolve, [remains_unresolved, values]), + Constraint( + alldifffdc_resolve, (tuple(remains_unresolved), tuple(values)) + ), state.constraints, ) ) @@ -298,7 +302,7 @@ def process_prefix_fd( (x, v), *_ = prefix.items() t = compose_constraints( run_constraints([x], constraints), - process_prefix_fd(prefix.remove(x), constraints), + process_prefix_fd(prefix.delete(x), constraints), ) def _process_prefix_fd(state: State): diff --git a/tests/test_fd.py b/tests/test_fd.py index a93392d..f271a8b 100644 --- a/tests/test_fd.py +++ b/tests/test_fd.py @@ -114,7 +114,7 @@ def test_neq_with_domfd(self): """ If neq(x, n), then n cannot be in the domain of x. """ - result = run_all(lambda x: domfd(x, make_domain(1, 2, 3)) & neq((x, 2))) + result = run_all(lambda x: domfd(x, make_domain(1, 2, 3)) & neq(x, 2)) assert set(result) == {1, 3} def test_neq_with_ltefd(self): @@ -125,7 +125,7 @@ def test_neq_with_ltefd(self): lambda x, y: domfd(x, make_domain(1, 2, 3)) & domfd(y, make_domain(1, 2)) & ltefd(x, y) - & neq((x, 1)) + & neq(x, 1) ) assert result == [(2, 2)]