Skip to content

Commit

Permalink
refactor: composition pipeline and qpf body parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
William Sørensen committed Jun 25, 2024
1 parent 201cf79 commit 2811848
Showing 1 changed file with 109 additions and 91 deletions.
200 changes: 109 additions & 91 deletions Qpf/Macro/Comp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -152,118 +152,132 @@ def DVec.toExpr {n : Nat} {αs : Q(Fin2 $n → Type u)} (xs : DVec (fun (i : Fin
| .fs i => $as i
)

def Vector.mapM
{m : Type u → Type v}
{α : Type w}
{β : Type u}
[Monad m]
(l : Vector α n)
(f : α → m β) : m $ Vector β n :=
match l with
| ⟨.cons hd tl, b⟩ => do
let v ← f hd
let x ← Vector.mapM ⟨tl, by rfl⟩ f
pure ⟨v :: x.val, by {
rw [List.length_cons] at b
rw [List.length_cons, x.property]
exact b
}⟩
| ⟨.nil, h⟩ => pure ⟨[], by exact h⟩

structure ElabQpfResult (u : Level) (arity : Nat) where
F : Q(TypeFun.{u,u} $arity)
functor : Q(MvFunctor $F)
qpf : Q(@MvQPF _ $F $functor)
deriving Inhabited

open PrettyPrinter in
/--
Elaborate the body of a qpf
-/
partial def elabQpf {arity : Nat} (vars : Vector Q(Type u) arity) (target : Q(Type u)) (targetStx : Option Term := none) (normalized := false) :
TermElabM (ElabQpfResult u arity) := do
trace[QPF] "elabQPF :: {vars.toList} -> {target}"
let vars' := vars.toList;
/- def isLiveVar (varIds : Vector FVarId n) (id : FVarId) := (List.indexOf' id varIds.toList).isSome -/
def isLiveVar (varIds : Vector FVarId n) (id : FVarId) := varIds.toList.contains id

let varIds := vars'.map fun expr => expr.fvarId!
let isLiveVar : FVarId → Bool
:= fun fvarId => (List.indexOf' fvarId varIds).isSome
open PrettyPrinter in
mutual
partial def handleLiveFVar (vars : Vector FVarId arity) (target : FVarId) : TermElabM (ElabQpfResult u arity) := do
trace[QPF] f!"target {Expr.fvar target} is a free variable"
let ind ← match List.indexOf' target vars.toList with
| none => throwError "Free variable {Expr.fvar target} is not one of the qpf arguments"
| some ind => pure ind

if target.isFVar && isLiveVar target.fvarId! then
trace[QPF] f!"target {target} is a free variable"
let ind ← match List.indexOf' target vars' with
| none => throwError "Free variable {target} is not one of the qpf arguments"
| some ind => pure ind
let ind : Fin2 arity := cast (by simp) ind.inv
let prj := q(@Prj.{u} $arity $ind)
trace[QPF] "represented by: {prj}"

let ind : Fin2 arity := cast (by simp [vars']) ind.inv
let prj := q(@Prj.{u} $arity $ind)
trace[QPF] "represented by: {prj}"
pure ⟨prj, q(Prj.mvfunctor _), q(Prj.mvqpf _)⟩
return { F := prj, functor := q(Prj.mvfunctor _), qpf := q(Prj.mvqpf _) }

else if !target.hasAnyFVar isLiveVar then
trace[QPF] "target {target} is a constant"
let const := q(Const.{u} $arity $target)
trace[QPF] "represented by: {const}"
pure ⟨const, q(Const.MvFunctor), q(Const.mvqpf)⟩
partial def handleConst (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do
trace[QPF] "target {target} is a constant"
let const := q(Const.{u} $arity $target)
trace[QPF] "represented by: {const}"

else if target.isApp then
let ⟨m, F, args⟩ ← (Comp.parseApp isLiveVar target)
trace[QPF] "target {target} is an application of {F} to {args.toList}"

/-
Optimization: check if the application is of the form `F α β γ .. = F α β γ ..`.
In such cases, we can directly return `F`, rather than generate a composition of projections.
-/
let is_trivial :=
args.length == arity
&& args.toList.enum.all fun ⟨i, arg⟩ =>
arg.isFVar && isLiveVar arg.fvarId! && vars'.indexOf arg == i
if is_trivial then
trace[QPF] "The application is trivial"
let mvFunctor ← synthInstanceQ q(MvFunctor $F)
let mvQPF ← synthInstanceQ q(MvQPF $F)
pure ⟨F, mvFunctor, mvQPF⟩
else
let G ← args.toList.mapM fun a =>
elabQpf vars a none false

/-
HACK: We know that `m`, which was defines as `args.length`, is equal to `G.length`.
It's a bit difficult to prove this. Thus, we simply assert it
-/
if hm : m ≠ G.length then
throwError "This shouldn't happen" -- TODO: come up with a better error message
else
have hm : m = G.length := by simpa using hm
return { F := const, functor := q(Const.MvFunctor), qpf := q(Const.mvqpf)}

partial def handleApp (vars : Vector FVarId arity) (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do
let vars' := vars.toList

let ⟨numArgs, F, args⟩ ← (Comp.parseApp (isLiveVar vars) target)
trace[QPF] "target {target} is an application of {F} to {args.toList}"

let Ffunctor ← synthInstanceQ q(MvFunctor $F)
let Fqpf ← synthInstanceQ q(@MvQPF _ $F $Ffunctor)
/-
Optimization: check if the application is of the form `F α β γ .. = G α β γ ..`.
In such cases, we can directly return `G`, rather than generate a composition of projections.
-/

let G : Vec _ m := fun i => G.get (hm ▸ i.inv)
let GExpr : Q(Vec (TypeFun.{u,u} $arity) $m) :=
Vec.toExpr (fun i => (G i).F)
let GFunctor : Q(∀ i, MvFunctor.{u,u} ($GExpr i)) :=
let αs := q(fun i => MvFunctor.{u,u} ($GExpr i))
@DVec.toExpr _ _ αs (fun i => (G i).functor)
let GQpf : Q(∀ i, @MvQPF.{u,u} _ _ ($GFunctor i)) :=
let αs := q(fun i => @MvQPF.{u,u} _ _ ($GFunctor i))
@DVec.toExpr _ _ αs (fun i => (G i).qpf)
-- O(n²), equivalent to a zipping and would be O(n) and much more readable
/- let is_trivial := -/
/- args.length == arity -/
/- && args.toList.enum.all fun ⟨i, arg⟩ => -/
-- Equivalent (?), O(n), and more readable
-- TODO: is this really equivalent
let is_trivial := (args.toList.mapM Expr.fvarId?).any (· == vars')
/- let is_trivial := (args.toList.mapM Expr.fvarId?).any (BEq.beq vars') -/

if is_trivial then
trace[QPF] "The application is trivial"
let functor ← synthInstanceQ q(MvFunctor $F)
let qpf ← synthInstanceQ q(MvQPF $F)

return { F, functor, qpf }
else
let G ← Vector.mapM args (elabQpf vars · none false)

let Ffunctor ← synthInstanceQ q(MvFunctor $F)
let Fqpf ← synthInstanceQ q(@MvQPF _ $F $Ffunctor)

let G : Vec _ numArgs := fun i => G.get i.inv
let GExpr : Q(Vec (TypeFun.{u,u} $arity) $numArgs) :=
Vec.toExpr (fun i => (G i).F)
let GFunctor : Q(∀ i, MvFunctor.{u,u} ($GExpr i)) :=
let αs := q(fun i => MvFunctor.{u,u} ($GExpr i))
@DVec.toExpr _ _ αs (fun i => (G i).functor)
let GQpf : Q(∀ i, @MvQPF.{u,u} _ _ ($GFunctor i)) :=
let αs := q(fun i => @MvQPF.{u,u} _ _ ($GFunctor i))
@DVec.toExpr _ _ αs (fun i => (G i).qpf)

let comp := q(@Comp $numArgs _ $F $GExpr)
trace[QPF] "G := {GExpr}"
trace[QPF] "comp := {comp}"

let functor := q(Comp.instMvFunctorComp)
let qpf := q(Comp.instMvQPFCompInstMvFunctorCompFin2
(fF := $Ffunctor) (q := $Fqpf) (fG := _) (q' := $GQpf)
)

let comp := q(@Comp $m _ $F $GExpr)
trace[QPF] "G := {GExpr}"
trace[QPF] "comp := {comp}"
return { F := comp, functor, qpf }

let functor := q(Comp.instMvFunctorComp)
let qpf := q(Comp.instMvQPFCompInstMvFunctorCompFin2
(fF := $Ffunctor) (q := $Fqpf) (fG := _) (q' := $GQpf)
)
pure ⟨comp, functor, qpf⟩
partial def handleArrow (binderType body : Expr) (vars : Vector FVarId arity) (targetStx : Option Term := none) (normalized := false): TermElabM (ElabQpfResult u arity) := do
let newTarget ← mkAppM ``MvQPF.Arrow.Arrow #[binderType, body]
elabQpf vars newTarget targetStx normalized

else if target.isArrow then
match target with
| Expr.forallE _ e₁ e₂ .. =>
let newTarget ← mkAppM ``MvQPF.Arrow.Arrow #[e₁, e₂]
elabQpf vars newTarget targetStx normalized
| _ => unreachable!
/-- Elaborate the body of a qpf -/
partial def elabQpf {arity : Nat} (vars : Vector FVarId arity) (target : Q(Type u)) (targetStx : Option Term := none) (normalized := false) :
TermElabM (ElabQpfResult u arity) := do
trace[QPF] "elabQPF :: {(vars.map Expr.fvar).toList} -> {target}"
let isLiveVar := isLiveVar vars

if let some target := target.fvarId?.filter isLiveVar then
handleLiveFVar vars target
else if !target.hasAnyFVar isLiveVar then
handleConst target
else if target.isApp then -- Could be pattern-matched here as well
handleApp vars target
else if let .forallE _ binderType body .. := target then
handleArrow binderType body vars (normalized := normalized) (targetStx := targetStx)
else if !normalized then
let target ← whnfR target
elabQpf vars target targetStx true
else
if !normalized then
let target ← whnfR target
elabQpf vars target targetStx true
else
let extra :=
if target.isForall then
"Dependent arrows / forall are not supported"
else
""
throwError f!"Unexpected target expression :\n {target}\n{extra}\nNote that the expression contains live variables, hence, must be functorial"


let extra := if target.isForall then "Dependent arrows / forall are not supported" else ""
throwError f!"Unexpected target expression :\n {target}\n{extra}\nNote that the expression contains live variables, hence, must be functorial"
end


structure QpfCompositionBodyView where
Expand Down Expand Up @@ -306,6 +320,10 @@ def elabQpfCompositionBody (view: QpfCompositionBodyView) :
let target_expr ← elabTermEnsuringTypeQ (u:=u.succ.succ) view.target q(Type u)
let arity := vars.toList.length
let vars : Vector _ arity := ⟨vars.toList, rfl⟩

let some vars := Vector.mapM vars Expr.fvarId? |
throwError "Expected all args to be fvars"

let res ← elabQpf vars target_expr view.target

res.F.check
Expand Down

0 comments on commit 2811848

Please sign in to comment.