Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: composition pipeline and qpf body parsing #19

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 109 additions & 91 deletions Qpf/Macro/Comp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -152,118 +152,132 @@ def DVec.toExpr {n : Nat} {αs : Q(Fin2 $n → Type u)} (xs : DVec (fun (i : Fin
| .fs i => $as i
)

def Vector.mapM
{m : Type u → Type v}
{α : Type w}
{β : Type u}
[Monad m]
(l : Vector α n)
(f : α → m β) : m $ Vector β n :=
match l with
| ⟨.cons hd tl, b⟩ => do
let v ← f hd
let x ← Vector.mapM ⟨tl, by rfl⟩ f
pure ⟨v :: x.val, by {
rw [List.length_cons] at b
rw [List.length_cons, x.property]
exact b
}⟩
| ⟨.nil, h⟩ => pure ⟨[], by exact h⟩
Comment on lines +155 to +171
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Vector.mmap exists upstream, let's use that one

A general note on style:

  • We generally prefer putting variables on fewer lines, rather than getting a very tall definition
  • The type (and binders) should be indented one level more than the body of the definition
  • There's no need for { ... } around the multiline proof, by is white-space sensitive
  • rfl is also a term, there is (usually!) no need to use the tactic version (as in by rfl)
Suggested change
def Vector.mapM
{m : Type u → Type v}
{α : Type w}
{β : Type u}
[Monad m]
(l : Vector α n)
(f : α → m β) : m $ Vector β n :=
match l with
| ⟨.cons hd tl, b⟩ => do
let v ← f hd
let x ← Vector.mapM ⟨tl, by rfl⟩ f
pure ⟨v :: x.val, by {
rw [List.length_cons] at b
rw [List.length_cons, x.property]
exact b
}⟩
| ⟨.nil, h⟩ => pure ⟨[], by exact h⟩
def Vector.mapM {m : Type u → Type v}
{α : Type w} {β : Type u} [Monad m]
(l : Vector α n) (f : α → m β) : m $ Vector β n :=
match l with
| ⟨.cons hd tl, b⟩ => do
let v ← f hd
let x ← Vector.mapM ⟨tl, rfl⟩ f
pure ⟨v :: x.val, by simpa [List.length_cons, x.property] using b⟩
| ⟨.nil, h⟩ => pure ⟨[], by exact h⟩


structure ElabQpfResult (u : Level) (arity : Nat) where
F : Q(TypeFun.{u,u} $arity)
functor : Q(MvFunctor $F)
qpf : Q(@MvQPF _ $F $functor)
deriving Inhabited

open PrettyPrinter in
/--
Elaborate the body of a qpf
-/
partial def elabQpf {arity : Nat} (vars : Vector Q(Type u) arity) (target : Q(Type u)) (targetStx : Option Term := none) (normalized := false) :
TermElabM (ElabQpfResult u arity) := do
trace[QPF] "elabQPF :: {vars.toList} -> {target}"
let vars' := vars.toList;
/- def isLiveVar (varIds : Vector FVarId n) (id : FVarId) := (List.indexOf' id varIds.toList).isSome -/
def isLiveVar (varIds : Vector FVarId n) (id : FVarId) := varIds.toList.contains id

let varIds := vars'.map fun expr => expr.fvarId!
let isLiveVar : FVarId → Bool
:= fun fvarId => (List.indexOf' fvarId varIds).isSome
open PrettyPrinter in
mutual
partial def handleLiveFVar (vars : Vector FVarId arity) (target : FVarId) : TermElabM (ElabQpfResult u arity) := do
trace[QPF] f!"target {Expr.fvar target} is a free variable"
let ind ← match List.indexOf' target vars.toList with
| none => throwError "Free variable {Expr.fvar target} is not one of the qpf arguments"
| some ind => pure ind

if target.isFVar && isLiveVar target.fvarId! then
trace[QPF] f!"target {target} is a free variable"
let ind ← match List.indexOf' target vars' with
| none => throwError "Free variable {target} is not one of the qpf arguments"
| some ind => pure ind
let ind : Fin2 arity := cast (by simp) ind.inv
let prj := q(@Prj.{u} $arity $ind)
trace[QPF] "represented by: {prj}"

let ind : Fin2 arity := cast (by simp [vars']) ind.inv
let prj := q(@Prj.{u} $arity $ind)
trace[QPF] "represented by: {prj}"
pure ⟨prj, q(Prj.mvfunctor _), q(Prj.mvqpf _)⟩
return { F := prj, functor := q(Prj.mvfunctor _), qpf := q(Prj.mvqpf _) }

else if !target.hasAnyFVar isLiveVar then
trace[QPF] "target {target} is a constant"
let const := q(Const.{u} $arity $target)
trace[QPF] "represented by: {const}"
pure ⟨const, q(Const.MvFunctor), q(Const.mvqpf)⟩
partial def handleConst (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do
trace[QPF] "target {target} is a constant"
let const := q(Const.{u} $arity $target)
trace[QPF] "represented by: {const}"

else if target.isApp then
let ⟨m, F, args⟩ ← (Comp.parseApp isLiveVar target)
trace[QPF] "target {target} is an application of {F} to {args.toList}"

/-
Optimization: check if the application is of the form `F α β γ .. = F α β γ ..`.
In such cases, we can directly return `F`, rather than generate a composition of projections.
-/
let is_trivial :=
args.length == arity
&& args.toList.enum.all fun ⟨i, arg⟩ =>
arg.isFVar && isLiveVar arg.fvarId! && vars'.indexOf arg == i
if is_trivial then
trace[QPF] "The application is trivial"
let mvFunctor ← synthInstanceQ q(MvFunctor $F)
let mvQPF ← synthInstanceQ q(MvQPF $F)
pure ⟨F, mvFunctor, mvQPF⟩
else
let G ← args.toList.mapM fun a =>
elabQpf vars a none false

/-
HACK: We know that `m`, which was defines as `args.length`, is equal to `G.length`.
It's a bit difficult to prove this. Thus, we simply assert it
-/
if hm : m ≠ G.length then
throwError "This shouldn't happen" -- TODO: come up with a better error message
else
have hm : m = G.length := by simpa using hm
return { F := const, functor := q(Const.MvFunctor), qpf := q(Const.mvqpf)}

partial def handleApp (vars : Vector FVarId arity) (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do
let vars' := vars.toList

let ⟨numArgs, F, args⟩ ← (Comp.parseApp (isLiveVar vars) target)
trace[QPF] "target {target} is an application of {F} to {args.toList}"

let Ffunctor ← synthInstanceQ q(MvFunctor $F)
let Fqpf ← synthInstanceQ q(@MvQPF _ $F $Ffunctor)
/-
Optimization: check if the application is of the form `F α β γ .. = G α β γ ..`.
In such cases, we can directly return `G`, rather than generate a composition of projections.
-/

let G : Vec _ m := fun i => G.get (hm ▸ i.inv)
let GExpr : Q(Vec (TypeFun.{u,u} $arity) $m) :=
Vec.toExpr (fun i => (G i).F)
let GFunctor : Q(∀ i, MvFunctor.{u,u} ($GExpr i)) :=
let αs := q(fun i => MvFunctor.{u,u} ($GExpr i))
@DVec.toExpr _ _ αs (fun i => (G i).functor)
let GQpf : Q(∀ i, @MvQPF.{u,u} _ _ ($GFunctor i)) :=
let αs := q(fun i => @MvQPF.{u,u} _ _ ($GFunctor i))
@DVec.toExpr _ _ αs (fun i => (G i).qpf)
-- O(n²), equivalent to a zipping and would be O(n) and much more readable
/- let is_trivial := -/
/- args.length == arity -/
/- && args.toList.enum.all fun ⟨i, arg⟩ => -/
-- Equivalent (?), O(n), and more readable
-- TODO: is this really equivalent
let is_trivial := (args.toList.mapM Expr.fvarId?).any (· == vars')
/- let is_trivial := (args.toList.mapM Expr.fvarId?).any (BEq.beq vars') -/

if is_trivial then
trace[QPF] "The application is trivial"
let functor ← synthInstanceQ q(MvFunctor $F)
let qpf ← synthInstanceQ q(MvQPF $F)

return { F, functor, qpf }
else
let G ← Vector.mapM args (elabQpf vars · none false)

let Ffunctor ← synthInstanceQ q(MvFunctor $F)
let Fqpf ← synthInstanceQ q(@MvQPF _ $F $Ffunctor)

let G : Vec _ numArgs := fun i => G.get i.inv
let GExpr : Q(Vec (TypeFun.{u,u} $arity) $numArgs) :=
Vec.toExpr (fun i => (G i).F)
let GFunctor : Q(∀ i, MvFunctor.{u,u} ($GExpr i)) :=
let αs := q(fun i => MvFunctor.{u,u} ($GExpr i))
@DVec.toExpr _ _ αs (fun i => (G i).functor)
let GQpf : Q(∀ i, @MvQPF.{u,u} _ _ ($GFunctor i)) :=
let αs := q(fun i => @MvQPF.{u,u} _ _ ($GFunctor i))
@DVec.toExpr _ _ αs (fun i => (G i).qpf)

let comp := q(@Comp $numArgs _ $F $GExpr)
trace[QPF] "G := {GExpr}"
trace[QPF] "comp := {comp}"

let functor := q(Comp.instMvFunctorComp)
let qpf := q(Comp.instMvQPFCompInstMvFunctorCompFin2
(fF := $Ffunctor) (q := $Fqpf) (fG := _) (q' := $GQpf)
)

let comp := q(@Comp $m _ $F $GExpr)
trace[QPF] "G := {GExpr}"
trace[QPF] "comp := {comp}"
return { F := comp, functor, qpf }

let functor := q(Comp.instMvFunctorComp)
let qpf := q(Comp.instMvQPFCompInstMvFunctorCompFin2
(fF := $Ffunctor) (q := $Fqpf) (fG := _) (q' := $GQpf)
)
pure ⟨comp, functor, qpf⟩
partial def handleArrow (binderType body : Expr) (vars : Vector FVarId arity) (targetStx : Option Term := none) (normalized := false): TermElabM (ElabQpfResult u arity) := do
let newTarget ← mkAppM ``MvQPF.Arrow.Arrow #[binderType, body]
elabQpf vars newTarget targetStx normalized

else if target.isArrow then
match target with
| Expr.forallE _ e₁ e₂ .. =>
let newTarget ← mkAppM ``MvQPF.Arrow.Arrow #[e₁, e₂]
elabQpf vars newTarget targetStx normalized
| _ => unreachable!
/-- Elaborate the body of a qpf -/
partial def elabQpf {arity : Nat} (vars : Vector FVarId arity) (target : Q(Type u)) (targetStx : Option Term := none) (normalized := false) :
TermElabM (ElabQpfResult u arity) := do
trace[QPF] "elabQPF :: {(vars.map Expr.fvar).toList} -> {target}"
let isLiveVar := isLiveVar vars

if let some target := target.fvarId?.filter isLiveVar then
handleLiveFVar vars target
else if !target.hasAnyFVar isLiveVar then
handleConst target
else if target.isApp then -- Could be pattern-matched here as well
handleApp vars target
else if let .forallE _ binderType body .. := target then
handleArrow binderType body vars (normalized := normalized) (targetStx := targetStx)
else if !normalized then
let target ← whnfR target
elabQpf vars target targetStx true
else
if !normalized then
let target ← whnfR target
elabQpf vars target targetStx true
else
let extra :=
if target.isForall then
"Dependent arrows / forall are not supported"
else
""
throwError f!"Unexpected target expression :\n {target}\n{extra}\nNote that the expression contains live variables, hence, must be functorial"


let extra := if target.isForall then "Dependent arrows / forall are not supported" else ""
throwError f!"Unexpected target expression :\n {target}\n{extra}\nNote that the expression contains live variables, hence, must be functorial"
end


structure QpfCompositionBodyView where
Expand Down Expand Up @@ -306,6 +320,10 @@ def elabQpfCompositionBody (view: QpfCompositionBodyView) :
let target_expr ← elabTermEnsuringTypeQ (u:=u.succ.succ) view.target q(Type u)
let arity := vars.toList.length
let vars : Vector _ arity := ⟨vars.toList, rfl⟩

let some vars := Vector.mapM vars Expr.fvarId? |
throwError "Expected all args to be fvars"

let res ← elabQpf vars target_expr view.target

res.F.check
Expand Down
58 changes: 25 additions & 33 deletions Qpf/Macro/Data/Replace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ structure CtorArgs where

/- TODO(@William): make these correspond by combining expr and vars into a product -/
structure Replace where
(expr: Array Term)
(vars: Array Name)
(vals: Array (Name × Term))
/- (expr: Array Term) -/
/- (vars: Array Name) -/
(ctor: CtorArgs)

def Replace.vars (r : Replace): Array Name := r.vals.map Prod.fst
def Replace.expr (r : Replace): Array Term := r.vals.map Prod.snd


variable (m) [Monad m] [MonadQuotation m] [MonadError m] [MonadTrace m] [MonadOptions m]
[AddMessageContext m] [MonadLiftT IO m]
Expand All @@ -32,25 +36,23 @@ private abbrev ReplaceM := StateT Replace m
variable {m}

private def Replace.new : m Replace :=
do pure ⟨#[], #[], ⟨#[], #[]⟩⟩
do pure ⟨#[], ⟨#[], #[]⟩⟩

private def CtorArgs.reset : ReplaceM m Unit := do
let r ← StateT.get
let n := r.vars.size
let ctor: CtorArgs := ⟨#[], (Array.range n).map fun _ => #[]⟩
StateT.set ⟨r.expr, r.vars, ctor
StateT.set { r with ctor }
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an unrelated change, please put that in it's own PR.
If you put the trivial stuff separately, I can quickly review and merge that, removing noise from the current diff


private def CtorArgs.get : ReplaceM m CtorArgs := do
pure (←StateT.get).ctor

/--
The arity of the shape type created *after* replacing, i.e., the size of `r.expr`
-/
def Replace.arity (r : Replace) : Nat :=
r.expr.size
def Replace.arity (r : Replace) : Nat := r.vals.size

def Replace.getBinderIdents (r : Replace) : Array Ident :=
r.vars.map mkIdent
def Replace.getBinderIdents (r : Replace) : Array Ident := r.vars.map mkIdent

open Parser.Term in
def Replace.getBinders {m} [Monad m] [MonadQuotation m] (r : Replace) : m <| TSyntax ``bracketedBinder := do
Expand All @@ -75,17 +77,15 @@ private def ReplaceM.identFor (stx : Term) : ReplaceM m Ident := do
| some id => do
let ctor_per_type := ctor.per_type.set! id $ (ctor.per_type.get! id).push argName
let ctor := ⟨ctor_args, ctor_per_type⟩
StateT.set ⟨r.expr, r.vars, ctor
StateT.set { r with ctor }
pure $ r.vars.get! id
| none => do
let ctor_per_type := ctor.per_type.push #[argName]
let name ← mkFreshBinderName
StateT.set ⟨r.expr.push stx, r.vars.push name, ⟨ctor_args, ctor_per_type⟩
StateT.set { vals := r.vals.push (name, stx), ctor := ⟨ctor_args, ctor_per_type⟩ }
pure name

return mkIdent name




open Lean.Parser in
Expand All @@ -99,11 +99,9 @@ private partial def shapeOf' : Syntax → ReplaceM m Syntax
let ctor_arg ← ReplaceM.identFor ⟨arg⟩
let ctor_tail ← shapeOf' tail

-- dbg_trace ">> {arg} ==> {ctor_arg}"
pure $ mkNode ``Term.arrow #[ctor_arg, arrow, ctor_tail]

| ctor_type =>
pure ctor_type
| ctor_type => pure ctor_type



Expand All @@ -117,12 +115,6 @@ private partial def setResultingType (res_type : Syntax) : Syntax → ReplaceM m
pure $ mkNode ``Term.arrow #[arg, arrow, tail]
| _ =>
pure res_type

-- TODO: this should be deprecated in favour of {v with ...} syntax
def CtorView.withType? (ctor : CtorView) (type? : Option Syntax) : CtorView := {
ctor with type?
}

/-
TODO: currently these functions ignore dead variables, everything is replaced.
This is OK, we can supply a "dead" value to a live variable, but we lose the ability to have
Expand Down Expand Up @@ -203,9 +195,9 @@ Replace.run <| do

CtorArgs.reset

let type? ← ctor.type?.mapM $ shapeOf'
let type? ← ctor.type?.mapM shapeOf'

pure $ (CtorView.withType? ctor type?, ←CtorArgs.get)
pure ({ ctor with type? }, ←CtorArgs.get)

let r ← StateT.get
let ctors := pairs.map Prod.fst;
Expand All @@ -216,8 +208,8 @@ Replace.run <| do

-- HACK: It seems that `Array.append` causes a stack overflow, so we go through `List` for now
-- TODO: fix this after updating to newer Lean version
let per_type := per_type.appendList $ (List.range diff).map (fun _ => (#[] : Array Name));
ctorArgs.args, per_type
let per_type := per_type.appendList $ List.replicate diff (#[] : Array Name)
{ ctorArgs with per_type }

-- Now that we know how many free variables were introduced, we can fix up the resulting type
-- of each constructor to be `Shape α_0 α_1 ... α_n`
Expand All @@ -226,24 +218,24 @@ Replace.run <| do

let ctors ← ctors.mapM fun ctor => do
let type? ← ctor.type?.mapM (setResultingType res)
pure $ CtorView.withType? ctor type?
pure { ctor with type? }

pure (ctors, ctorArgs)




/-- Replace syntax in *all* subexpressions -/
partial def Replace.replaceAllStx (find replace : Syntax) : Syntax → Syntax :=
fun stx =>
if stx == find then
replace
else
stx.setArgs (stx.getArgs.map (replaceAllStx find replace))
partial def Replace.replaceAllStx (find replace stx : Syntax) : Syntax :=
if stx == find then
replace
else
stx.setArgs (stx.getArgs.map (replaceAllStx find replace))



open Parser in
-- TODO: In this occasion is it with pulling stx out, it makes this a lot less noisy
/--
Given a sequence of arrows e₁ → e₂ → ... → eₙ, check that `eₙ == recType`, and replace all
*other* occurences (i.e., in e₁ ... eₖ₋₁) of `recType` with `newParam`.
Expand Down Expand Up @@ -292,7 +284,7 @@ def makeNonRecursive (view : DataView) : MetaM (DataView × Name) := do

let ctors ← view.ctors.mapM fun ctor => do
let type? ← ctor.type?.mapM (Replace.replaceStx expected recId <| TSyntax.mk ·)
return CtorView.withType? ctor type?
pure { ctor with type? }

let view := view.setCtors ctors
pure (view, rec)
Loading