Skip to content

Commit

Permalink
refactor: update types in elabQpf to reflect the code
Browse files Browse the repository at this point in the history
William Sørensen committed Jun 26, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent cdf79d0 commit 1720c1e
Showing 1 changed file with 85 additions and 90 deletions.
175 changes: 85 additions & 90 deletions Qpf/Macro/Comp.lean
Original file line number Diff line number Diff line change
@@ -159,130 +159,121 @@ structure ElabQpfResult (u : Level) (arity : Nat) where
qpf : Q(@MvQPF _ $F $functor)
deriving Inhabited

-- TODO: Move to Vector.mmap
def Vector.mapM {m : Type u → Type v}
{α : Type w} {β : Type u} [Monad m]
(l : Vector α n) (f : α → m β) : m $ Vector β n :=
match l with
| ⟨.cons hd tl, b⟩ => do
let v ← f hd
let x ← Vector.mapM ⟨tl, rfl⟩ f
pure ⟨v :: x.val, by simpa [List.length_cons, x.property] using b⟩
| ⟨.nil, h⟩ => pure ⟨[], by exact h⟩

def isLiveVar (varIds : Vector FVarId n) (id : FVarId) := varIds.toList.contains id

open PrettyPrinter in
mutual
partial def handleLiveFVar (vars : Vector Q(Type u) arity) (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do
let vars' := vars.toList;

trace[QPF] f!"target {target} is a free variable"
let ind ← match List.indexOf' target vars' with
| none => throwError "Free variable {target} is not one of the qpf arguments"
partial def handleLiveFVar (vars : Vector FVarId arity) (target : FVarId) : TermElabM (ElabQpfResult u arity) := do
trace[QPF] f!"target {Expr.fvar target} is a free variable"
let ind ← match List.indexOf' target vars.toList with
| none => throwError "Free variable {Expr.fvar target} is not one of the qpf arguments"
| some ind => pure ind

let ind : Fin2 arity := cast (by simp [vars']) ind.inv
let ind : Fin2 arity := cast (by simp) ind.inv
let prj := q(@Prj.{u} $arity $ind)
trace[QPF] "represented by: {prj}"
pure ⟨prj, q(Prj.mvfunctor _), q(Prj.mvqpf _)⟩

partial def handleConst (vars : Vector Q(Type u) arity) (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do
pure { F := prj, functor := q(Prj.mvfunctor _), qpf := q(Prj.mvqpf _) }

partial def handleConst (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do
trace[QPF] "target {target} is a constant"
let const := q(Const.{u} $arity $target)
trace[QPF] "represented by: {const}"
pure ⟨const, q(Const.MvFunctor), q(Const.mvqpf)⟩

partial def handleApp (vars : Vector Q(Type u) arity) (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do
let vars' := vars.toList;
pure { F := const, functor := q(Const.MvFunctor), qpf := q(Const.mvqpf)}

let varIds := vars'.map fun expr => expr.fvarId!
let isLiveVar : FVarId → Bool
:= fun fvarId => (List.indexOf' fvarId varIds).isSome
partial def handleApp (vars : Vector FVarId arity) (target : Q(Type u)) : TermElabM (ElabQpfResult u arity) := do
let vars' := vars.toList

letm, F, args⟩ ← (Comp.parseApp isLiveVar target)
letnumArgs, F, args⟩ ← (Comp.parseApp (isLiveVar vars) target)
trace[QPF] "target {target} is an application of {F} to {args.toList}"

/-
Optimization: check if the application is of the form `F α β γ .. = F α β γ ..`.
In such cases, we can directly return `F`, rather than generate a composition of projections.
Optimization: check if the application is of the form `F α β γ .. = G α β γ ..`.
In such cases, we can directly return `G`, rather than generate a composition of projections.
-/
let is_trivial :=
args.length == arity
&& args.toList.enum.all fun ⟨i, arg⟩ =>
arg.isFVar && isLiveVar arg.fvarId! && vars'.indexOf arg == i
if is_trivial then
trace[QPF] "The application is trivial"
let mvFunctor ← synthInstanceQ q(MvFunctor $F)
let mvQPF ← synthInstanceQ q(MvQPF $F)
pure ⟨F, mvFunctor, mvQPF⟩
else
let G ← args.toList.mapM fun a =>
elabQpf vars a none false

/-
HACK: We know that `m`, which was defines as `args.length`, is equal to `G.length`.
It's a bit difficult to prove this. Thus, we simply assert it
-/
if hm : m ≠ G.length then
throwError "This shouldn't happen" -- TODO: come up with a better error message
else
have hm : m = G.length := by simpa using hm


-- O(n²), equivalent to a zipping and would be O(n) and much more readable
/- let is_trivial := -/
/- args.length == arity -/
/- && args.toList.enum.all fun ⟨i, arg⟩ => -/
-- Equivalent (?), O(n), and more readable
-- TODO: is this really equivalent
let is_trivial := (args.toList.mapM Expr.fvarId?).any (· == vars')
/- let is_trivial := (args.toList.mapM Expr.fvarId?).any (BEq.beq vars') -/

let Ffunctor ← synthInstanceQ q(MvFunctor $F)
let Fqpf ← synthInstanceQ q(@MvQPF _ $F $Ffunctor)

let G : Vec _ m := fun i => G.get (hm ▸ i.inv)
let GExpr : Q(Vec (TypeFun.{u,u} $arity) $m) :=
Vec.toExpr (fun i => (G i).F)
let GFunctor : Q(∀ i, MvFunctor.{u,u} ($GExpr i)) :=
let αs := q(fun i => MvFunctor.{u,u} ($GExpr i))
@DVec.toExpr _ _ αs (fun i => (G i).functor)
let GQpf : Q(∀ i, @MvQPF.{u,u} _ _ ($GFunctor i)) :=
let αs := q(fun i => @MvQPF.{u,u} _ _ ($GFunctor i))
@DVec.toExpr _ _ αs (fun i => (G i).qpf)

let comp := q(@Comp $m _ $F $GExpr)
trace[QPF] "G := {GExpr}"
trace[QPF] "comp := {comp}"
if is_trivial then
trace[QPF] "The application is trivial"
let functor ← synthInstanceQ q(MvFunctor $F)
let qpf ← synthInstanceQ q(MvQPF $F)

let functor := q(Comp.instMvFunctorComp)
let qpf := q(Comp.instMvQPFCompInstMvFunctorCompFin2
(fF := $Ffunctor) (q := $Fqpf) (fG := _) (q' := $GQpf)
)
pure ⟨comp, functor, qpf⟩
return { F, functor, qpf }
else
let G ← Vector.mapM args (elabQpf vars · none false)

let Ffunctor ← synthInstanceQ q(MvFunctor $F)
let Fqpf ← synthInstanceQ q(@MvQPF _ $F $Ffunctor)

let G : Vec _ numArgs := fun i => G.get i.inv
let GExpr : Q(Vec (TypeFun.{u,u} $arity) $numArgs) :=
Vec.toExpr (fun i => (G i).F)
let GFunctor : Q(∀ i, MvFunctor.{u,u} ($GExpr i)) :=
let αs := q(fun i => MvFunctor.{u,u} ($GExpr i))
@DVec.toExpr _ _ αs (fun i => (G i).functor)
let GQpf : Q(∀ i, @MvQPF.{u,u} _ _ ($GFunctor i)) :=
let αs := q(fun i => @MvQPF.{u,u} _ _ ($GFunctor i))
@DVec.toExpr _ _ αs (fun i => (G i).qpf)

let comp := q(@Comp $numArgs _ $F $GExpr)
trace[QPF] "G := {GExpr}"
trace[QPF] "comp := {comp}"

let functor := q(Comp.instMvFunctorComp)
let qpf := q(Comp.instMvQPFCompInstMvFunctorCompFin2
(fF := $Ffunctor) (q := $Fqpf) (fG := _) (q' := $GQpf)
)

return { F := comp, functor, qpf }

partial def handleArrow (vars : Vector Q(Type u) arity) (target : Q(Type u)) (targetStx : Option Term := none) (normalized := false) : TermElabM (ElabQpfResult u arity) := do
match target with
| Expr.forallE _ e₁ e₂ .. =>
let newTarget ← mkAppM ``MvQPF.Arrow.Arrow #[e₁, e₂]
elabQpf vars newTarget targetStx normalized
| _ => unreachable!
partial def handleArrow (binderType body : Expr) (vars : Vector FVarId arity) (targetStx : Option Term := none) (normalized := false): TermElabM (ElabQpfResult u arity) := do
let newTarget ← mkAppM ``MvQPF.Arrow.Arrow #[binderType, body]
elabQpf vars newTarget targetStx normalized

/--
Elaborate the body of a qpf
-/
partial def elabQpf {arity : Nat} (vars : Vector Q(Type u) arity) (target : Q(Type u)) (targetStx : Option Term := none) (normalized := false) :
partial def elabQpf {arity : Nat} (vars : Vector FVarId arity) (target : Q(Type u)) (targetStx : Option Term := none) (normalized := false) :
TermElabM (ElabQpfResult u arity) := do
trace[QPF] "elabQPF :: {vars.toList} -> {target}"
let vars' := vars.toList;
trace[QPF] "elabQPF :: {(vars.map Expr.fvar).toList} -> {target}"
let isLiveVar := isLiveVar vars

let varIds := vars'.map fun expr => expr.fvarId!
let isLiveVar : FVarId → Bool
:= fun fvarId => (List.indexOf' fvarId varIds).isSome

if target.isFVar && isLiveVar target.fvarId! then
if let some target := target.fvarId?.filter isLiveVar then
handleLiveFVar vars target
else if !target.hasAnyFVar isLiveVar then
handleConst vars target
else if target.isApp then
handleConst target
else if target.isApp then -- Could be pattern-matched here as well
handleApp vars target
else if target.isArrow then
handleArrow vars target (targetStx := targetStx) (normalized := normalized)
else if let .forallE _ binderType body .. := target then
handleArrow binderType body vars (normalized := normalized) (targetStx := targetStx)
else if !normalized then
let target ← whnfR target
elabQpf vars target targetStx true
else
if !normalized then
let target ← whnfR target
elabQpf vars target targetStx true
else
let extra :=
if target.isForall then
"Dependent arrows / forall are not supported"
else
""
throwError f!"Unexpected target expression :\n {target}\n{extra}\nNote that the expression contains live variables, hence, must be functorial"
let extra := if target.isForall then "Dependent arrows / forall are not supported" else ""
throwError f!"Unexpected target expression :\n {target}\n{extra}\nNote that the expression contains live variables, hence, must be functorial"
end


structure QpfCompositionBodyView where
(type? : Option Syntax := none)
(target : Term)
@@ -323,6 +314,10 @@ def elabQpfCompositionBody (view: QpfCompositionBodyView) :
let target_expr ← elabTermEnsuringTypeQ (u:=u.succ.succ) view.target q(Type u)
let arity := vars.toList.length
let vars : Vector _ arity := ⟨vars.toList, rfl⟩

let some vars := Vector.mapM vars Expr.fvarId? |
throwError "Expected all args to be fvars"

let res ← elabQpf vars target_expr view.target

res.F.check

0 comments on commit 1720c1e

Please sign in to comment.