From 82f69a1a19757a203522495ad85d5a18178eaf5e Mon Sep 17 00:00:00 2001 From: utkn Date: Tue, 17 Dec 2024 00:26:33 +0100 Subject: [PATCH] small emitter improvements. slice and array syntax --- Lampe/Lampe.lean | 10 ++++ Lampe/Lampe/Ast.lean | 1 + Lampe/Lampe/Builtin/Array.lean | 14 +++++ Lampe/Lampe/Builtin/Slice.lean | 8 +++ Lampe/Lampe/Semantics.lean | 10 ++-- Lampe/Lampe/Syntax.lean | 96 ++++++++++++++++++++-------------- Lampe/Lampe/Tp.lean | 25 +++++++++ src/lean/builtin.rs | 67 ++++++++++++++++-------- src/lean/mod.rs | 61 ++++++++++----------- src/lean/syntax.rs | 36 +++++-------- src/lib.rs | 22 +++++--- 11 files changed, 220 insertions(+), 130 deletions(-) diff --git a/Lampe/Lampe.lean b/Lampe/Lampe.lean index 0eca852..a79d495 100644 --- a/Lampe/Lampe.lean +++ b/Lampe/Lampe.lean @@ -268,3 +268,13 @@ example {p} {x y : Tp.denote p .field} : simp_all . sl aesop + +nr_def createSlice<>() -> [Field] { + let a = &[1 : Field, 2 : Field]; + a +} + +nr_def createArray<>() -> [Field; 2] { + let a = [1 : Field, 2 : Field]; + a +} diff --git a/Lampe/Lampe/Ast.lean b/Lampe/Lampe/Ast.lean index fb62b94..16e33f5 100644 --- a/Lampe/Lampe/Ast.lean +++ b/Lampe/Lampe/Ast.lean @@ -33,6 +33,7 @@ inductive FunctionIdent (rep : Tp → Type) : Type where inductive Expr (rep : Tp → Type) : Tp → Type where | lit : (tp : Tp) → Nat → Expr rep tp +| list : List (Expr rep Tp.bool) → Expr rep (.slice tp) | var : rep tp → Expr rep tp | letIn : Expr rep t₁ → (rep t₁ → Expr rep t₂) → Expr rep t₂ | call : HList Kind.denote tyKinds → (argTypes : List Tp) → (res : Tp) → FunctionIdent rep → HList rep argTypes → Expr rep res diff --git a/Lampe/Lampe/Builtin/Array.lean b/Lampe/Lampe/Builtin/Array.lean index 896c910..c14adc5 100644 --- a/Lampe/Lampe/Builtin/Array.lean +++ b/Lampe/Lampe/Builtin/Array.lean @@ -1,6 +1,20 @@ import Lampe.Builtin.Basic namespace Lampe.Builtin +/-- +Defines the builtin array constructor. +-/ +def mkArray (n : Nat) := newGenericPureBuiltin + (fun (argTps, tp) => ⟨argTps, (.array tp n)⟩) + (fun (argTps, tp) args => ⟨argTps = List.replicate n tp ∧ n < 2^32, + fun h => Mathlib.Vector.ofFn fun i => List.get (HList.toList args (by tauto)) (by + have hn : BitVec.toNat (n := 32) ↑n = n := by + simp_all + rw [hn] at i + convert i + apply HList.toList_len_is_n + )⟩) + /-- Defines the function that evaluates to an array's length `n`. This builtin evaluates to an `U 32`. Hence, we assume that `n < 2^32`. diff --git a/Lampe/Lampe/Builtin/Slice.lean b/Lampe/Lampe/Builtin/Slice.lean index 36a9ffa..ba51286 100644 --- a/Lampe/Lampe/Builtin/Slice.lean +++ b/Lampe/Lampe/Builtin/Slice.lean @@ -1,6 +1,14 @@ import Lampe.Builtin.Basic namespace Lampe.Builtin +/-- +Defines the builtin slice constructor. +-/ +def mkSlice (n : Nat) := newGenericPureBuiltin + (fun (argTps, tp) => ⟨argTps, (.slice tp)⟩) + (fun (argTps, tp) args => ⟨argTps = List.replicate n tp, + fun h => HList.toList args h⟩) + /-- Defines the indexing of a slice `l : List tp` with `i : U 32` We make the following assumptions: diff --git a/Lampe/Lampe/Semantics.lean b/Lampe/Lampe/Semantics.lean index 6ae9857..5ae1c7a 100644 --- a/Lampe/Lampe/Semantics.lean +++ b/Lampe/Lampe/Semantics.lean @@ -35,13 +35,14 @@ inductive TraitResolution (Γ : Env): TraitImplRef → List (Ident × Function) TraitResolution Γ ref (impl.impl implGenerics) inductive Omni : Env → State p → Expr (Tp.denote p) tp → (Option (State p × Tp.denote p tp) → Prop) → Prop where +| skip {Q} : Q (some (st, ())) → Omni Γ st (.skip) Q | litField {Q} : Q (some (st, n)) → Omni Γ st (.lit .field n) Q +| litU {Q} : Q (some (st, ↑n)) → Omni Γ st (.lit (.u s) n) Q +| litI {Q} : Q (some (st, ↑n)) → Omni Γ st (.lit (.i s) n) Q | litFalse {Q} : Q (some (st, false)) → Omni Γ st (.lit .bool 0) Q | litTrue {Q} : Q (some (st, true)) → Omni Γ st (.lit .bool 1) Q | litRef {Q} : Q (some (st, ⟨r⟩)) → Omni Γ st (.lit (.ref tp) r) Q -| litU {Q} : Q (some (st, ↑n)) → Omni Γ st (.lit (.u s) n) Q | var {Q} : Q (some (st, v)) → Omni Γ st (.var v) Q -| skip {Q} : Q (some (st, ())) → Omni Γ st (.skip) Q | iteTrue {mainBranch elseBranch} : Omni Γ st mainBranch Q → Omni Γ st (Expr.ite true mainBranch elseBranch) Q @@ -115,11 +116,12 @@ theorem Omni.frame {p Γ tp} {st₁ st₂ : State p} {e : Expr (Tp.denote p) tp} ) := by intro h induction h with - | litField hq | skip hq + | litField hq + | litU hq + | litI hq | litFalse hq | litTrue hq - | litU hq | litRef hq | var hq => intro diff --git a/Lampe/Lampe/Syntax.lean b/Lampe/Lampe/Syntax.lean index 6ea8290..f70117d 100644 --- a/Lampe/Lampe/Syntax.lean +++ b/Lampe/Lampe/Syntax.lean @@ -36,6 +36,7 @@ syntax ident : nr_type syntax "${" term "}" : nr_type syntax nr_ident "<" nr_type,* ">" : nr_type syntax "[" nr_type "]" : nr_type +syntax "[" nr_type ";" term "]" : nr_type def mkListLit [Monad m] [MonadQuotation m] [MonadExceptOf Exception m] [MonadError m] : List (TSyntax `term) → m (TSyntax `term) | [] => `([]) @@ -65,32 +66,33 @@ partial def mkNrType [Monad m] [MonadQuotation m] [MonadExceptOf Exception m] [M `(Struct.tp $name $(←mkHListLit generics)) | `(nr_type| ${ $i }) => pure i | `(nr_type| [ $tp ]) => do `(Tp.slice $(←mkNrType tp)) +| `(nr_type| [ $tp ; $len:num ]) => do `(Tp.array $(←mkNrType tp) $len) | _ => throwUnsupportedSyntax partial def mkBuiltin [Monad m] [MonadQuotation m] [MonadExceptOf Exception m] [MonadError m] (i : String) : m (TSyntax `term) := match i with -| "add" => ``(Builtin.fAdd) -| "sub" => ``(Builtin.fSub) -| "mul" => ``(Builtin.fMul) -| "div" => ``(Builtin.fDiv) -| "eq" => ``(Builtin.fEq) -| "assert" => ``(Builtin.assert) -| "not" => ``(Builtin.bNot) -| "lt" => ``(Builtin.lt) -| "index" => ``(Builtin.index) -| "cast" => ``(Builtin.cast) -| "modulus_num_bits" => ``(Builtin.fModNumBits) -| "to_le_bytes" => ``(Builtin.toLeBytes) -| "fresh" => ``(Builtin.fresh) -| "slice_len" => ``(Builtin.sliceLen) -| "slice_push_back" => ``(Builtin.slicePushBack) -| "slice_push_front" => ``(Builtin.slicePushFront) -| "slice_pop_back" => ``(Builtin.slicePopBack) -| "slice_index" => ``(Builtin.sliceIndex) -| "slice_pop_front" => ``(Builtin.slicePopFront) -| "slice_insert" => ``(Builtin.sliceInsert) -| "ref" => ``(Builtin.ref) -| "read_ref" => ``(Builtin.readRef) -| "write_ref" => ``(Builtin.writeRef) +| "add" => `(Builtin.fAdd) +| "sub" => `(Builtin.fSub) +| "mul" => `(Builtin.fMul) +| "div" => `(Builtin.fDiv) +| "eq" => `(Builtin.fEq) +| "assert" => `(Builtin.assert) +| "not" => `(Builtin.bNot) +| "lt" => `(Builtin.lt) +| "index" => `(Builtin.index) +| "cast" => `(Builtin.cast) +| "modulus_num_bits" => `(Builtin.fModNumBits) +| "to_le_bytes" => `(Builtin.toLeBytes) +| "fresh" => `(Builtin.fresh) +| "slice_len" => `(Builtin.sliceLen) +| "slice_push_back" => `(Builtin.slicePushBack) +| "slice_push_front" => `(Builtin.slicePushFront) +| "slice_pop_back" => `(Builtin.slicePopBack) +| "slice_index" => `(Builtin.sliceIndex) +| "slice_pop_front" => `(Builtin.slicePopFront) +| "slice_insert" => `(Builtin.sliceInsert) +| "ref" => `(Builtin.ref) +| "read_ref" => `(Builtin.readRef) +| "write_ref" => `(Builtin.writeRef) | _ => throwError "Unknown builtin {i}" syntax ident ":" nr_type : nr_param_decl @@ -108,12 +110,14 @@ syntax "if" nr_expr nr_expr ("else" nr_expr)? : nr_expr syntax "for" ident "in" nr_expr ".." nr_expr nr_expr : nr_expr syntax "(" nr_expr ")" : nr_expr syntax "*(" nr_expr ")" : nr_expr -syntax "|" nr_param_decl,* "|" "->" nr_type nr_expr : nr_expr -- Lambda +syntax "[" nr_expr,* "]" : nr_expr -- Array constructor +syntax "&" "[" nr_expr,* "]" : nr_expr -- Slice constructor +syntax "|" nr_param_decl,* "|" "->" nr_type nr_expr : nr_expr -- Lambda constructor syntax "#" nr_ident "(" nr_expr,* ")" ":" nr_type : nr_expr -- Builtin call syntax "^" nr_ident "(" nr_expr,* ")" ":" nr_type : nr_expr -- Lambda call syntax "@" nr_ident "<" nr_type,* ">" "(" nr_expr,* ")" ":" nr_type : nr_expr -- Decl call - -syntax "(" nr_type "as" nr_ident "<" nr_type,* ">" ")" "::" nr_ident "<" nr_type,* ">" "(" nr_expr,* ")" ":" nr_type : nr_expr -- Trait call +syntax "(" nr_type "as" nr_ident "<" nr_type,* ">" ")" + "::" nr_ident "<" nr_type,* ">" "(" nr_expr,* ")" ":" nr_type : nr_expr -- Trait call syntax nr_expr "[" nr_ident "<" nr_type,* ">" "." ident "]" : nr_expr -- Struct access syntax nr_ident "<" nr_type,* ">" "{" nr_expr,* "}" : nr_expr -- Struct constructor @@ -137,6 +141,12 @@ def Expr.readRef (ref : rep tp.ref): Expr rep tp := def Expr.writeRef (ref : rep tp.ref) (val : rep tp): Expr rep .unit := Expr.call h![] _ .unit (.builtin .writeRef) h![ref, val] +def Expr.slice (vals : HList rep tps) : Expr rep (.slice tp) := + Expr.call h![] _ (.slice tp) (.builtin $ .mkSlice (tps.length)) vals + +def Expr.array (vals : HList rep tps) : Expr rep (.array tp n) := + Expr.call h![] _ (.array tp n) (.builtin $ .mkArray n.toNat) vals + structure DesugarState where autoDeref : Name → Bool nextFresh : Nat @@ -198,21 +208,29 @@ partial def mkArgs [MonadSyntax m] (args : List (TSyntax `nr_expr)) (k : List (T mkArgs t fun t => k (h :: t) partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.Ident) (k : TSyntax `term → m (TSyntax `term)): m (TSyntax `term) := match e with -| `(nr_expr|$n:num : $tp) => do wrapSimple (←``(Lampe.Expr.lit $(←mkNrType tp) $n)) vname k -| `(nr_expr| true) => do wrapSimple (←``(Lampe.Expr.lit Tp.bool 1)) vname k -| `(nr_expr| false) => do wrapSimple (←``(Lampe.Expr.lit Tp.bool 0)) vname k +| `(nr_expr|$n:num : $tp) => do wrapSimple (←`(Lampe.Expr.lit $(←mkNrType tp) $n)) vname k +| `(nr_expr| true) => do wrapSimple (←`(Lampe.Expr.lit Tp.bool 1)) vname k +| `(nr_expr| false) => do wrapSimple (←`(Lampe.Expr.lit Tp.bool 0)) vname k +| `(nr_expr | & [ $args,* ]) => do + let args := args.getElems.toList + mkArgs args fun argVals => do + wrapSimple (←`(Lampe.Expr.slice $(←mkHListLit argVals))) vname k +| `(nr_expr | [ $args,* ]) => do + let args := args.getElems.toList + mkArgs args fun argVals => do + wrapSimple (←`(Lampe.Expr.array $(←mkHListLit argVals))) vname k | `(nr_expr| { $exprs;* }) => mkBlock exprs.getElems.toList k | `(nr_expr| $i:ident) => do - if ←isAutoDeref i.getId then wrapSimple (← ``(Lampe.Expr.readRef $i)) vname k else match vname with + if ←isAutoDeref i.getId then wrapSimple (← `(Lampe.Expr.readRef $i)) vname k else match vname with | none => k i - | some _ => wrapSimple (←``(Lampe.Expr.var $i)) vname k + | some _ => wrapSimple (←`(Lampe.Expr.var $i)) vname k | `(nr_expr| # $i:ident ($args,*): $tp) => do mkArgs args.getElems.toList fun argVals => do wrapSimple (←`(Lampe.Expr.call h![] _ $(←mkNrType tp) (.builtin $(←mkBuiltin i.getId.toString)) $(←mkHListLit argVals))) vname k | `(nr_expr| for $i in $lo .. $hi $body) => do mkExpr lo none fun lo => mkExpr hi none fun hi => do - let body ← mkExpr body none (fun x => ``(Lampe.Expr.var $x)) + let body ← mkExpr body none (fun x => `(Lampe.Expr.var $x)) wrapSimple (←`(Lampe.Expr.loop $lo $hi fun $i => $body)) vname k | `(nr_expr| $v:ident = $e) => do mkExpr e none fun eVal => do @@ -220,12 +238,12 @@ partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.I | `(nr_expr| ( $e )) => mkExpr e vname k | `(nr_expr| if $cond $mainBody else $elseBody) => do mkExpr cond none fun cond => do - let mainBody ← mkExpr mainBody none fun x => ``(Lampe.Expr.var $x) - let elseBody ← mkExpr elseBody none fun x => ``(Lampe.Expr.var $x) + let mainBody ← mkExpr mainBody none fun x => `(Lampe.Expr.var $x) + let elseBody ← mkExpr elseBody none fun x => `(Lampe.Expr.var $x) wrapSimple (←`(Lampe.Expr.ite $cond $mainBody $elseBody)) vname k | `(nr_expr| if $cond $mainBody) => do mkExpr cond none fun cond => do - let mainBody ← mkExpr mainBody none fun x => ``(Lampe.Expr.var $x) + let mainBody ← mkExpr mainBody none fun x => `(Lampe.Expr.var $x) wrapSimple (←`(Lampe.Expr.ite $cond $mainBody (Lampe.Expr.skip))) vname k | `(nr_expr| | $params,* | -> $outTp $lambdaBody) => do let outTp ← mkNrType outTp @@ -235,7 +253,7 @@ partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.I let args ← mkHListLit (← params.getElems.toList.mapM fun param => match param with | `(nr_param_decl|$i:ident : $_) => `($i) | _ => throwUnsupportedSyntax) - let body ← mkExpr lambdaBody none fun x => ``(Lampe.Expr.var $x) + let body ← mkExpr lambdaBody none fun x => `(Lampe.Expr.var $x) wrapSimple (←`(Lampe.Expr.lambda $argTps $outTp (fun $args => $body))) vname k | `(nr_expr| ^ $i:ident ($args,*) : $tp) => do mkArgs args.getElems.toList fun argVals => do @@ -255,8 +273,8 @@ partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.I let callGenVals ← mkHListLit (←callGenVals.getElems.toList.mapM fun gVal => mkNrType gVal) mkArgs args.getElems.toList fun argVals => do wrapSimple (←`(@Lampe.Expr.call _ $callGenKinds $callGenVals _ $(←mkNrType tp) (.decl $(Syntax.mkStrLit (←mkNrIdent declName))) $(←mkHListLit argVals))) vname k -| `(nr_expr| $structName:nr_ident < $genericVals,* > { $args,* }) => do - let structGenValsSyn ← mkHListLit (←genericVals.getElems.toList.mapM fun gVal => mkNrType gVal) +| `(nr_expr| $structName:nr_ident < $structGenVals,* > { $args,* }) => do + let structGenValsSyn ← mkHListLit (←structGenVals.getElems.toList.mapM fun gVal => mkNrType gVal) let paramTpsSyn ← `(Struct.fieldTypes $(mkStructDefIdent $ ←mkNrIdent structName) $structGenValsSyn) let structName ← mkNrIdent structName mkArgs args.getElems.toList fun argVals => do @@ -350,7 +368,7 @@ def mkStructProjector [Monad m] [MonadQuotation m] [MonadExceptOf Exception m] [ | _ => throwUnsupportedSyntax elab "expr![" expr:nr_expr "]" : term => do - let term ← MonadSyntax.run $ mkExpr expr none fun x => ``(Expr.var $x) + let term ← MonadSyntax.run $ mkExpr expr none fun x => `(Expr.var $x) Elab.Term.elabTerm term.raw none elab "nrfn![" "fn" fn:nr_fn_decl "]" : term => do diff --git a/Lampe/Lampe/Tp.lean b/Lampe/Lampe/Tp.lean index 77191f4..1719335 100644 --- a/Lampe/Lampe/Tp.lean +++ b/Lampe/Lampe/Tp.lean @@ -81,4 +81,29 @@ example : newMember [.bool, .field, .field] ⟨0, (by tauto)⟩ = Member.head := example : newMember [.bool, .field, .field] ⟨1, (by tauto)⟩ = Member.head.tail := rfl example : newMember [.bool, .field, .field] ⟨2, (by tauto)⟩ = Member.head.tail.tail := rfl +lemma replicate_cons (hl : x :: xs = List.replicate n a) : + x = a ∧ xs = List.replicate (n-1) a := by + unfold List.replicate at hl + constructor + . aesop + . cases xs <;> aesop + +@[reducible] +def HList.toList (hList : HList rep tps) (h_same : tps = List.replicate n tp) : List (rep tp) := match hList with +| .nil => [] +| .cons x xs => match tps with + | [] => [] + | _ :: _ => (by + have hl := replicate_cons h_same + obtain ⟨hl₁, hl₂⟩ := hl + exact (hl₁ ▸ x) :: (HList.toList xs hl₂)) + +theorem HList.toList_len_is_n (h_same : tps = List.replicate n tp) : + (HList.toList hl h_same).length = n := by + cases hl + aesop + sorry + + + end Lampe diff --git a/src/lean/builtin.rs b/src/lean/builtin.rs index 82f8736..7eddfc5 100644 --- a/src/lean/builtin.rs +++ b/src/lean/builtin.rs @@ -8,6 +8,27 @@ use noirc_frontend::{ use itertools::Itertools; +pub const INDEX_BUILTIN_NAME: &str = "index"; +pub const ZEROED_BUILTIN_NAME: &str = "zeroed"; +pub const NEG_BUILTIN_NAME: &str = "neg"; +pub const NOT_BUILTIN_NAME: &str = "not"; +pub const ADD_BUILTIN_NAME: &str = "add"; +pub const SUB_BUILTIN_NAME: &str = "sub"; +pub const MUL_BUILTIN_NAME: &str = "mul"; +pub const DIV_BUILTIN_NAME: &str = "div"; +pub const MOD_BUILTIN_NAME: &str = "rem"; +pub const EQ_BUILTIN_NAME: &str = "eq"; +pub const NEQ_BUILTIN_NAME: &str = "neq"; +pub const GT_BUILTIN_NAME: &str = "gt"; +pub const LT_BUILTIN_NAME: &str = "lt"; +pub const GEQ_BUILTIN_NAME: &str = "geq"; +pub const LEQ_BUILTIN_NAME: &str = "leq"; +pub const AND_BUILTIN_NAME: &str = "and"; +pub const OR_BUILTIN_NAME: &str = "or"; +pub const XOR_BUILTIN_NAME: &str = "xor"; +pub const SHL_BUILTIN_NAME: &str = "shl"; +pub const SHR_BUILTIN_NAME: &str = "shr"; + #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum BuiltinType { Field, @@ -45,7 +66,7 @@ impl TryInto for Type { Type::Array(_, _) => todo!(), Type::Slice(_) => todo!(), Type::String(_) => todo!(), - _ => Err(format!("unknown builtin type {self}")), + _ => Err(format!("unknown builtin type `{:?}`", self)), } } } @@ -69,7 +90,7 @@ impl BuiltinType { } } -pub fn try_func_as_builtin(func_name: &str, func_meta: &FuncMeta) -> Option { +pub fn try_func_into_builtin(func_name: &str, func_meta: &FuncMeta) -> Option { let param_types: Result, _> = func_meta .parameters .0 @@ -78,22 +99,22 @@ pub fn try_func_as_builtin(func_name: &str, func_meta: &FuncMeta) -> Option Some(format!("zeroed")), + ("zeroed", _) => Some(ZEROED_BUILTIN_NAME.into()), _ => None, } } -pub fn try_prefix_as_builtin(op: UnaryOp, rhs_type: BuiltinType) -> Option { +pub fn try_prefix_into_builtin(op: UnaryOp, rhs_type: BuiltinType) -> Option { match op { - UnaryOp::Minus if rhs_type.is_arithmetic() => Some(format!("neg")), - UnaryOp::Not if rhs_type.is_bitwise() => Some(format!("not")), + UnaryOp::Minus if rhs_type.is_arithmetic() => Some(NEG_BUILTIN_NAME.into()), + UnaryOp::Not if rhs_type.is_bitwise() => Some(NOT_BUILTIN_NAME.into()), UnaryOp::MutableReference => todo!(), UnaryOp::Dereference { .. } => todo!(), _ => None, } } -pub fn try_infix_as_builtin( +pub fn try_infix_into_builtin( op: BinaryOpKind, lhs_type: BuiltinType, rhs_type: BuiltinType, @@ -103,24 +124,24 @@ pub fn try_infix_as_builtin( } match op { // Arithmetic - BinaryOpKind::Add if lhs_type.is_arithmetic() => Some(format!("add")), - BinaryOpKind::Subtract if lhs_type.is_arithmetic() => Some(format!("sub")), - BinaryOpKind::Divide if lhs_type.is_arithmetic() => Some(format!("div")), - BinaryOpKind::Multiply if lhs_type.is_arithmetic() => Some(format!("mul")), - BinaryOpKind::Modulo if lhs_type.is_arithmetic() => Some(format!("rem")), + BinaryOpKind::Add if lhs_type.is_arithmetic() => Some(ADD_BUILTIN_NAME.into()), + BinaryOpKind::Subtract if lhs_type.is_arithmetic() => Some(SUB_BUILTIN_NAME.into()), + BinaryOpKind::Divide if lhs_type.is_arithmetic() => Some(DIV_BUILTIN_NAME.into()), + BinaryOpKind::Multiply if lhs_type.is_arithmetic() => Some(MUL_BUILTIN_NAME.into()), + BinaryOpKind::Modulo if lhs_type.is_arithmetic() => Some(MOD_BUILTIN_NAME.into()), // Cmp - BinaryOpKind::Equal => Some(format!("eq")), - BinaryOpKind::NotEqual => Some(format!("neq")), - BinaryOpKind::Greater if lhs_type.is_arithmetic() => Some(format!("gt")), - BinaryOpKind::GreaterEqual if lhs_type.is_arithmetic() => Some(format!("geq")), - BinaryOpKind::Less if lhs_type.is_arithmetic() => Some(format!("lt")), - BinaryOpKind::LessEqual if lhs_type.is_arithmetic() => Some(format!("leq")), + BinaryOpKind::Equal => Some(EQ_BUILTIN_NAME.into()), + BinaryOpKind::NotEqual => Some(NEQ_BUILTIN_NAME.into()), + BinaryOpKind::Greater if lhs_type.is_arithmetic() => Some(GT_BUILTIN_NAME.into()), + BinaryOpKind::GreaterEqual if lhs_type.is_arithmetic() => Some(GEQ_BUILTIN_NAME.into()), + BinaryOpKind::Less if lhs_type.is_arithmetic() => Some(LT_BUILTIN_NAME.into()), + BinaryOpKind::LessEqual if lhs_type.is_arithmetic() => Some(LEQ_BUILTIN_NAME.into()), // Bit - BinaryOpKind::And if lhs_type.is_bitwise() => Some(format!("and")), - BinaryOpKind::Or if lhs_type.is_bitwise() => Some(format!("or")), - BinaryOpKind::Xor if lhs_type.is_bitwise() => Some(format!("xor")), - BinaryOpKind::ShiftLeft if lhs_type.is_bitwise() => Some(format!("shl")), - BinaryOpKind::ShiftRight if lhs_type.is_bitwise() => Some(format!("shr")), + BinaryOpKind::And if lhs_type.is_bitwise() => Some(AND_BUILTIN_NAME.into()), + BinaryOpKind::Or if lhs_type.is_bitwise() => Some(OR_BUILTIN_NAME.into()), + BinaryOpKind::Xor if lhs_type.is_bitwise() => Some(XOR_BUILTIN_NAME.into()), + BinaryOpKind::ShiftLeft if lhs_type.is_bitwise() => Some(SHL_BUILTIN_NAME.into()), + BinaryOpKind::ShiftRight if lhs_type.is_bitwise() => Some(SHR_BUILTIN_NAME.into()), _ => None, } } diff --git a/src/lean/mod.rs b/src/lean/mod.rs index 9fbee72..286208e 100644 --- a/src/lean/mod.rs +++ b/src/lean/mod.rs @@ -664,7 +664,7 @@ impl LeanEmitter { let rhs_ty = self.context.def_interner.id_type(infix.rhs); let lhs_builtin_ty = lhs_ty.try_into().unwrap(); let rhs_builtin_ty = rhs_ty.try_into().unwrap(); - let builtin_func_name = builtin::try_infix_as_builtin( + let builtin_name = builtin::try_infix_into_builtin( infix.operator.kind, lhs_builtin_ty, rhs_builtin_ty, @@ -674,18 +674,18 @@ impl LeanEmitter { let lhs = self.emit_expr(ind, infix.lhs)?; let rhs = self.emit_expr(ind, infix.rhs)?; - syntax::expr::format_infix_builtin_call(&builtin_func_name, &lhs, &rhs, &out_ty_str) + syntax::expr::format_infix_builtin_call(&builtin_name, &lhs, &rhs, &out_ty_str) } HirExpression::Prefix(prefix) => { let rhs_ty = self.context.def_interner.id_type(prefix.rhs); let rhs_builtin_ty = rhs_ty.try_into().unwrap(); - let builtin_func_name = - builtin::try_prefix_as_builtin(prefix.operator, rhs_builtin_ty) + let builtin_name = + builtin::try_prefix_into_builtin(prefix.operator, rhs_builtin_ty) .expect("not a builtin"); let rhs = self.emit_expr(ind, prefix.rhs)?; - syntax::expr::format_prefix_builtin_call(&builtin_func_name, &rhs, &out_ty_str) + syntax::expr::format_prefix_builtin_call(&builtin_name, &rhs, &out_ty_str) } HirExpression::Ident(ident, _) => { let name = self.context.def_interner.definition_name(ident.id); @@ -709,11 +709,11 @@ impl LeanEmitter { }; if let Some(builtin_fn_name) = - builtin::try_func_as_builtin(&fn_name, function_info) + builtin::try_func_into_builtin(&fn_name, function_info) { - syntax::expr::format_func_ident(&builtin_fn_name, &generics_str, true) + syntax::expr::format_builtin_ident(&builtin_fn_name) } else { - syntax::expr::format_func_ident(&fn_name, &generics_str, false) + syntax::expr::format_func_ident(&fn_name, &generics_str) } } DefinitionKind::Global(..) @@ -722,10 +722,14 @@ impl LeanEmitter { } } HirExpression::Index(index) => { - let collection = self.emit_expr(ind, index.collection)?; - let index = self.emit_expr(ind, index.index)?; + let index_builtin_ident = + syntax::expr::format_builtin_ident(builtin::INDEX_BUILTIN_NAME); - syntax::expr::format_index(&collection, &index) + let collection_expr_str = self.emit_expr(ind, index.collection)?; + let index_expr_str = self.emit_expr(ind, index.index)?; + let args_str = format!("{collection_expr_str}, {index_expr_str}"); + + syntax::expr::format_call(&index_builtin_ident, &args_str, &out_ty_str) } HirExpression::Literal(lit) => self.emit_literal(ind, lit, expr)?, HirExpression::Constructor(constructor) => { @@ -785,28 +789,11 @@ impl LeanEmitter { Type::Function(_, _, env) if matches!(*env, Type::Tuple(..)) => true, _ => false, }; - - syntax::expr::format_call(&function, &args_str, &out_ty_str, is_lambda) - } - HirExpression::MethodCall(method_call) => { - let receiver = self.emit_expr(ind, method_call.object)?; - let generics = match method_call.generics { - Some(gs) => { - let generic_strings = - gs.iter().map(|g| self.emit_fully_qualified_type(g)).collect_vec(); - generic_strings.join(", ") - } - _ => String::new(), - }; - - let arguments: Vec = method_call - .arguments - .iter() - .map(|arg| self.emit_expr(ind, *arg)) - .try_collect()?; - let args_string = arguments.join(", "); - - syntax::expr::format_method_call(&receiver, &generics, &args_string) + if is_lambda { + syntax::expr::format_lambda_call(&function, &args_str, &out_ty_str) + } else { + syntax::expr::format_call(&function, &args_str, &out_ty_str) + } } HirExpression::Cast(cast) => { let source = self.emit_expr(ind, cast.lhs)?; @@ -866,6 +853,9 @@ impl LeanEmitter { syntax::expr::format_lambda(&captures, &args, &body, &ret_type) } + HirExpression::MethodCall(_) => { + panic!("Method call expressions should not exist after type checking") + } HirExpression::Comptime(_) => { panic!("Comptime expressions should not exist after compilation is done") } @@ -875,7 +865,10 @@ impl LeanEmitter { HirExpression::Unquote(_) => { panic!("Unquote expressions should not exist after macro resolution") } - HirExpression::Error => panic!("Encountered error expression where none should exist"), + + HirExpression::Error => { + panic!("Encountered error expression where none should exist") + } }; Ok(expression) diff --git a/src/lean/syntax.rs b/src/lean/syntax.rs index 467c617..7900cd9 100644 --- a/src/lean/syntax.rs +++ b/src/lean/syntax.rs @@ -91,26 +91,18 @@ pub(super) mod expr { format!("{struct_ident}<{struct_generic_vals}> {{ {fields_ordered} }}") } - pub fn format_call(func_expr: &str, func_args: &str, out_ty: &str, is_lambda: bool) -> String { - if is_lambda { - format!("(^{func_expr}({func_args}) : {out_ty})") - } else if func_expr.starts_with(BUILTIN_PREFIX) { + pub fn format_lambda_call(lam_expr: &str, func_args: &str, out_ty: &str) -> String { + format!("(^{lam_expr}({func_args}) : {out_ty})") + } + + pub fn format_call(func_expr: &str, func_args: &str, out_ty: &str) -> String { + if func_expr.starts_with(BUILTIN_PREFIX) { format!("({func_expr}({func_args}) : {out_ty})") } else { format!("(@{func_expr}({func_args}) : {out_ty})") } } - #[inline] - pub fn format_method_call(receiver: &str, generic_vals: &str, args: &str) -> String { - format!("{receiver}<{generic_vals}>({args})") - } - - #[inline] - pub fn format_index(lhs_expr: &str, index: &str) -> String { - format!("{lhs_expr}[{index}]") - } - #[inline] pub fn format_member_access(struct_name: &str, target_expr: &str, member: Ident) -> String { format!("{target_expr}[{struct_name}.{member}]") @@ -150,13 +142,13 @@ pub(super) mod expr { normalize_ident(ident) } - pub fn format_func_ident(ident: &str, generics: &str, is_builtin: bool) -> String { + pub fn format_builtin_ident(builtin_name: &str) -> String { + format!("{BUILTIN_PREFIX}{builtin_name}") + } + + pub fn format_func_ident(ident: &str, generics: &str) -> String { let ident = normalize_ident(ident); - if is_builtin { - format!("{BUILTIN_PREFIX}{ident}") - } else { - format!("{ident}<{generics}>") - } + format!("{ident}<{generics}>") } #[inline] @@ -184,8 +176,8 @@ pub(super) mod stmt { use super::*; #[inline] - pub fn format_let_in(name: &str, binding_type: &str, bound_expr: &str) -> String { - format!("let {name}: {binding_type} = {bound_expr}") + pub fn format_let_in(name: &str, _binding_type: &str, bound_expr: &str) -> String { + format!("let {name} = {bound_expr}") } #[inline] diff --git a/src/lib.rs b/src/lib.rs index 5759fec..8c7ff32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,6 @@ pub use crate::noir::project::Project; /// The source type for use with the library, exported here for easy access. pub use crate::noir::source::Source; - /// Takes the definition of a Noir project and converts it into equivalent /// definitions in the Lean theorem prover and programming language. /// @@ -99,10 +98,10 @@ mod test { // fn assigns(x: u8) { // let mut y = 3; // y += x; - // + // let mut foo = Option2::none(); // foo._is_some = false; - // + // let mut arr = [1, 2]; // arr[0] = 10; // } @@ -146,11 +145,11 @@ mod test { Self { _is_some: false, _value: std::unsafe::zeroed() } } - // /// Constructs a Some wrapper around the given value - // pub fn some(_value: T) -> Self { - // Self { _is_some: true, _value } - // } - // + /// Constructs a Some wrapper around the given value + pub fn some(_value: T) -> Self { + Self { _is_some: true, _value } + } + /// True if this Option is None pub fn is_none(self) -> bool { !self.is_some() @@ -173,6 +172,13 @@ mod test { self } } + + fn main() { + let opt = Option2::some(5); + opt.is_some(); + let l = &[1, 2, 3]; + l[0]; + } "#; let source = Source::new(file_name, source);