Skip to content

Commit

Permalink
fixed test
Browse files Browse the repository at this point in the history
  • Loading branch information
philzook58 committed Jan 23, 2025
1 parent 2bdd21a commit bab3137
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 12 deletions.
5 changes: 3 additions & 2 deletions examples/soft_found/lf/Lists.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@
}
],
"source": [
"kd.utils.eval_(fst(q))\n",
"kd.utils.eval_(snd(q))"
"import kdrag.reflect\n",
"kd.reflect.eval_(fst(q))\n",
"kd.reflect.eval_(snd(q))"
]
},
{
Expand Down
73 changes: 64 additions & 9 deletions kdrag/hypothesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
),
)

names = st.sampled_from("x y z".split())
# I think we'll get more interesting bugs with more name clashes rather than exploring weird names


def binop(children, op) -> st.SearchStrategy:
return st.tuples(children, children).map(op)
return st.tuples(children, children).map(lambda t: op(t[0], t[1]))


def binops(children) -> st.SearchStrategy:
Expand All @@ -27,23 +30,75 @@ def binops(children) -> st.SearchStrategy:
)


smt_bool_val: st.SearchStrategy[smt.BoolRef] = st.sampled_from(
[smt.BoolVal(True), smt.BoolVal(False)]
)

smt_int_val: st.SearchStrategy[smt.ArithRef] = st.integers().map(smt.IntVal)
smt_int_expr = st.recursive(
smt_int_val,
binops,
st.one_of(smt_int_val, names.map(smt.Int)),
lambda children: st.one_of(
binop(children, op.add),
binop(children, op.sub),
binop(children, op.mul),
binop(children, op.truediv),
st.deferred(
lambda: st.tuples(smt_bool_expr, children, children).map(
lambda x: smt.If(x[0], x[1], x[2])
)
),
),
)


smt_bool_val: st.SearchStrategy[smt.BoolRef] = st.sampled_from(
[smt.BoolVal(True), smt.BoolVal(False)]
)

smt_real_val = st.fractions().map(smt.RealVal)
smt_real_expr = st.recursive(
smt_real_val,
binops,
st.one_of(
smt_real_val,
names.map(smt.Real),
),
lambda children: st.one_of(
binop(children, op.add),
binop(children, op.sub),
binop(children, op.mul),
binop(children, op.truediv),
st.deferred(
lambda: st.tuples(smt_bool_expr, children, children).map(
lambda x: smt.If(x[0], x[1], x[2])
)
),
),
)


def compares(strat) -> st.SearchStrategy:
return st.one_of(
binop(strat, op.eq),
binop(strat, op.ne),
binop(strat, op.lt),
binop(strat, op.le),
binop(strat, op.gt),
binop(strat, op.ge),
)


smt_bool_expr = st.recursive(
st.one_of(
smt_bool_val,
names.map(smt.Bool),
compares(smt_int_expr),
compares(smt_real_expr),
),
lambda children: st.one_of(
binop(children, smt.And),
binop(children, smt.Or),
binop(children, smt.Xor),
st.tuples(children, children).map(lambda x: x[0] == x[1]),
st.tuples(children, children).map(lambda x: smt.Implies(x[0], x[1])),
),
)


smt_string_val = st.text().map(smt.StringVal)


Expand Down
25 changes: 25 additions & 0 deletions kdrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,30 @@ def expr_to_lean(expr: smt.ExprRef):
"""


def free_vars(t: smt.ExprRef) -> set[smt.ExprRef]:
"""
Return free variables in an expression. Looks at kernel.defns to determine if contacts are free.
If you have meaningful constants no registered there, this may not work.
>>> x,y = smt.Ints("x y")
>>> free_vars(smt.Lambda([x], x + y + 1))
{y}
"""
fvs = set()
todo = [t]
while todo:
t = todo.pop()
if smt.is_var(t) or is_value(t) or smt.is_constructor(t):
continue
if smt.is_const(t) and t.decl() not in kd.kernel.defns:
fvs.add(t)
elif isinstance(t, smt.QuantifierRef):
todo.append(t.body())
elif smt.is_app(t):
todo.extend(t.children())
return fvs


def prune(
thm: smt.BoolRef | smt.QuantifierRef | kd.kernel.Proof, by=[], timeout=1000
) -> list[smt.ExprRef | kd.kernel.Proof]:
Expand Down Expand Up @@ -296,6 +320,7 @@ def decls(t: smt.ExprRef):


def is_value(t: smt.ExprRef):
# TODO, could make faster check using Z3 internals
return (
smt.is_int_value(t)
or smt.is_rational_value(t)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_hypothesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,14 @@ def test_forall1():

@given(hyp.smt_sorts)
def test_reflect_sort(s):
assert reflect.sort_of_type(reflect.type_of_sort(s)) == s
assert reflect.sort_of_type(reflect.type_of_sort(s)) == s

@pytest.mark.slow
@given(hyp.smt_bool_expr)
def test_bool_expr(e):
assert e.sort() == smt.BoolSort()

@pytest.mark.slow
@given(hyp.smt_int_expr)
def test_int_expr(e):
assert e.sort() == smt.IntSort()

0 comments on commit bab3137

Please sign in to comment.