Skip to content

Commit

Permalink
knuckleclosure
Browse files Browse the repository at this point in the history
  • Loading branch information
philzook58 committed Jan 23, 2025
1 parent c0baae0 commit 2bdd21a
Showing 1 changed file with 93 additions and 16 deletions.
109 changes: 93 additions & 16 deletions kdrag/reflect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable, no_type_check
from collections import namedtuple
import operator
from dataclasses import dataclass


def sort_of_type(t: type) -> smt.SortRef:
Expand Down Expand Up @@ -102,7 +103,23 @@ def namedtuple_of_constructor(sort: smt.DatatypeSortRef, idx: int):
return namedtuple(decl.name(), fields)


# could env be just a python module? That's kind of intriguing
@dataclass
class KnuckleClosure:
"""
A closure that can be used to evaluate expressions in a given environment.
We don't use lambda so that we can inspect
"""

lam: smt.QuantifierRef
env: dict[str, object]

def __call__(self, *args):
# TODO: Should I open binder more eagerly before call?
vs, body = kd.utils.open_binder(self.lam)
return eval_(
body,
globals={**{v.decl().name(): arg for v, arg in zip(vs, args)}, **self.env},
)


# This is fiendishly difficult to typecheck probably
Expand All @@ -126,12 +143,8 @@ def eval_(e: smt.ExprRef, globals={}):
"""
if isinstance(e, smt.QuantifierRef):
if e.is_lambda():
vs, body = kd.utils.open_binder(e)
# also possibly lookup Lambda in globals.
# and/or use KnuckleClosure.
return lambda *args: eval_(
body, {**{v.decl().name(): arg for v, arg in zip(vs, args)}, **globals}
)
return KnuckleClosure(e, globals)
else:
raise ValueError("Quantifier not implemented", e)
elif isinstance(e, smt.IntNumRef): # smt.is_int_value(e):
Expand All @@ -141,6 +154,7 @@ def eval_(e: smt.ExprRef, globals={}):
elif isinstance(e, smt.FPNumRef):
raise ValueError("FPNumRef not implemented")
elif smt.is_app(e):
# Lazy evaluation of if, and, or, implies
if smt.is_if(e):
c = eval_(e.arg(0), globals=globals)
if isinstance(c, bool):
Expand All @@ -157,7 +171,37 @@ def eval_(e: smt.ExprRef, globals={}):
else:
# possibly lookup "If" in environment
raise ValueError("If condition not a boolean or expression", c)

elif smt.is_and(e):
acc = []
for child in e.children():
echild = eval_(child, globals=globals)
if echild is False:
return False
elif echild is True:
continue
else:
acc.append(echild)
return smt.And(acc)
elif smt.is_or(e):
acc = []
for child in e.children():
echild = eval_(child, globals=globals)
if echild is True:
return True
elif echild is False:
continue
else:
acc.append(echild)
return smt.Or(acc)
elif smt.is_implies(e):
cond = eval_(e.arg(0), globals=globals)
if cond is True:
return eval_(e.arg(1), globals=globals)
elif cond is False:
return True
else:
return smt.Implies(cond, eval_(e.arg(1), globals=globals))
# eval all children
children = list(map(lambda x: eval_(x, globals), e.children()))
decl = e.decl()
if decl in kd.kernel.defns:
Expand All @@ -181,12 +225,14 @@ def eval_(e: smt.ExprRef, globals={}):
return getattr(children[0], e.decl().name())
elif smt.is_select(e): # apply
return children[0](*children[1:])
# elif is_store(e): hmm
elif smt.is_store(e):
raise ValueError("Store not implemented", e)
# #return children[0]._replace(children[1], children[2])
elif smt.is_const_array(e):
return lambda x: children[0] # Maybe return a Closure here?
elif smt.is_map(e):
return map(children[0], children[1])
raise ValueError("Map not implemented", e)
# return map(children[0], children[1])
elif smt.is_constructor(e):
sort, decl = e.sort(), e.decl()
i = 0 # Can't have 0 constructors. Makes typechecker happy
Expand All @@ -200,14 +246,8 @@ def eval_(e: smt.ExprRef, globals={}):
return True
elif smt.is_false(e):
return False
elif smt.is_and(e):
return functools.reduce(operator.and_, children)
elif smt.is_or(e):
return functools.reduce(operator.or_, children)
elif smt.is_not(e):
return ~children[0]
elif smt.is_implies(e):
return (~children[0]) | children[1]
elif smt.is_eq(e):
return children[0] == children[1]
elif smt.is_lt(e):
Expand Down Expand Up @@ -236,7 +276,7 @@ def eval_(e: smt.ExprRef, globals={}):
# return e.as_string()
# elif isisntance(e, ArithRef):
elif smt.is_add(e):
return sum(children)
return functools.reduce(operator.add, children)
elif smt.is_mul(e):
return functools.reduce(operator.mul, children)
elif smt.is_sub(e):
Expand Down Expand Up @@ -269,6 +309,8 @@ def reify(s: smt.SortRef, x: object) -> smt.ExprRef:
>>> reify(smt.RealSort(), fractions.Fraction(10,16))
5/8
"""
if isinstance(x, KnuckleClosure):
return x.lam # TODO: Do I need to substitute in the env? Probably. That stinks. recurse into subterms, find name matches, reify those out of env
if isinstance(x, smt.ExprRef):
if x.sort() != s:
raise ValueError(f"Sort mismatch of {x} : {x.sort()} != {s}")
Expand Down Expand Up @@ -302,3 +344,38 @@ def reify(s: smt.SortRef, x: object) -> smt.ExprRef:
return smt.StringVal(x)
else:
raise ValueError(f"Cannot reify {x} as an expression")


def infer_sort(x: object) -> smt.SortRef:
if isinstance(x, int):
return smt.IntSort()
elif isinstance(x, fractions.Fraction):
return smt.RealSort()
elif isinstance(x, bool):
return smt.BoolSort()
elif isinstance(x, str):
return smt.StringSort()
elif isinstance(x, list):
assert len(x) > 0
return smt.SeqSort(infer_sort(x[0]))
elif isinstance(x, KnuckleClosure):
return x.lam.sort()
else:
raise ValueError(f"Cannot infer sort of {x}")


def nbe(x: smt.ExprRef) -> smt.ExprRef:
"""
Normalization by evaluation.
>>> nbe(smt.IntVal(41) + smt.IntVal(1))
42
>>> x,y = smt.Ints("x y")
>>> nbe(smt.Lambda([x], x + 1)[3])
4
>>> nbe(smt.Lambda([x], x + 1))
Lambda(x, x + 1)
>>> nbe(smt.Lambda([x], smt.IntVal(3) + 1))
Lambda(x, 3 + 1)
"""
return reify(x.sort(), eval_(x))

0 comments on commit 2bdd21a

Please sign in to comment.