From f3bb9f0c4f97b19a6204693b31a20bc9a8e2e8e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?William=20S=C3=B8rensen?= Date: Fri, 16 Aug 2024 10:55:07 +0100 Subject: [PATCH] refactor: extract handling of RecForms --- Qpf/Macro/Data/Ind.lean | 44 +++----------------------- Qpf/Macro/Data/RecForm.lean | 62 +++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 39 deletions(-) create mode 100644 Qpf/Macro/Data/RecForm.lean diff --git a/Qpf/Macro/Data/Ind.lean b/Qpf/Macro/Data/Ind.lean index a052154..ca092ef 100644 --- a/Qpf/Macro/Data/Ind.lean +++ b/Qpf/Macro/Data/Ind.lean @@ -1,3 +1,4 @@ +import Qpf.Macro.Data.RecForm import Qpf.Macro.Data.View import Qpf.Macro.Common import Mathlib.Data.QPF.Multivariate.Constructions.Fix @@ -7,30 +8,8 @@ open Lean.Parser (Parser) open Lean Meta Elab.Command Elab.Term Parser.Term open Lean.Parser.Tactic (inductionAlt) -/-- - The recursive form encodes how a function argument is recursive. - - Examples ty R α: - - α → R α → List (R α) → R α - [nonRec, directRec, composed ] --/ -inductive RecursionForm := - | nonRec (stx: Term) - | directRec - -- | composed -- Not supported yet -deriving Repr, BEq - -partial def getArgTypes (v : Term) : List Term := match v.raw with - | .node _ ``arrow #[arg, _, deeper] => - ⟨arg⟩ :: getArgTypes ⟨deeper⟩ - | rest => [⟨rest⟩] - def flattenForArg (n : Name) := Name.str .anonymous $ n.toStringWithSep "_" true -def containsStx (top : Term) (search : Term) : Bool := - (top.raw.find? (· == search)).isSome - /-- Both `bracketedBinder` and `matchAlts` have optional arguments, which cause them to not by recognized as parsers in quotation syntax (that is, ``` `(bracketedBinder| ...) ``` does not work). @@ -54,22 +33,6 @@ def addShapeToName : Name → Name section variable {m} [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [AddMessageContext m] -/-- Extract takes a constructor and extracts its recursive forms. - -This function assumes the pre-processor has run -It also assumes you don't have polymorphic recursive types such as -data Ql α | nil | l : α → Ql Bool → Ql α -/ -def extract (topName : Name) (view : CtorView) (rec_type : Term) : m $ Name × List RecursionForm := - (view.declName.replacePrefix topName .anonymous , ·) <$> (do - let some type := view.type? | pure [] - let type_ls := (getArgTypes ⟨type⟩).dropLast - - type_ls.mapM fun v => - if v == rec_type then pure .directRec - else if containsStx v rec_type then - throwErrorAt v.raw "Cannot handle composed recursive types" - else pure $ .nonRec v) - /-- Generate the binders for the different recursors -/ def mkRecursorBinder (rec_type : Term) (name : Name) @@ -87,6 +50,7 @@ def mkRecursorBinder let ty ← form.foldlM (fun acc => (match · with | ⟨.nonRec x, name⟩ => `(($name : $x) → $acc) | ⟨.directRec, name⟩ => `(($name : $rec_type) → $acc) + | ⟨.composed x, _⟩ => throwErrorAt x "Cannot handle recursive forms" )) out `(bb | ($(mkIdent $ flattenForArg name) : $ty)) @@ -174,6 +138,7 @@ def generateRecBody (ctors : Array (Name × List RecursionForm)) (includeMotive match f with | .directRec => `(⟨_, $nm⟩) | .nonRec _ => `(_) + | .composed _ => throwError "Cannot handle composed" let nonMotiveArgs ← names.mapM fun _ => `(_) let motiveArgs ← if includeMotive then @@ -181,6 +146,7 @@ def generateRecBody (ctors : Array (Name × List RecursionForm)) (includeMotive match f with | .directRec => some <$> `($nm) | .nonRec _ => pure none + | .composed _ => throwError "Cannot handle composed" else pure #[] @@ -194,7 +160,7 @@ def generateRecBody (ctors : Array (Name × List RecursionForm)) (includeMotive def genRecursors (view : DataView) : CommandElabM Unit := do let rec_type := view.getExpectedType - let mapped ← view.ctors.mapM (extract view.declName · rec_type) + let mapped := view.ctors.map (RecursionForm.extractWithName view.declName · rec_type) let ih_types ← mapped.mapM fun ⟨name, base⟩ => mkRecursorBinder (rec_type) (name) base true diff --git a/Qpf/Macro/Data/RecForm.lean b/Qpf/Macro/Data/RecForm.lean new file mode 100644 index 0000000..1d3bfa1 --- /dev/null +++ b/Qpf/Macro/Data/RecForm.lean @@ -0,0 +1,62 @@ +import Qpf.Macro.Data.Replace + +open Lean.Parser (Parser) +open Lean Meta Elab.Command Elab.Term Parser.Term +open Lean.Parser.Tactic (inductionAlt) + +/-- + The recursive form encodes how a function argument is recursive. + Examples ty R α: + α → R α → List (R α) → R α + [nonRec, directRec, composed ] +-/ +inductive RecursionForm := + | nonRec (stx : Term) + | directRec + | composed (stx : Term) -- Not supported yet +deriving Repr, BEq + +namespace RecursionForm + +variable {m} [Monad m] [MonadQuotation m] + +private def containsStx (top : Term) (search : Term) : Bool := + (top.raw.find? (· == search)).isSome + +partial def getArgTypes (v : Term) : List Term := match v.raw with + | .node _ ``arrow #[arg, _, deeper] => + ⟨arg⟩ :: getArgTypes ⟨deeper⟩ + | rest => [⟨rest⟩] + +partial def toType (retTy : Term) : List Term → m Term + | [] => pure retTy + | hd :: tl => do `($hd → $(← toType retTy tl)) + +/-- Extract takes a constructor and extracts its recursive forms. +This function assumes the pre-processor has run +It also assumes you don't have polymorphic recursive types such as +data Ql α | nil | l : α → Ql Bool → Ql α -/ +def extract (view : CtorView) (rec_type : Term) : List RecursionForm := do + if let some type := view.type? then + let type_ls := (getArgTypes ⟨type⟩).dropLast + + type_ls.map fun v => + if v == rec_type then .directRec + else if containsStx v rec_type then + .composed v + else .nonRec v + else [] + +def extractWithName (topName : Name) (view : CtorView) (rec_type : Term) : Name × List RecursionForm := + (view.declName.replacePrefix topName .anonymous , extract view rec_type) + +def replaceRec (old new : Term) : RecursionForm → Term + | .nonRec x => x + | .directRec => new + | .composed x => ⟨Replace.replaceAllStx old new x⟩ + +def toTerm (recType : Term) : RecursionForm → Term + | .nonRec x | .composed x => x + | .directRec => recType + +end RecursionForm