Skip to content

Commit

Permalink
small emitter improvements. slice and array syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
utkn committed Dec 16, 2024
1 parent c81e159 commit 82f69a1
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 130 deletions.
10 changes: 10 additions & 0 deletions Lampe/Lampe.lean
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,13 @@ example {p} {x y : Tp.denote p .field} :
simp_all
. sl
aesop

nr_def createSlice<>() -> [Field] {
let a = &[1 : Field, 2 : Field];
a
}

nr_def createArray<>() -> [Field; 2] {
let a = [1 : Field, 2 : Field];
a
}
1 change: 1 addition & 0 deletions Lampe/Lampe/Ast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ inductive FunctionIdent (rep : Tp → Type) : Type where

inductive Expr (rep : Tp → Type) : Tp → Type where
| lit : (tp : Tp) → Nat → Expr rep tp
| list : List (Expr rep Tp.bool) → Expr rep (.slice tp)
| var : rep tp → Expr rep tp
| letIn : Expr rep t₁ → (rep t₁ → Expr rep t₂) → Expr rep t₂
| call : HList Kind.denote tyKinds → (argTypes : List Tp) → (res : Tp) → FunctionIdent rep → HList rep argTypes → Expr rep res
Expand Down
14 changes: 14 additions & 0 deletions Lampe/Lampe/Builtin/Array.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
import Lampe.Builtin.Basic
namespace Lampe.Builtin

/--
Defines the builtin array constructor.
-/
def mkArray (n : Nat) := newGenericPureBuiltin
(fun (argTps, tp) => ⟨argTps, (.array tp n)⟩)
(fun (argTps, tp) args => ⟨argTps = List.replicate n tp ∧ n < 2^32,
fun h => Mathlib.Vector.ofFn fun i => List.get (HList.toList args (by tauto)) (by
have hn : BitVec.toNat (n := 32) ↑n = n := by
simp_all
rw [hn] at i
convert i
apply HList.toList_len_is_n
)⟩)

/--
Defines the function that evaluates to an array's length `n`.
This builtin evaluates to an `U 32`. Hence, we assume that `n < 2^32`.
Expand Down
8 changes: 8 additions & 0 deletions Lampe/Lampe/Builtin/Slice.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import Lampe.Builtin.Basic
namespace Lampe.Builtin

/--
Defines the builtin slice constructor.
-/
def mkSlice (n : Nat) := newGenericPureBuiltin
(fun (argTps, tp) => ⟨argTps, (.slice tp)⟩)
(fun (argTps, tp) args => ⟨argTps = List.replicate n tp,
fun h => HList.toList args h⟩)

/--
Defines the indexing of a slice `l : List tp` with `i : U 32`
We make the following assumptions:
Expand Down
10 changes: 6 additions & 4 deletions Lampe/Lampe/Semantics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ inductive TraitResolution (Γ : Env): TraitImplRef → List (Ident × Function)
TraitResolution Γ ref (impl.impl implGenerics)

inductive Omni : Env → State p → Expr (Tp.denote p) tp → (Option (State p × Tp.denote p tp) → Prop) → Prop where
| skip {Q} : Q (some (st, ())) → Omni Γ st (.skip) Q
| litField {Q} : Q (some (st, n)) → Omni Γ st (.lit .field n) Q
| litU {Q} : Q (some (st, ↑n)) → Omni Γ st (.lit (.u s) n) Q
| litI {Q} : Q (some (st, ↑n)) → Omni Γ st (.lit (.i s) n) Q
| litFalse {Q} : Q (some (st, false)) → Omni Γ st (.lit .bool 0) Q
| litTrue {Q} : Q (some (st, true)) → Omni Γ st (.lit .bool 1) Q
| litRef {Q} : Q (some (st, ⟨r⟩)) → Omni Γ st (.lit (.ref tp) r) Q
| litU {Q} : Q (some (st, ↑n)) → Omni Γ st (.lit (.u s) n) Q
| var {Q} : Q (some (st, v)) → Omni Γ st (.var v) Q
| skip {Q} : Q (some (st, ())) → Omni Γ st (.skip) Q
| iteTrue {mainBranch elseBranch} :
Omni Γ st mainBranch Q →
Omni Γ st (Expr.ite true mainBranch elseBranch) Q
Expand Down Expand Up @@ -115,11 +116,12 @@ theorem Omni.frame {p Γ tp} {st₁ st₂ : State p} {e : Expr (Tp.denote p) tp}
) := by
intro h
induction h with
| litField hq
| skip hq
| litField hq
| litU hq
| litI hq
| litFalse hq
| litTrue hq
| litU hq
| litRef hq
| var hq =>
intro
Expand Down
96 changes: 57 additions & 39 deletions Lampe/Lampe/Syntax.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ syntax ident : nr_type
syntax "${" term "}" : nr_type
syntax nr_ident "<" nr_type,* ">" : nr_type
syntax "[" nr_type "]" : nr_type
syntax "[" nr_type ";" term "]" : nr_type

def mkListLit [Monad m] [MonadQuotation m] [MonadExceptOf Exception m] [MonadError m] : List (TSyntax `term) → m (TSyntax `term)
| [] => `([])
Expand Down Expand Up @@ -65,32 +66,33 @@ partial def mkNrType [Monad m] [MonadQuotation m] [MonadExceptOf Exception m] [M
`(Struct.tp $name $(←mkHListLit generics))
| `(nr_type| ${ $i }) => pure i
| `(nr_type| [ $tp ]) => do `(Tp.slice $(←mkNrType tp))
| `(nr_type| [ $tp ; $len:num ]) => do `(Tp.array $(←mkNrType tp) $len)
| _ => throwUnsupportedSyntax

partial def mkBuiltin [Monad m] [MonadQuotation m] [MonadExceptOf Exception m] [MonadError m] (i : String) : m (TSyntax `term) := match i with
| "add" => ``(Builtin.fAdd)
| "sub" => ``(Builtin.fSub)
| "mul" => ``(Builtin.fMul)
| "div" => ``(Builtin.fDiv)
| "eq" => ``(Builtin.fEq)
| "assert" => ``(Builtin.assert)
| "not" => ``(Builtin.bNot)
| "lt" => ``(Builtin.lt)
| "index" => ``(Builtin.index)
| "cast" => ``(Builtin.cast)
| "modulus_num_bits" => ``(Builtin.fModNumBits)
| "to_le_bytes" => ``(Builtin.toLeBytes)
| "fresh" => ``(Builtin.fresh)
| "slice_len" => ``(Builtin.sliceLen)
| "slice_push_back" => ``(Builtin.slicePushBack)
| "slice_push_front" => ``(Builtin.slicePushFront)
| "slice_pop_back" => ``(Builtin.slicePopBack)
| "slice_index" => ``(Builtin.sliceIndex)
| "slice_pop_front" => ``(Builtin.slicePopFront)
| "slice_insert" => ``(Builtin.sliceInsert)
| "ref" => ``(Builtin.ref)
| "read_ref" => ``(Builtin.readRef)
| "write_ref" => ``(Builtin.writeRef)
| "add" => `(Builtin.fAdd)
| "sub" => `(Builtin.fSub)
| "mul" => `(Builtin.fMul)
| "div" => `(Builtin.fDiv)
| "eq" => `(Builtin.fEq)
| "assert" => `(Builtin.assert)
| "not" => `(Builtin.bNot)
| "lt" => `(Builtin.lt)
| "index" => `(Builtin.index)
| "cast" => `(Builtin.cast)
| "modulus_num_bits" => `(Builtin.fModNumBits)
| "to_le_bytes" => `(Builtin.toLeBytes)
| "fresh" => `(Builtin.fresh)
| "slice_len" => `(Builtin.sliceLen)
| "slice_push_back" => `(Builtin.slicePushBack)
| "slice_push_front" => `(Builtin.slicePushFront)
| "slice_pop_back" => `(Builtin.slicePopBack)
| "slice_index" => `(Builtin.sliceIndex)
| "slice_pop_front" => `(Builtin.slicePopFront)
| "slice_insert" => `(Builtin.sliceInsert)
| "ref" => `(Builtin.ref)
| "read_ref" => `(Builtin.readRef)
| "write_ref" => `(Builtin.writeRef)
| _ => throwError "Unknown builtin {i}"

syntax ident ":" nr_type : nr_param_decl
Expand All @@ -108,12 +110,14 @@ syntax "if" nr_expr nr_expr ("else" nr_expr)? : nr_expr
syntax "for" ident "in" nr_expr ".." nr_expr nr_expr : nr_expr
syntax "(" nr_expr ")" : nr_expr
syntax "*(" nr_expr ")" : nr_expr
syntax "|" nr_param_decl,* "|" "->" nr_type nr_expr : nr_expr -- Lambda
syntax "[" nr_expr,* "]" : nr_expr -- Array constructor
syntax "&" "[" nr_expr,* "]" : nr_expr -- Slice constructor
syntax "|" nr_param_decl,* "|" "->" nr_type nr_expr : nr_expr -- Lambda constructor
syntax "#" nr_ident "(" nr_expr,* ")" ":" nr_type : nr_expr -- Builtin call
syntax "^" nr_ident "(" nr_expr,* ")" ":" nr_type : nr_expr -- Lambda call
syntax "@" nr_ident "<" nr_type,* ">" "(" nr_expr,* ")" ":" nr_type : nr_expr -- Decl call

syntax "(" nr_type "as" nr_ident "<" nr_type,* ">" ")" "::" nr_ident "<" nr_type,* ">" "(" nr_expr,* ")" ":" nr_type : nr_expr -- Trait call
syntax "(" nr_type "as" nr_ident "<" nr_type,* ">" ")"
"::" nr_ident "<" nr_type,* ">" "(" nr_expr,* ")" ":" nr_type : nr_expr -- Trait call
syntax nr_expr "[" nr_ident "<" nr_type,* ">" "." ident "]" : nr_expr -- Struct access
syntax nr_ident "<" nr_type,* ">" "{" nr_expr,* "}" : nr_expr -- Struct constructor

Expand All @@ -137,6 +141,12 @@ def Expr.readRef (ref : rep tp.ref): Expr rep tp :=
def Expr.writeRef (ref : rep tp.ref) (val : rep tp): Expr rep .unit :=
Expr.call h![] _ .unit (.builtin .writeRef) h![ref, val]

def Expr.slice (vals : HList rep tps) : Expr rep (.slice tp) :=
Expr.call h![] _ (.slice tp) (.builtin $ .mkSlice (tps.length)) vals

def Expr.array (vals : HList rep tps) : Expr rep (.array tp n) :=
Expr.call h![] _ (.array tp n) (.builtin $ .mkArray n.toNat) vals

structure DesugarState where
autoDeref : Name → Bool
nextFresh : Nat
Expand Down Expand Up @@ -198,34 +208,42 @@ partial def mkArgs [MonadSyntax m] (args : List (TSyntax `nr_expr)) (k : List (T
mkArgs t fun t => k (h :: t)

partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.Ident) (k : TSyntax `term → m (TSyntax `term)): m (TSyntax `term) := match e with
| `(nr_expr|$n:num : $tp) => do wrapSimple (←``(Lampe.Expr.lit $(←mkNrType tp) $n)) vname k
| `(nr_expr| true) => do wrapSimple (←``(Lampe.Expr.lit Tp.bool 1)) vname k
| `(nr_expr| false) => do wrapSimple (←``(Lampe.Expr.lit Tp.bool 0)) vname k
| `(nr_expr|$n:num : $tp) => do wrapSimple (←`(Lampe.Expr.lit $(←mkNrType tp) $n)) vname k
| `(nr_expr| true) => do wrapSimple (←`(Lampe.Expr.lit Tp.bool 1)) vname k
| `(nr_expr| false) => do wrapSimple (←`(Lampe.Expr.lit Tp.bool 0)) vname k
| `(nr_expr | & [ $args,* ]) => do
let args := args.getElems.toList
mkArgs args fun argVals => do
wrapSimple (←`(Lampe.Expr.slice $(←mkHListLit argVals))) vname k
| `(nr_expr | [ $args,* ]) => do
let args := args.getElems.toList
mkArgs args fun argVals => do
wrapSimple (←`(Lampe.Expr.array $(←mkHListLit argVals))) vname k
| `(nr_expr| { $exprs;* }) => mkBlock exprs.getElems.toList k
| `(nr_expr| $i:ident) => do
if ←isAutoDeref i.getId then wrapSimple (← ``(Lampe.Expr.readRef $i)) vname k else match vname with
if ←isAutoDeref i.getId then wrapSimple (← `(Lampe.Expr.readRef $i)) vname k else match vname with
| none => k i
| some _ => wrapSimple (←``(Lampe.Expr.var $i)) vname k
| some _ => wrapSimple (←`(Lampe.Expr.var $i)) vname k
| `(nr_expr| # $i:ident ($args,*): $tp) => do
mkArgs args.getElems.toList fun argVals => do
wrapSimple (←`(Lampe.Expr.call h![] _ $(←mkNrType tp) (.builtin $(←mkBuiltin i.getId.toString)) $(←mkHListLit argVals))) vname k
| `(nr_expr| for $i in $lo .. $hi $body) => do
mkExpr lo none fun lo =>
mkExpr hi none fun hi => do
let body ← mkExpr body none (fun x => ``(Lampe.Expr.var $x))
let body ← mkExpr body none (fun x => `(Lampe.Expr.var $x))
wrapSimple (←`(Lampe.Expr.loop $lo $hi fun $i => $body)) vname k
| `(nr_expr| $v:ident = $e) => do
mkExpr e none fun eVal => do
wrapSimple (←`(Lampe.Expr.writeRef $v $eVal)) vname k
| `(nr_expr| ( $e )) => mkExpr e vname k
| `(nr_expr| if $cond $mainBody else $elseBody) => do
mkExpr cond none fun cond => do
let mainBody ← mkExpr mainBody none fun x => ``(Lampe.Expr.var $x)
let elseBody ← mkExpr elseBody none fun x => ``(Lampe.Expr.var $x)
let mainBody ← mkExpr mainBody none fun x => `(Lampe.Expr.var $x)
let elseBody ← mkExpr elseBody none fun x => `(Lampe.Expr.var $x)
wrapSimple (←`(Lampe.Expr.ite $cond $mainBody $elseBody)) vname k
| `(nr_expr| if $cond $mainBody) => do
mkExpr cond none fun cond => do
let mainBody ← mkExpr mainBody none fun x => ``(Lampe.Expr.var $x)
let mainBody ← mkExpr mainBody none fun x => `(Lampe.Expr.var $x)
wrapSimple (←`(Lampe.Expr.ite $cond $mainBody (Lampe.Expr.skip))) vname k
| `(nr_expr| | $params,* | -> $outTp $lambdaBody) => do
let outTp ← mkNrType outTp
Expand All @@ -235,7 +253,7 @@ partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.I
let args ← mkHListLit (← params.getElems.toList.mapM fun param => match param with
| `(nr_param_decl|$i:ident : $_) => `($i)
| _ => throwUnsupportedSyntax)
let body ← mkExpr lambdaBody none fun x => ``(Lampe.Expr.var $x)
let body ← mkExpr lambdaBody none fun x => `(Lampe.Expr.var $x)
wrapSimple (←`(Lampe.Expr.lambda $argTps $outTp (fun $args => $body))) vname k
| `(nr_expr| ^ $i:ident ($args,*) : $tp) => do
mkArgs args.getElems.toList fun argVals => do
Expand All @@ -255,8 +273,8 @@ partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.I
let callGenVals ← mkHListLit (←callGenVals.getElems.toList.mapM fun gVal => mkNrType gVal)
mkArgs args.getElems.toList fun argVals => do
wrapSimple (←`(@Lampe.Expr.call _ $callGenKinds $callGenVals _ $(←mkNrType tp) (.decl $(Syntax.mkStrLit (←mkNrIdent declName))) $(←mkHListLit argVals))) vname k
| `(nr_expr| $structName:nr_ident < $genericVals,* > { $args,* }) => do
let structGenValsSyn ← mkHListLit (←genericVals.getElems.toList.mapM fun gVal => mkNrType gVal)
| `(nr_expr| $structName:nr_ident < $structGenVals,* > { $args,* }) => do
let structGenValsSyn ← mkHListLit (←structGenVals.getElems.toList.mapM fun gVal => mkNrType gVal)
let paramTpsSyn ← `(Struct.fieldTypes $(mkStructDefIdent $ ←mkNrIdent structName) $structGenValsSyn)
let structName ← mkNrIdent structName
mkArgs args.getElems.toList fun argVals => do
Expand Down Expand Up @@ -350,7 +368,7 @@ def mkStructProjector [Monad m] [MonadQuotation m] [MonadExceptOf Exception m] [
| _ => throwUnsupportedSyntax

elab "expr![" expr:nr_expr "]" : term => do
let term ← MonadSyntax.run $ mkExpr expr none fun x => ``(Expr.var $x)
let term ← MonadSyntax.run $ mkExpr expr none fun x => `(Expr.var $x)
Elab.Term.elabTerm term.raw none

elab "nrfn![" "fn" fn:nr_fn_decl "]" : term => do
Expand Down
25 changes: 25 additions & 0 deletions Lampe/Lampe/Tp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,29 @@ example : newMember [.bool, .field, .field] ⟨0, (by tauto)⟩ = Member.head :=
example : newMember [.bool, .field, .field] ⟨1, (by tauto)⟩ = Member.head.tail := rfl
example : newMember [.bool, .field, .field] ⟨2, (by tauto)⟩ = Member.head.tail.tail := rfl

lemma replicate_cons (hl : x :: xs = List.replicate n a) :
x = a ∧ xs = List.replicate (n-1) a := by
unfold List.replicate at hl
constructor
. aesop
. cases xs <;> aesop

@[reducible]
def HList.toList (hList : HList rep tps) (h_same : tps = List.replicate n tp) : List (rep tp) := match hList with
| .nil => []
| .cons x xs => match tps with
| [] => []
| _ :: _ => (by
have hl := replicate_cons h_same
obtain ⟨hl₁, hl₂⟩ := hl
exact (hl₁ ▸ x) :: (HList.toList xs hl₂))

theorem HList.toList_len_is_n (h_same : tps = List.replicate n tp) :
(HList.toList hl h_same).length = n := by
cases hl
aesop
sorry



end Lampe
Loading

0 comments on commit 82f69a1

Please sign in to comment.