Skip to content

Commit

Permalink
Merge branch 'extract-rec-form-handling' into generalize-constructor-…
Browse files Browse the repository at this point in the history
…generation-further
  • Loading branch information
William Sørensen committed Aug 16, 2024
2 parents ab6aa44 + f3bb9f0 commit c6ee24d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 39 deletions.
44 changes: 5 additions & 39 deletions Qpf/Macro/Data/Ind.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Qpf.Macro.Data.RecForm
import Qpf.Macro.Data.View
import Qpf.Macro.Common
import Mathlib.Data.QPF.Multivariate.Constructions.Fix
Expand All @@ -7,30 +8,8 @@ open Lean.Parser (Parser)
open Lean Meta Elab.Command Elab.Term Parser.Term
open Lean.Parser.Tactic (inductionAlt)

/--
The recursive form encodes how a function argument is recursive.
Examples ty R α:
α → R α → List (R α) → R α
[nonRec, directRec, composed ]
-/
inductive RecursionForm :=
| nonRec (stx: Term)
| directRec
-- | composed -- Not supported yet
deriving Repr, BEq

partial def getArgTypes (v : Term) : List Term := match v.raw with
| .node _ ``arrow #[arg, _, deeper] =>
⟨arg⟩ :: getArgTypes ⟨deeper⟩
| rest => [⟨rest⟩]

def flattenForArg (n : Name) := Name.str .anonymous $ n.toStringWithSep "_" true

def containsStx (top : Term) (search : Term) : Bool :=
(top.raw.find? (· == search)).isSome

/-- Both `bracketedBinder` and `matchAlts` have optional arguments,
which cause them to not by recognized as parsers in quotation syntax
(that is, ``` `(bracketedBinder| ...) ``` does not work).
Expand All @@ -54,22 +33,6 @@ def addShapeToName : Name → Name
section
variable {m} [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [AddMessageContext m]

/-- Extract takes a constructor and extracts its recursive forms.
This function assumes the pre-processor has run
It also assumes you don't have polymorphic recursive types such as
data Ql α | nil | l : α → Ql Bool → Ql α -/
def extract (topName : Name) (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm :=
(view.declName.replacePrefix topName .anonymous , ·) <$> (do
let some type := view.type? | pure []
let type_ls := (getArgTypes ⟨type⟩).dropLast

type_ls.mapM fun v =>
if v == rec_type then pure .directRec
else if containsStx v rec_type then
throwErrorAt v.raw "Cannot handle composed recursive types"
else pure $ .nonRec v)

/-- Generate the binders for the different recursors -/
def mkRecursorBinder
(rec_type : Term) (name : Name)
Expand All @@ -87,6 +50,7 @@ def mkRecursorBinder
let ty ← form.foldlM (fun acc => (match · with
| ⟨.nonRec x, name⟩ => `(($name : $x) → $acc)
| ⟨.directRec, name⟩ => `(($name : $rec_type) → $acc)
| ⟨.composed x, _⟩ => throwErrorAt x "Cannot handle recursive forms"
)) out

`(bb | ($(mkIdent $ flattenForArg name) : $ty))
Expand Down Expand Up @@ -174,13 +138,15 @@ def generateRecBody (ctors : Array (Name × List RecursionForm)) (includeMotive
match f with
| .directRec => `(⟨_, $nm⟩)
| .nonRec _ => `(_)
| .composed _ => throwError "Cannot handle composed"

let nonMotiveArgs ← names.mapM fun _ => `(_)
let motiveArgs ← if includeMotive then
names.filterMapM fun ⟨nm, f⟩ =>
match f with
| .directRec => some <$> `($nm)
| .nonRec _ => pure none
| .composed _ => throwError "Cannot handle composed"
else pure #[]


Expand All @@ -194,7 +160,7 @@ def generateRecBody (ctors : Array (Name × List RecursionForm)) (includeMotive
def genRecursors (view : DataView) : CommandElabM Unit := do
let rec_type := view.getExpectedType

let mapped view.ctors.mapM (extract view.declName · rec_type)
let mapped := view.ctors.map (RecursionForm.extractWithName view.declName · rec_type)

let ih_types ← mapped.mapM fun ⟨name, base⟩ =>
mkRecursorBinder (rec_type) (name) base true
Expand Down
62 changes: 62 additions & 0 deletions Qpf/Macro/Data/RecForm.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import Qpf.Macro.Data.Replace

open Lean.Parser (Parser)
open Lean Meta Elab.Command Elab.Term Parser.Term
open Lean.Parser.Tactic (inductionAlt)

/--
The recursive form encodes how a function argument is recursive.
Examples ty R α:
α → R α → List (R α) → R α
[nonRec, directRec, composed ]
-/
inductive RecursionForm :=
| nonRec (stx : Term)
| directRec
| composed (stx : Term) -- Not supported yet
deriving Repr, BEq

namespace RecursionForm

variable {m} [Monad m] [MonadQuotation m]

private def containsStx (top : Term) (search : Term) : Bool :=
(top.raw.find? (· == search)).isSome

partial def getArgTypes (v : Term) : List Term := match v.raw with
| .node _ ``arrow #[arg, _, deeper] =>
⟨arg⟩ :: getArgTypes ⟨deeper⟩
| rest => [⟨rest⟩]

partial def toType (retTy : Term) : List Term → m Term
| [] => pure retTy
| hd :: tl => do `($hd → $(← toType retTy tl))

/-- Extract takes a constructor and extracts its recursive forms.
This function assumes the pre-processor has run
It also assumes you don't have polymorphic recursive types such as
data Ql α | nil | l : α → Ql Bool → Ql α -/
def extract (view : CtorView) (rec_type : Term) : List RecursionForm := do
if let some type := view.type? then
let type_ls := (getArgTypes ⟨type⟩).dropLast

type_ls.map fun v =>
if v == rec_type then .directRec
else if containsStx v rec_type then
.composed v
else .nonRec v
else []

def extractWithName (topName : Name) (view : CtorView) (rec_type : Term) : Name × List RecursionForm :=
(view.declName.replacePrefix topName .anonymous , extract view rec_type)

def replaceRec (old new : Term) : RecursionForm → Term
| .nonRec x => x
| .directRec => new
| .composed x => ⟨Replace.replaceAllStx old new x⟩

def toTerm (recType : Term) : RecursionForm → Term
| .nonRec x | .composed x => x
| .directRec => recType

end RecursionForm

0 comments on commit c6ee24d

Please sign in to comment.