Skip to content

Commit

Permalink
lens fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
utkn committed Dec 23, 2024
1 parent f8f35c2 commit 3461081
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 27 deletions.
39 changes: 29 additions & 10 deletions Lampe/Lampe.lean
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ example {a b : Tp.denote p .field} :

nr_def structProjection<>(x : Field, y : Field) -> Field {
let s = Pair<Field> { x, y };
(s as Pair<Field>).a : Field
(s as Pair<Field>).a
}

example {x y : Tp.denote p .field} :
Expand All @@ -253,7 +253,7 @@ example {x y : Tp.denote p .field} :
nr_def structWrite<>(x : Field, y : Field) -> Field {
let mut s = Pair<Field> { x, y };
(s as Pair<Field>).a = (5 : Field);
(s as Pair<Field>).a : Field
(s as Pair<Field>).a
}

example {_: 5 < p.natVal} {x y : Tp.denote p .field} :
Expand Down Expand Up @@ -290,22 +290,22 @@ example {_: 5 < p.natVal} :

nr_def callDecl<>(x: Field, y : Field) -> Field {
let s = @structConstruct<>(x, y) : Pair<Field>;
(s as Pair<Field>).a : Field
(s as Pair<Field>).a
}

example {x y : Tp.denote p .field} :
STHoare p ⟨[(structConstruct.name, structConstruct.fn)], []⟩
⟦⟧ (callDecl.fn.body _ h![] |>.body h![x, y]) (fun v => v = x) := by
simp only [callDecl]
steps
rotate_right 1
exact (fun v => v.fst = x)
all_goals tauto
steps <;> tauto
. simp only [structConstruct]
steps
simp_all
. sl
aesop
simp_all [SLP.wand, SLP.entails, SLP.forall']
. intros
generalize («Pair#a» _) = mem at *
simp only at mem
subst_vars
sorry

nr_def createSlice<>() -> [bool] {
&[true, false]
Expand All @@ -322,3 +322,22 @@ nr_def createArray<>() -> [Field; 2] {
example : STHoare p Γ ⟦⟧ (createArray.fn.body _ h![] |>.body h![]) (fun v => v.toList.get? 1 = some 2) := by
simp only [createArray, Expr.array]
steps <;> aesop

nr_struct_def Lens <> {
a : `(Field, Field),
}

nr_def simpleLens<>() -> Field {
let s = Lens<> { `(1 : Field, 2 : Field) };
((s as Lens<>).a).1 : Field
}

example {_ : 2 < p.natVal} :
STHoare p Γ ⟦⟧ (simpleLens.fn.body _ h![] |>.body h![]) fun v => v.val = 2 := by
simp only [simpleLens]
steps
intros
simp_all
subst_vars
apply ZMod.val_cast_of_lt
tauto
43 changes: 26 additions & 17 deletions Lampe/Lampe/Syntax.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ declare_syntax_cat nr_type
declare_syntax_cat nr_expr
declare_syntax_cat nr_block_contents
declare_syntax_cat nr_param_decl
declare_syntax_cat nr_lval

syntax ident : nr_ident
syntax ident "::" nr_ident : nr_ident
Expand Down Expand Up @@ -60,6 +61,9 @@ def mkBuiltin [Monad m] [MonadQuotation m] [MonadExceptOf Exception m] [MonadErr
@[reducible]
def typeof {rep : Tp → Type _} (_ : rep tp) := tp

@[reducible]
def typeofMem (_ : Builtin.Member tp tps) := tp

@[reducible]
def tupleFields (tp : Tp) := match tp with
| Tp.tuple _ fields => fields
Expand Down Expand Up @@ -90,7 +94,15 @@ partial def mkNrType [Monad m] [MonadQuotation m] [MonadExceptOf Exception m] [M
| `(nr_type| ${ $i }) => pure i
| _ => throwUnsupportedSyntax

-- nr_param_decl
syntax ident ":" nr_type : nr_param_decl

-- nr_lval
syntax ident : nr_lval
syntax nr_lval "[" nr_expr "]" : nr_lval
syntax nr_lval "." num : nr_lval

-- nr_expr
syntax num ":" nr_type : nr_expr -- Literal
syntax ident : nr_expr -- Reference
syntax "{" sepBy(nr_expr, ";", ";", allowTrailingSep) "}" : nr_expr -- Block
Expand All @@ -112,7 +124,7 @@ syntax "^" nr_ident "(" nr_expr,* ")" ":" nr_type : nr_expr -- Lambda call
syntax "@" nr_ident "<" nr_type,* ">" "(" nr_expr,* ")" ":" nr_type : nr_expr -- Decl call
syntax "(" nr_type "as" nr_ident "<" nr_type,* ">" ")"
"::" nr_ident "<" nr_type,* ">" "(" nr_expr,* ")" ":" nr_type : nr_expr -- Trait call
syntax "(" nr_expr "as" nr_ident "<" nr_type,* ">" ")" "." ident ":" nr_type : nr_expr -- Struct access
syntax "(" nr_expr "as" nr_ident "<" nr_type,* ">" ")" "." ident : nr_expr -- Struct access
syntax nr_expr "." num ":" nr_type : nr_expr -- Tuple access

syntax nr_fn_decl := nr_ident "<" ident,* ">" "(" nr_param_decl,* ")" "->" nr_type "{" sepBy(nr_expr, ";", ";", allowTrailingSep) "}"
Expand All @@ -122,10 +134,6 @@ syntax nr_trait_impl := "<" ident,* ">" nr_ident "<" nr_type,* ">" "for" nr_type
"{" sepBy(nr_trait_fn_def, ";", ";", allowTrailingSep) "}"
syntax nr_struct_def := "<" ident,* ">" "{" sepBy(nr_param_decl, ",", ",", allowTrailingSep) "}"

def Expr.letMutIn (definition : Expr rep tp) (body : rep tp.ref → Expr rep tp') : Expr rep tp' :=
let refDef := Expr.letIn definition fun v => Expr.call h![] _ (tp.ref) (.builtin .ref) h![v]
Expr.letIn refDef body

def Expr.ref (val : rep tp) : Expr rep tp.ref :=
Expr.call h![] _ tp.ref (.builtin .ref) h![val]

Expand Down Expand Up @@ -176,7 +184,7 @@ def wrapSimple [MonadSyntax m] (e : TSyntax `term) (ident : Option Lean.Ident) (
`(Lampe.Expr.letIn $e fun $ident => $rest)

def mkFieldName (structName : String) (fieldName : String) : Lean.Ident :=
mkIdent $ Name.mkSimple (structName |>.append "#" |>.append fieldName)
mkIdent $ Name.mkSimple (structName ++ "#" ++ fieldName)

def mkStructDefIdent (structName : String) : Lean.Ident :=
mkIdent $ Name.mkSimple structName
Expand All @@ -197,7 +205,7 @@ partial def mkBlock [MonadSyntax m] (items: List (TSyntax `nr_expr)) (k : TSynta
let body ← mkBlock (n :: rest) k
`(Lampe.Expr.letIn (Expr.ref $eVal) fun $v => $body)
| e => do
mkExpr e none fun _ => mkBlock (n::rest) k
mkExpr e none fun _ => mkBlock (n :: rest) k
| [e] => match e with
| `(nr_expr | let $_ = $e)
| `(nr_expr | let mut $_ = $e)
Expand All @@ -210,7 +218,7 @@ partial def mkArgs [MonadSyntax m] (args : List (TSyntax `nr_expr)) (k : List (T
mkExpr h none fun h => do
mkArgs t fun t => k (h :: t)

partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.Ident) (k : TSyntax `term → m (TSyntax `term)): m (TSyntax `term) := match e with
partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.Ident) (k : TSyntax `term → m (TSyntax `term)) : m (TSyntax `term) := match e with
| `(nr_expr|$n:num : $tp) => do wrapSimple (←`(Lampe.Expr.lit $(←mkNrType tp) $n)) vname k
| `(nr_expr| true) => do wrapSimple (←`(Lampe.Expr.lit Tp.bool 1)) vname k
| `(nr_expr| false) => do wrapSimple (←`(Lampe.Expr.lit Tp.bool 0)) vname k
Expand All @@ -226,7 +234,7 @@ partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.I
wrapSimple (←`(Lampe.Expr.array $(Syntax.mkNumLit $ toString len) $(←mkHListLit argVals))) vname k
| `(nr_expr| { $exprs;* }) => mkBlock exprs.getElems.toList k
| `(nr_expr| $i:ident) => do
if ←isAutoDeref i.getId then wrapSimple (← `(Lampe.Expr.readRef $i)) vname k else match vname with
if ←isAutoDeref i.getId then wrapSimple (←`(Lampe.Expr.readRef $i)) vname k else match vname with
| none => k i
| some _ => wrapSimple (←`(Lampe.Expr.var $i)) vname k
| `(nr_expr| # $i:ident ($args,*): $tp) => do
Expand Down Expand Up @@ -288,13 +296,19 @@ partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.I
mkArgs args.getElems.toList fun argVals => do
let argTps ← argVals.mapM fun arg => `(typeof $arg)
wrapSimple (←`(Lampe.Expr.call h![] _ (.tuple none $(←mkListLit argTps)) (.builtin Builtin.mkTuple) $(←mkHListLit argVals))) vname k
| `(nr_expr| ( $structExpr:nr_expr as $structName:nr_ident < $structGenVals,* > ) . $structField:ident : $outTy:nr_type ) => do
let outTp ← mkNrType outTy
| `(nr_expr| $tupleExpr:nr_expr . $idx:num : $outTp:nr_type) => do
let accessorSyn ← mkRecMember idx.getNat
let outTp ← mkNrType outTp
mkExpr tupleExpr none fun t => do
wrapSimple (←`(Lampe.Expr.call h![] [typeof $t] $outTp (.builtin (@Builtin.projectTuple $outTp (tupleFields $ typeof $t) $accessorSyn)) h![$t])) vname k
| `(nr_expr| ( $structExpr:nr_expr as $structName:nr_ident < $structGenVals,* > ) . $structField:ident ) => do
let structGenValsSyn ← mkHListLit (←structGenVals.getElems.toList.mapM fun gVal => mkNrType gVal)
let accessor := mkFieldName (←mkNrIdent structName) (structField.getId.toString)
let accessorSyn ← `($accessor $structGenValsSyn)
-- let outTp ← mkNrType outTp
let outTp ← `(typeofMem $accessorSyn)
mkExpr structExpr none fun s => do
`(Lampe.Expr.call h![] [typeof $s] $outTp (.builtin (@Builtin.projectTuple $outTp _ $accessorSyn)) h![$s])
wrapSimple (←`(Lampe.Expr.call h![] [typeof $s] $outTp (.builtin (@Builtin.projectTuple $outTp _ $accessorSyn)) h![$s])) vname k
| `(nr_expr| ( $r:ident as $structName:nr_ident < $structGenVals,* > ) . $structField:ident = $rhs:nr_expr) => do
let structGenValsSyn ← mkHListLit (←structGenVals.getElems.toList.mapM fun gVal => mkNrType gVal)
let accessor := mkFieldName (←mkNrIdent structName) (structField.getId.toString)
Expand All @@ -305,11 +319,6 @@ partial def mkExpr [MonadSyntax m] (e : TSyntax `nr_expr) (vname : Option Lean.I
mkExpr rhs none fun rhs => do
let accessorSyn ← mkRecMember idx.getNat
wrapSimple (←`(Lampe.Expr.tupleWriteMember $r $accessorSyn $rhs)) vname k
| `(nr_expr| $tupleExpr:nr_expr . $idx:num : $outTy:nr_type) => do
let outTp ← mkNrType outTy
mkExpr tupleExpr none fun t => do
let accessorSyn ← mkRecMember idx.getNat
`(Lampe.Expr.call h![] [typeof $t] $outTp (.builtin (@Builtin.projectTuple $outTp (tupleFields $ typeof $t) $accessorSyn)) h![$t])
| _ => throwUnsupportedSyntax

end
Expand Down

0 comments on commit 3461081

Please sign in to comment.