Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support constructor arguments given as explicit binders #18

Merged
merged 5 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Qpf/Macro/Data.lean
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,9 @@ open Macro Comp in
def elabData : CommandElab := fun stx => do
let modifiers ← elabModifiers stx[0]
let decl := stx[1]

let view ← dataSyntaxToView modifiers decl
let view ← preProcessCtors view -- Transforms binders into simple lambda types

let (nonRecView, ⟨r, shape, _P, eff⟩) ← runTermElabM fun _ => do
let (nonRecView, _rho) ← makeNonRecursive view;
Expand Down
45 changes: 36 additions & 9 deletions Qpf/Macro/Data/Replace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ structure CtorArgs where
(args : Array Name)
(per_type : Array (Array Name))

/- TODO(@William): make these correspond by combining expr and vars into a product -/
structure Replace where
(expr: Array Term)
(vars: Array Name)
Expand Down Expand Up @@ -59,7 +60,7 @@ def Replace.getBinders {m} [Monad m] [MonadQuotation m] (r : Replace) : m <| TSy




/- TODO: Figure out how to break this up into section -/



Expand Down Expand Up @@ -117,13 +118,9 @@ private partial def setResultingType (res_type : Syntax) : Syntax → ReplaceM m
| _ =>
pure res_type


-- TODO: this should be deprecated in favour of {v with ...} syntax
def CtorView.withType? (ctor : CtorView) (type? : Option Syntax) : CtorView := {
ref := ctor.ref
modifiers := ctor.modifiers
declName := ctor.declName
binders := ctor.binders
type? := type?
ctor with type?
}

/-
Expand All @@ -147,6 +144,34 @@ def Replace.run : ReplaceM m α → m (α × Replace) :=
let r ← Replace.new
StateT.run x r

-- Have a look at how, e.g., def, deals with binders,
-- this method might already exist look for BinderView
def getBinderView (ref : Syntax): m BinderView := match ref with
| .node _ `Lean.Parser.Term.explicitBinder
#[_, id, (.node _ `null #[_, ty]), _, _] =>
return .mk ref id ty .default
/- pure (ids, ty) -/
| _ => Elab.throwUnsupportedSyntax

/-- This function takes in a DataView with possibly explicit binders.
It then runs a simple scheme to translate them into (non-dependent) lambdas.
Then it also tries to infer an output type to handle the case with no type.
Finally it stiches all of this together to an output type-/
def preProcessCtors (view : DataView) : m DataView := do
let ctors ← view.ctors.mapM fun ctor => do
let namedArgs ← ctor.binders.getArgs.mapM getBinderView
let flatArgs :=
(namedArgs.map (fun b => b.id.getArgs.map (fun _ => ⟨b.type⟩)))
|>.flatten.reverse

let ty := if let some x := ctor.type? then x else view.getExpectedType

let out_ty ← flatArgs.foldlM (fun acc curr => `($curr → $acc)) (⟨ty⟩)

pure { ctor with binders := .missing, type? := some out_ty }

pure { view with ctors }

/--
Extract the constructors for a "shape" functor from the original constructors.
It replaces all constructor arguments with fresh variables, ensuring that repeated occurences
Expand All @@ -169,6 +194,8 @@ Replace.run <| do
let ctors := view.ctors

let pairs ← ctors.mapM fun ctor => do
/- We do not need to check for binders as the preprocessort fixes this.
We keep the test in case it goes wrong. -/
if !ctor.binders.isNone then
throwErrorAt ctor.binders "Constructor binders are not supported yet, please provide all arguments in the type"

Expand All @@ -184,7 +211,7 @@ Replace.run <| do
let ctors := pairs.map Prod.fst;
let ctorArgs := pairs.map fun ⟨_, ctorArgs⟩ =>
let per_type := ctorArgs.per_type

let diff := r.vars.size - ctorArgs.per_type.size

-- HACK: It seems that `Array.append` causes a stack overflow, so we go through `List` for now
Expand Down Expand Up @@ -268,4 +295,4 @@ def makeNonRecursive (view : DataView) : MetaM (DataView × Name) := do
return CtorView.withType? ctor type?

let view := view.setCtors ctors
pure (view, rec)
pure (view, rec)
2 changes: 1 addition & 1 deletion Test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ import Test.List
import Test.Misc
import Test.Tree
-- import Test.Variable
-- import Test.WithBindersInCtor
import Test.WithBindersInCtor
import Test.Wrap
14 changes: 11 additions & 3 deletions Test/WithBindersInCtor.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import Qpf
import Qpf.Macro.Data

data QpfListWithBinder α
data QpfListWithBinder α
| cons (h : α) (tl : QpfListWithBinder α)
| nil
| cons (hd : α) (tl : QpfListWithBinder α)

data Wrap α
| mk : α → Wrap α
data Wrap₂ α
| mk (a : α) : Wrap₂ α
data Wrap₃ α
| mk (a : α)

Loading