Skip to content

Commit

Permalink
Merge pull request #5 from jams2/feature/immutables-map
Browse files Browse the repository at this point in the history
Feature/immutables map
  • Loading branch information
jams2 committed Apr 23, 2023
2 parents 78b233a + cd84ad9 commit d31630d
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 35 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Moved finite domain goal constructors into fd.py
- Exit early from FD goals if any var has no domain
- When exiting early from FD goals, make sure a constraint is added to the store
- Use `immutables` map instead of pyrsistent map (better performance)
- Change `neq` signature from `neq((a, b), *rest_pairs)` to `neq(a, b, /, *rest)`
- Make `Constraint` frozen so it's hashable
- Store `Constraint` operands as tuples rather than lists/sets

## [0.3.0] - 2023-04-10

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ classifiers = [
dependencies = [
"pyrsistent ~= 0.19",
"fastcons ~= 0.3.0",
"immutables ~= 0.19",
]
license = {file = "LICENSE"}

Expand Down
66 changes: 44 additions & 22 deletions src/microkanren/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

from collections.abc import Callable, Generator
from dataclasses import dataclass
from functools import reduce, update_wrapper, wraps
from functools import reduce, wraps
from typing import Any, Optional, Protocol, TypeAlias, TypeVar

import immutables
from fastcons import cons, nil
from pyrsistent import PClass, field, pmap
from pyrsistent.typing import PMap
from pyrsistent import PClass, field

from microkanren.utils import identity

Expand All @@ -38,19 +38,19 @@ def __repr__(self):
Value: TypeAlias = (
Var | int | str | bool | tuple["Value", ...] | list["Value"] | cons | nil
)
Substitution: TypeAlias = PMap[Var, Value]
Substitution: TypeAlias = immutables.Map[Var, Value]
NeqStore: TypeAlias = list[list[tuple[Var, Value]]]
DomainStore: TypeAlias = PMap[Var, set[int]]
DomainStore: TypeAlias = immutables.Map[Var, set[int]]
ConstraintFunction: TypeAlias = Callable[["State"], Optional["State"]]
ConstraintStore: TypeAlias = list["Constraint"]


def empty_sub() -> Substitution:
return pmap()
return immutables.Map()


def empty_domain_store() -> DomainStore:
return pmap()
return immutables.Map()


def empty_constraint_store() -> ConstraintStore:
Expand Down Expand Up @@ -81,10 +81,10 @@ def __call__(self, *args: Any) -> ConstraintFunction:
...


@dataclass(slots=True)
@dataclass(slots=True, frozen=True)
class Constraint:
func: ConstraintProto
operands: list[Value]
operands: tuple[Value]

def __call__(self, state: State) -> State | None:
return self.func(*self.operands)(state)
Expand All @@ -105,7 +105,6 @@ def __call__(self, *args: Value) -> GoalProto:

class Goal:
def __init__(self, goal: GoalProto):
update_wrapper(self, goal)
self.goal = goal

def __call__(self, state: State) -> Stream:
Expand Down Expand Up @@ -251,8 +250,8 @@ def delay(g: GoalProto) -> GoalProto:

def disj(g: GoalProto, *goals: GoalProto) -> GoalProto:
if goals == ():
return delay(g)
return reduce(_disj, (delay(goal) for goal in goals), delay(g))
return g
return reduce(_disj, (goal for goal in goals), g)


def _disj(g1: GoalProto, g2: GoalProto) -> GoalProto:
Expand All @@ -264,8 +263,8 @@ def __disj(state: State) -> Stream:

def conj(g: GoalProto, *goals: GoalProto) -> GoalProto:
if goals == ():
return delay(g)
return reduce(_conj, (delay(goal) for goal in goals), delay(g))
return g
return reduce(_conj, (goal for goal in goals), g)


def _conj(g1: GoalProto, g2: GoalProto) -> GoalProto:
Expand Down Expand Up @@ -308,6 +307,26 @@ def unify_all(
return None


def pairs(xs):
_xs = iter(xs)
while True:
try:
a = next(_xs)
except StopIteration:
break
try:
b = next(_xs)
yield (a, b)
except StopIteration:
raise ValueError("got sequence with uneven length")


def unpairs(xs):
for a, b in xs:
yield a
yield b


def maybe_unify(
pair: tuple[Value, Value], sub: Substitution | None
) -> Substitution | None:
Expand All @@ -325,23 +344,22 @@ def _flipped(x, y):
return _flipped


def neq(*pairs) -> GoalProto:
return goal_from_constraint(neqc(pairs))
def neq(u, v, /, *rest) -> GoalProto:
return goal_from_constraint(neqc(u, v, *rest))


def neqc(pairs: tuple[tuple[Value, Value], ...]) -> ConstraintFunction:
def neqc(u, v, *rest) -> ConstraintFunction:
def _neqc(state: State) -> State | None:
(u, v), *rest = pairs
new_sub = reduce(flip(maybe_unify), rest, unify(u, v, state.sub))
new_sub = reduce(flip(maybe_unify), pairs(rest), unify(u, v, state.sub))
if new_sub is None:
return state
elif new_sub == state.sub:
return None
prefix = get_sub_prefix(new_sub, state.sub)
remaining_pairs = list(prefix.items())
remaining_pairs = tuple(prefix.items())
return state.set(
constraints=extend_constraint_store(
Constraint(neqc, [remaining_pairs]), state.constraints
Constraint(neqc, tuple(unpairs(remaining_pairs))), state.constraints
)
)

Expand Down Expand Up @@ -440,7 +458,11 @@ def map_sum(goal_constructor: Callable[[A], GoalProto], xs: list[A]) -> GoalProt


def get_sub_prefix(new_sub: Substitution, old_sub: Substitution) -> Substitution:
return pmap({k: v for k, v in new_sub.items() if k not in old_sub})
mutation = new_sub.mutate()
for k in new_sub:
if k in old_sub:
del mutation[k]
return mutation.finish()


def fresh(fp: Callable) -> GoalProto:
Expand Down
26 changes: 15 additions & 11 deletions src/microkanren/fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _ltfdc(state: State) -> State | None:

next_state = state.set(
constraints=extend_constraint_store(
Constraint(ltfdc, [_u, _v]), state.constraints
Constraint(ltfdc, (_u, _v)), state.constraints
)
)
if not dom_u or not dom_v:
Expand Down Expand Up @@ -100,7 +100,7 @@ def _ltefdc(state: State) -> State | None:

next_state = state.set(
constraints=extend_constraint_store(
Constraint(ltefdc, [_u, _v]), state.constraints
Constraint(ltefdc, (_u, _v)), state.constraints
)
)
if not dom_u or not dom_v:
Expand Down Expand Up @@ -131,7 +131,7 @@ def _plusfdc(state: State) -> State | None:

next_state = state.set(
constraints=extend_constraint_store(
Constraint(plusfdc, [_u, _v, _w]), state.constraints
Constraint(plusfdc, (_u, _v, _w)), state.constraints
)
)
if not all((dom_u, dom_v, dom_w)):
Expand Down Expand Up @@ -167,7 +167,7 @@ def _neqfdc(state: State) -> State | None:
if dom_u is None or dom_v is None:
return state.set(
constraints=extend_constraint_store(
Constraint(neqfdc, [_u, _v]), state.constraints
Constraint(neqfdc, (_u, _v)), state.constraints
)
)
elif len(dom_u) == 1 and len(dom_v) == 1 and dom_u == dom_v:
Expand All @@ -177,7 +177,7 @@ def _neqfdc(state: State) -> State | None:

next_state = state.set(
constraints=extend_constraint_store(
Constraint(neqfdc, [_u, _v]), state.constraints
Constraint(neqfdc, (_u, _v)), state.constraints
)
)
if len(dom_u) == 1:
Expand All @@ -197,8 +197,8 @@ def alldifffd(*vs: Value) -> GoalProto:
def alldifffdc(*vs: Value) -> ConstraintFunction:
def _alldifffdc(state: State) -> State | None:
unresolved, values = partition(lambda v: isinstance(v, Var), vs)
unresolved = list(unresolved)
values = list(values)
unresolved = tuple(unresolved)
values = tuple(values)
values_domain = make_domain(*values)
if len(values) == len(values_domain):
return alldifffdc_resolve(unresolved, values_domain)(state)
Expand All @@ -207,10 +207,12 @@ def _alldifffdc(state: State) -> State | None:
return _alldifffdc


def alldifffdc_resolve(unresolved: list[Var], values: set[Value]) -> ConstraintFunction:
def alldifffdc_resolve(
unresolved: tuple[Var], values: set[Value]
) -> ConstraintFunction:
def _alldifffdc_resolve(state: State) -> State | None:
nonlocal values
values = values.copy()
values = set(values)
remains_unresolved = []
for var in unresolved:
v = walk(var, state.sub)
Expand All @@ -223,7 +225,9 @@ def _alldifffdc_resolve(state: State) -> State | None:

next_state = state.set(
constraints=extend_constraint_store(
Constraint(alldifffdc_resolve, [remains_unresolved, values]),
Constraint(
alldifffdc_resolve, (tuple(remains_unresolved), tuple(values))
),
state.constraints,
)
)
Expand Down Expand Up @@ -298,7 +302,7 @@ def process_prefix_fd(
(x, v), *_ = prefix.items()
t = compose_constraints(
run_constraints([x], constraints),
process_prefix_fd(prefix.remove(x), constraints),
process_prefix_fd(prefix.delete(x), constraints),
)

def _process_prefix_fd(state: State):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_neq_with_domfd(self):
"""
If neq(x, n), then n cannot be in the domain of x.
"""
result = run_all(lambda x: domfd(x, make_domain(1, 2, 3)) & neq((x, 2)))
result = run_all(lambda x: domfd(x, make_domain(1, 2, 3)) & neq(x, 2))
assert set(result) == {1, 3}

def test_neq_with_ltefd(self):
Expand All @@ -125,7 +125,7 @@ def test_neq_with_ltefd(self):
lambda x, y: domfd(x, make_domain(1, 2, 3))
& domfd(y, make_domain(1, 2))
& ltefd(x, y)
& neq((x, 1))
& neq(x, 1)
)
assert result == [(2, 2)]

Expand Down

0 comments on commit d31630d

Please sign in to comment.