Skip to content

Commit

Permalink
Merge pull request #4 from jams2/more-fd-goals
Browse files Browse the repository at this point in the history
Add `ltfd`
  • Loading branch information
jams2 authored Apr 23, 2023
2 parents 7a702ad + 1a8d213 commit 78b233a
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 23 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- `ltfd` goal

### Changed
- Moved finite domain goal constructors into fd.py
- Exit early from FD goals if any var has no domain
- When exiting early from FD goals, make sure a constraint is added to the store

## [0.3.0] - 2023-04-10

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dev = [
"ruff ~= 0.0.256",
]
testing = ["pytest ~= 7.2.2", "pytest-profiling ~= 1.7.0"]
build = ["hatch ~= 1.7.0"]

[build-system]
requires = ["hatchling"]
Expand Down
79 changes: 56 additions & 23 deletions src/microkanren/fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,35 @@ def infd(values: tuple[Value], domain, /) -> GoalProto:
return goal_from_constraint(infdc)


def ltfd(u: Value, v: Value) -> GoalProto:
return goal_from_constraint(ltfdc(u, v))


def ltfdc(u: Value, v: Value) -> ConstraintFunction:
def _ltfdc(state: State) -> State | None:
_u = walk(u, state.sub)
_v = walk(v, state.sub)
dom_u = state.get_domain(_u) if isinstance(_u, Var) else make_domain(_u)
dom_v = state.get_domain(_v) if isinstance(_v, Var) else make_domain(_v)

next_state = state.set(
constraints=extend_constraint_store(
Constraint(ltfdc, [_u, _v]), state.constraints
)
)
if not dom_u or not dom_v:
return next_state

max_v = max(dom_v)
min_u = min(dom_u)
return compose_constraints(
process_domain(_u, make_domain(*(i for i in dom_u if i < max_v))),
process_domain(_v, make_domain(*(i for i in dom_v if i > min_u))),
)(next_state)

return _ltfdc


def ltefd(u: Value, v: Value) -> GoalProto:
return goal_from_constraint(ltefdc(u, v))

Expand All @@ -68,19 +97,21 @@ def _ltefdc(state: State) -> State | None:
_v = walk(v, state.sub)
dom_u = state.get_domain(_u) if isinstance(_u, Var) else make_domain(_u)
dom_v = state.get_domain(_v) if isinstance(_v, Var) else make_domain(_v)

next_state = state.set(
constraints=extend_constraint_store(
Constraint(ltefdc, [_u, _v]), state.constraints
)
)
if dom_u and dom_v:
max_v = max(dom_v)
min_u = min(dom_u)
return compose_constraints(
process_domain(_u, make_domain(*(i for i in dom_u if i <= max_v))),
process_domain(_v, make_domain(*(i for i in dom_v if i >= min_u))),
)(next_state)
return state
if not dom_u or not dom_v:
return next_state

max_v = max(dom_v)
min_u = min(dom_u)
return compose_constraints(
process_domain(_u, make_domain(*(i for i in dom_u if i <= max_v))),
process_domain(_v, make_domain(*(i for i in dom_v if i >= min_u))),
)(next_state)

return _ltefdc

Expand All @@ -97,26 +128,28 @@ def _plusfdc(state: State) -> State | None:
dom_u = state.get_domain(_u) if isinstance(_u, Var) else make_domain(_u)
dom_v = state.get_domain(_v) if isinstance(_v, Var) else make_domain(_v)
dom_w = state.get_domain(_w) if isinstance(_w, Var) else make_domain(_w)

next_state = state.set(
constraints=extend_constraint_store(
Constraint(plusfdc, [_u, _v, _w]), state.constraints
)
)
if dom_u and dom_v and dom_w:
min_u = min(dom_u)
max_u = max(dom_u)
min_v = min(dom_v)
max_v = max(dom_v)
min_w = min(dom_w)
max_w = max(dom_w)
return compose_constraints(
process_domain(_w, mkrange(min_u + min_v, max_u + max_v)),
compose_constraints(
process_domain(_u, mkrange(min_w - max_v, max_w - min_v)),
process_domain(_v, mkrange(min_w - max_u, max_w - min_u)),
),
)(next_state)
return state
if not all((dom_u, dom_v, dom_w)):
return next_state

min_u = min(dom_u)
max_u = max(dom_u)
min_v = min(dom_v)
max_v = max(dom_v)
min_w = min(dom_w)
max_w = max(dom_w)
return compose_constraints(
process_domain(_w, mkrange(min_u + min_v, max_u + max_v)),
compose_constraints(
process_domain(_u, mkrange(min_w - max_v, max_w - min_v)),
process_domain(_v, mkrange(min_w - max_u, max_w - min_u)),
),
)(next_state)

return _plusfdc

Expand Down
17 changes: 17 additions & 0 deletions tests/test_fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
enforce_constraints_fd,
infd,
ltefd,
ltfd,
make_domain,
mkrange,
neqfd,
Expand Down Expand Up @@ -93,6 +94,22 @@ def test_ltefd(self, a, b, expected_x):
for x, y in result:
assert x <= y

@pytest.mark.parametrize(
("a", "b", "expected_x"),
[
(make_domain(1, 2, 3, 4), make_domain(2, 3), make_domain(1, 2)),
(make_domain(1, 2, 3), make_domain(4, 5), make_domain(1, 2, 3)),
(make_domain(4, 5), make_domain(1, 2), make_domain()),
(make_domain(3, 4), make_domain(2, 3, 4, 5), make_domain(3, 4)),
(make_domain(1, 2, 3, 4), make_domain(1, 2, 3, 4), make_domain(1, 2, 3)),
],
)
def test_ltfd(self, a, b, expected_x):
result = run_all(lambda x, y: domfd(x, a) & domfd(y, b) & ltfd(x, y))
assert {x[0] for x in result} == expected_x
for x, y in result:
assert x < y

def test_neq_with_domfd(self):
"""
If neq(x, n), then n cannot be in the domain of x.
Expand Down

0 comments on commit 78b233a

Please sign in to comment.