From d7db483d4708fa0fcd4e5a03e2cc3e1f7bbe94b8 Mon Sep 17 00:00:00 2001 From: William Spencer Date: Mon, 4 Mar 2024 17:37:58 -0600 Subject: [PATCH] Further updates, expanded kronecker product, and frameworked CastCategory --- ViCaR/Classes/BraidedMonoidal.v | 2 +- ViCaR/Classes/CastCategory.v | 36 + ViCaR/Classes/Category.v | 20 + ViCaR/Classes/ExamplesAutomation.v | 214 +++ ViCaR/GeneralLemmas.v | 225 ++- examples/DirectSumMatrixExample.v | 666 +++++++ examples/KronComm.v | 2600 ++++++++++++++++++++++++++ examples/KronComm_orig.v | 1213 ++++++++++++ examples/KronMatrixExample.v | 103 ++ examples/MatrixExampleBase.v | 263 +++ examples/MatrixPermBase.v | 166 ++ examples/ZXExample.v | 2773 +++++++++++----------------- 12 files changed, 6613 insertions(+), 1668 deletions(-) create mode 100644 ViCaR/Classes/CastCategory.v create mode 100644 ViCaR/Classes/ExamplesAutomation.v create mode 100644 examples/DirectSumMatrixExample.v create mode 100644 examples/KronComm.v create mode 100644 examples/KronComm_orig.v create mode 100644 examples/KronMatrixExample.v create mode 100644 examples/MatrixExampleBase.v create mode 100644 examples/MatrixPermBase.v diff --git a/ViCaR/Classes/BraidedMonoidal.v b/ViCaR/Classes/BraidedMonoidal.v index 243a353..1a63fcd 100644 --- a/ViCaR/Classes/BraidedMonoidal.v +++ b/ViCaR/Classes/BraidedMonoidal.v @@ -10,7 +10,7 @@ Notation CommuteBifunctor' F := ({| id2_map := ltac:(intros; apply id2_map); compose2_map := ltac:(intros; apply compose2_map); morphism2_compat := ltac:(intros; apply morphism2_compat; easy); -|}). +|}) (only parsing). diff --git a/ViCaR/Classes/CastCategory.v b/ViCaR/Classes/CastCategory.v new file mode 100644 index 0000000..54b3bc2 --- /dev/null +++ b/ViCaR/Classes/CastCategory.v @@ -0,0 +1,36 @@ +Require Export Category. +Require Import ExamplesAutomation. + +Local Open Scope Cat. + +Class CastCategory {C : Type} (cC : Category C) : Type := { + cast {A B A' B' : C} (HA : A = A') (HB : B = B') : + A' ~> B' -> A ~> B; + (* Will need some coherence conditions, probably + cast_id, cast_compose? Might be it. Oh, we may want + cast HA HB (cast (eq_sym HA) (eq_sym HB) f) ≃ f as + well, just giving bijectivity. All should be trivial + for any sensible cast. This seems sufficient on no + reflection, so let's see how far it can go: *) + cast_invertible {A B A' B' : C} + (HA : A' = A) (HB : B' = B) (HA' : A = A') (HB' : B = B') f : + cast HA' HB' (cast HA HB f) ≃ f; +}. + +Definition CastCategory_of_DecEq_Category {C : Type} (cC: Category C) + (dec : forall A B : C, {A = B} + {A <> B}) : + CastCategory cC := {| + cast := fun A B A' B' HA HB => + match HA in (_ = a) return (a ~> B' -> A ~> B) with (* Tell coq that A = A' *) + | eq_refl => + fun f => + match HB in (_ = a) return (A ~> a -> A ~> B) with + | eq_refl => fun f' => f' + end f + end; + cast_invertible := ltac:(intros A B A' B' HA HB; + destruct HA, HB; intros HA' HB'; + rewrite (@Eqdep_dec.UIP_dec C dec A' _ _ (eq_refl)); + rewrite (@Eqdep_dec.UIP_dec C dec B' _ _ (eq_refl)); + reflexivity); +|}. \ No newline at end of file diff --git a/ViCaR/Classes/Category.v b/ViCaR/Classes/Category.v index 6d8d9dc..d84691d 100644 --- a/ViCaR/Classes/Category.v +++ b/ViCaR/Classes/Category.v @@ -374,4 +374,24 @@ Definition Bifunctor_of_ProductCategoryFunctor {C1 C2 D : Type} `{cC1 : Category morphism2_compat := ltac:(intros; apply morphism_compat; constructor; easy); |}. +Lemma equiv_of_iso_compose_l {C : Type} `{cC : Category C} {A A' B : C} + (I : Isomorphism A A') (f g : A' ~> B) (H : I ∘ f ≃ I ∘ g) : + f ≃ g. +Proof. + rewrite <- (left_unit (f:=f)), <- (left_unit (f:=g)). + rewrite <- I.(id_B), 2!assoc, H. + easy. +Qed. + +Lemma equiv_of_iso_compose_r {C : Type} `{cC : Category C} {A B' B : C} + (I : Isomorphism B' B) (f g : A ~> B') (H : f ∘ I ≃ g ∘ I) : + f ≃ g. +Proof. + rewrite <- (right_unit (f:=f)), <- (right_unit (f:=g)). + rewrite <- I.(id_A), <- 2!assoc, H. + easy. +Qed. + + + Local Close Scope Cat. diff --git a/ViCaR/Classes/ExamplesAutomation.v b/ViCaR/Classes/ExamplesAutomation.v new file mode 100644 index 0000000..ee4c565 --- /dev/null +++ b/ViCaR/Classes/ExamplesAutomation.v @@ -0,0 +1,214 @@ +From VyZX Require Import PermutationAutomation. +Require Import String. + +Ltac print_state := + try (match goal with | H : ?p |- _ => idtac H ":" p; fail end); + idtac "---------------------------------------------------------"; + match goal with |- ?P => idtac P; idtac "" +end. + +Ltac is_C0 x := assert (x = C0) by (cbv; lca). + +Ltac is_C1 x := assert (x = C1) by (cbv; lca). + +Tactic Notation "print_C" constr(x) := (tryif is_C0 x then constr:("0"%string) else + tryif is_C1 x then constr:("1"%string) else constr:("X"%string)). + +Ltac print_LHS_matU := + intros; + (let i := fresh "i" in + let j := fresh "j" in + let Hi := fresh "Hi" in + let Hj := fresh "Hj" in + intros i j Hi Hj; try solve_end; + repeat (* Enumerate rows and columns; see `by_cell` *) + (destruct i as [| i]; [ | apply <- Nat.succ_lt_mono in Hi ]; + try solve_end); clear Hi; + repeat + (destruct j as [| j]; [ | apply <- Nat.succ_lt_mono in Hj ]; + try solve_end); clear Hj + ); + match goal with |- ?x = ?y ?i ?j => autounfold with U_db; simpl; + match goal with + | |- ?x = _ => + tryif is_C0 x then idtac "A[" i "," j "] = 0" else + tryif is_C1 x then idtac "A[" i "," j "] = 1" else + idtac "A[" i "," j "] = X" + end +end. + +Ltac simpl_bools := + repeat (simpl; rewrite ?andb_true_r, ?andb_false_r, ?orb_true_r, ?orb_false_r). + +Ltac simplify_bools_lia_one_free := + let act_T b := ((replace_bool_lia b true || replace_bool_lia b false); simpl) in + let act_F b := ((replace_bool_lia b false || replace_bool_lia b true); simpl) in + match goal with + | |- context[?b && _] => act_F b; rewrite ?andb_true_l, ?andb_false_l + | |- context[_ && ?b] => act_F b; rewrite ?andb_true_r, ?andb_false_r + | |- context[?b || _] => act_T b; rewrite ?orb_true_l, ?orb_false_l + | |- context[_ || ?b] => act_T b; rewrite ?orb_true_r, ?orb_false_r + | |- context[negb ?b] => act_T b; simpl negb + | |- context[if ?b then _ else _] => act_T b + end; simpl_bools. + +Ltac simplify_bools_lia_one_kernel := + let fail_if_iffy H := + match H with + | context [ if _ then _ else _ ] => fail 1 + | _ => idtac + end + in + let fail_if_compound H := + fail_if_iffy H; + match H with + | context [ ?a && ?b ] => fail 1 + | context [ ?a || ?b ] => fail 1 + | _ => idtac + end + in + let act_T b := (fail_if_compound b; + (replace_bool_lia b true || replace_bool_lia b false); simpl) in + let act_F b := (fail_if_compound b; + (replace_bool_lia b false || replace_bool_lia b true); simpl) in + match goal with + | |- context[?b && _] => act_F b; rewrite ?andb_true_l, ?andb_false_l + | |- context[_ && ?b] => act_F b; rewrite ?andb_true_r, ?andb_false_r + | |- context[?b || _] => act_T b; rewrite ?orb_true_l, ?orb_false_l + | |- context[_ || ?b] => act_T b; rewrite ?orb_true_r, ?orb_false_r + | |- context[negb ?b] => act_T b; simpl negb + | |- context[if ?b then _ else _] => act_T b + end; simpl_bools. + +Ltac simplify_bools_lia_one := + simplify_bools_lia_one_kernel || simplify_bools_lia_one. + +Ltac simplify_bools_lia := + repeat simplify_bools_lia_one. + +Ltac bdestruct_one_old := + let fail_if_iffy H := + match H with + | context [ if _ then _ else _ ] => fail 1 + | _ => idtac + end + in + match goal with + | |- context [ ?a + fail_if_iffy a; fail_if_iffy b; bdestruct (a + fail_if_iffy a; fail_if_iffy b; bdestruct (a <=? b) + | |- context [ ?a =? ?b ] => + fail_if_iffy a; fail_if_iffy b; bdestruct (a =? b) + | |- context [ if ?b then _ else _ ] => fail_if_iffy b; destruct b eqn:? + end. + +Ltac bdestruct_one_new := + let fail_if_iffy H := + match H with + | context [ if _ then _ else _ ] => fail 1 + | _ => idtac + end + in + let fail_if_booley H := + fail_if_iffy H; + match H with + | context [ ?a fail 1 + | context [ ?a <=? ?b ] => fail 1 + | context [ ?a =? ?b ] => fail 1 + | context [ ?a && ?b ] => fail 1 + | context [ ?a || ?b ] => fail 1 + | context [ negb ?a ] => fail 1 + | context [ xorb ?a ?b ] => fail 1 + | _ => idtac + end + in + let rec destruct_kernel H := + match H with + | context [ if ?b then _ else _ ] => destruct_kernel b + | context [ ?a + tryif fail_if_booley a then + (tryif fail_if_booley b then bdestruct (a + tryif fail_if_booley a then + (tryif fail_if_booley b then bdestruct (a <=? b) + else destruct_kernel b) else (destruct_kernel a) + | context [ ?a =? ?b ] => + tryif fail_if_booley a then + (tryif fail_if_booley b then bdestruct (a =? b); try subst + else destruct_kernel b) else (destruct_kernel a) + | context [ ?a && ?b ] => + destruct_kernel a || destruct_kernel b + | context [ ?a || ?b ] => + destruct_kernel a || destruct_kernel b + | context [ xorb ?a ?b ] => + destruct_kernel a || destruct_kernel b + | context [ negb ?a ] => + destruct_kernel a + | _ => idtac + end + in + simpl_bools; + match goal with + | |- context [ ?a =? ?b ] => + fail_if_iffy a; fail_if_iffy b; bdestruct (a =? b); try subst + | |- context [ ?a + fail_if_iffy a; fail_if_iffy b; bdestruct (a + fail_if_iffy a; fail_if_iffy b; bdestruct (a <=? b) + | |- context [ if ?b then _ else _ ] => fail_if_iffy b; destruct b eqn:? + end; + simpl_bools. + +Ltac bdestruct_one ::= bdestruct_one_new || bdestruct_one_old. + +Ltac bdestructΩ'simp := + let tryeasylia := try easy; try lca; try lia in + tryeasylia; + repeat (bdestruct_one; subst; simpl_bools; simpl; tryeasylia); tryeasylia. + + +Ltac simplify_mods_one := + let __fail_if_has_mods a := + match a with + | context [ _ mod _ ] => fail 1 + | _ => idtac + end + in + match goal with + | |- context [ ?a mod ?b ] => + __fail_if_has_mods a; __fail_if_has_mods b; + simplify_mods_of a b + | H:context [ ?a mod ?b ] |- _ => + __fail_if_has_mods a; __fail_if_has_mods b; + simplify_mods_of a b + end. + +Ltac case_mods_one := + match goal with + | |- context [ ?a mod ?b ] => + bdestruct (a ( + match type of A with Matrix ?m' ?n' => + match type of B with Matrix ?n'' ?o'' => + let Hm' := fresh "Hm'" in let Hn' := fresh "Hn'" in + let Hn'' := fresh "Hn''" in let Ho'' := fresh "Hoo'" in + assert (Hm' : m = m') by lia; + assert (Hn' : n = n') by lia; + assert (Hn'' : n = n'') by lia; + assert (Ho' : o = o'') by lia; + replace (@Mmult m n o A B) with (@Mmult m' n' o A B) + by (first [try (rewrite Hm' at 1); try (rewrite Hn' at 1); reflexivity | f_equal; lia]); + apply WF_mult; [ + auto with wf_db | + apply WF_Matrix_dim_change; + auto with wf_db + ] + end end) : wf_db. \ No newline at end of file diff --git a/ViCaR/GeneralLemmas.v b/ViCaR/GeneralLemmas.v index 5886e3d..0541609 100644 --- a/ViCaR/GeneralLemmas.v +++ b/ViCaR/GeneralLemmas.v @@ -23,6 +23,87 @@ Proof. reflexivity. Qed. +Lemma stack_compose_distr_test : forall {C : Type} + `{Cat : Category C} `{MonCat : MonoidalCategory C} + {A B D M N P : C} (f : A ~> B) (g : B ~> D) + (h : M ~> N) (i : N ~> P), + (f ∘ g) ⊗ (h ∘ i) ≃ (f ⊗ h) ∘ (g ⊗ i). +Proof. + intros. + rewrite compose2_map. + easy. +Qed. + +Local Notation "A ⊗ B" := (morphism2_map (Bifunctor:=tensor) A B) (only printing). +Notation "'id_' A" := (c_identity A) (at level 10, no associativity). + +Lemma stack_distr_pushout_r_bot : forall {C : Type} + `{Cat : Category C} `{MonCat : MonoidalCategory C} + {a b d m n} (f : a ~> b) (g : b ~> d) (h : m ~> n), + f ∘ g ⊗ h ≃ f ⊗ h ∘ (g ⊗ (id_ n)). +Proof. + intros. + rewrite <- compose2_map, right_unit. + easy. +Qed. + +(* TODO: the other two; _l_bot and _l_top *) + +Lemma stack_distr_pushout_r_top : forall {C : Type} + `{Cat : Category C} `{MonCat : MonoidalCategory C} + {a b m n o} (f : a ~> b) (g : m ~> n) (h : n ~> o), + f ⊗ (g ∘ h) ≃ f ⊗ g ∘ (id_ b ⊗ h). +Proof. + intros. + rewrite <- compose2_map, right_unit. + easy. +Qed. + +Ltac fencestep := + let test_simple t := match t with + | context[(_ ⊗ _) ∘ _] => fail 2 + | context[_ ∘ (_ ⊗ _)] => fail 2 + | context[(_ ∘ _) ⊗ _] => fail 2 + | context[_ ⊗ (_ ∘ _)] => fail 2 + | _ => idtac + end + in + first [ match goal with + |- context[(?f ∘ ?g) ⊗ (?h ∘ ?i)] => + test_simple f; test_simple g; test_simple h; test_simple i; + rewrite (compose2_map f g h i) + end | match goal with + |- context[(?f) ⊗ (?g ∘ ?h)] => + test_simple f; test_simple g; test_simple h; + rewrite (stack_distr_pushout_r_top f g h) + end | match goal with + |- context[(?f ∘ ?g) ⊗ (?h)] => + test_simple f; test_simple g; test_simple h; + rewrite (stack_distr_pushout_r_bot f g h) + end]. + + + +(* Ltac fencepost := + + let rec process_term term := + match term with + | ?A ⊗ ?B => match A, B with + | ?A' ⊗ B', _ => process_term A + | ?f ∘ ?g (* TODO: test if these are "simple", in some sense. + Also see if I can come up with an (even informal) algorithm, or + even as much as an explicit spec (e.g., strict fenceposing may + be required to have functions go in descending order, so + f id id i id f + id g id id g id + id id h id , etc. is good, but not id id ... + id id id id id id + ). I suspect the strict spec will make reasoning + much easier, i.e. to process f⊗g, we *must* push it out + to f⊗id ∘ id⊗g. *) *) + + + Lemma nwire_stack_compose_topleft_general : forall {C : Type} {Cat : Category C} {MonCat : MonoidalCategory C} {topIn botIn topOut botOut : C} @@ -47,7 +128,145 @@ Proof. easy. Qed. -Definition cast {C : Type} `{Category C} (A B : C) {A' B' : C} +Require Import ExamplesAutomation. + +Lemma stack_id_compose_split_top : forall {C : Type} + {Cat: Category C} {MonCat : MonoidalCategory C} + {topIn topMid topOut bot : C} + (f0 : topIn ~> topMid) (f1 : topMid ~> topOut), + (f0 ∘ f1) ⊗ (id_ bot) ≃ f0 ⊗ id_ bot ∘ (f1 ⊗ id_ bot). +Proof. + intros. + rewrite <- compose2_map, left_unit. + easy. +Qed. + +Lemma stack_id_compose_split_bot : forall {C : Type} + {Cat: Category C} {MonCat : MonoidalCategory C} + {top botIn botMid botOut : C} + (f0 : botIn ~> botMid) (f1 : botMid ~> botOut), + (id_ top) ⊗ (f0 ∘ f1) ≃ id_ top ⊗ f0 ∘ (id_ top ⊗ f1). +Proof. + intros. + rewrite <- compose2_map, left_unit. + easy. +Qed. + +(* Ignore this stuff for now: *) + +Ltac _fencepost term := + match term with + | id_ ?A => idtac + | ?f ⊗ id_ _ => _fencepost f; idtac "fix:"; print_state + | id_ _ ⊗ ?f => _fencepost f; idtac "fix:"; print_state + | ?f ∘ ?g => idtac f "∘" g; _fencepost f; _fencepost g + | ?f ⊗ ?g => idtac f "⊗" g; match type of f with + | ?topIn ~> ?topOut => match type of g with + | ?botIn ~> ?botOut => rewrite <- (nwire_stackcompose_topright_general f g); + _fencepost (f ⊗ c_identity botIn); _fencepost (c_identity topOut ⊗ g) + end end + | ?f => idtac "should be clean:" f + end. + +Ltac __fencepost term := + match term with + | id_ ?A => idtac + | ?f ∘ ?g => (* idtac f "∘" g; *) __fencepost f; __fencepost g + | ?f ⊗ id_ _ => __fencepost f (* ; idtac "fix:"; print_state *) + | id_ _ ⊗ ?f => __fencepost f (* ; idtac "fix:"; print_state *) + | (?f ∘ ?g) ⊗ (?h ∘ ?i) => first [ + first [progress __fencepost f; rewrite ?assoc | + progress __fencepost g; rewrite ?assoc | + progress __fencepost h; rewrite ?assoc | + progress __fencepost i; rewrite ?assoc]; + idtac "hit1"; __fencepost term; idtac "hit2" | + rewrite (compose2_map f g h i); __fencepost (f⊗h); __fencepost (g⊗i)] + | (?f ∘ ?g) ⊗ ?h => first [ + progress __fencepost f; rewrite ?assoc | + progress __fencepost g; rewrite ?assoc | + progress __fencepost h; rewrite ?assoc | + rewrite (stack_distr_pushout_r_bot f g h)] + | ?f ⊗ (?g ∘ ?h) => first [ + progress __fencepost f; rewrite ?assoc | + progress __fencepost g; rewrite ?assoc | + progress __fencepost h; rewrite ?assoc | + rewrite (stack_distr_pushout_r_top f g h)] + | ?f ⊗ ?g => first [ + progress __fencepost f; rewrite ?assoc | + progress __fencepost g; rewrite ?assoc | + rewrite <- (nwire_stackcompose_topright_general f g)] + (* | ?f ⊗ ?g => idtac f "⊗" g; match type of f with + | ?topIn ~> ?topOut => match type of g with + | ?botIn ~> ?botOut => rewrite <- (nwire_stackcompose_topright_general f g); + __fencepost (f ⊗ c_identity botIn); __fencepost (c_identity topOut ⊗ g) + end end *) + | ?f => idtac "should be clean:" f + end. + +Ltac weak_fencepost term := + let fence f := progress (weak_fencepost f) in + match term with + | id_ ?A => idtac + | ?f ∘ ?g => (* idtac f "∘" g; *) weak_fencepost f; weak_fencepost g + (* | ?f ⊗ id_ _ => fence f (* ; idtac "fix:"; print_state *) + | id_ _ ⊗ ?f => fence f (* ; idtac "fix:"; print_state *) *) + | (?f ∘ ?g) ⊗ (?h ∘ ?i) => first [ + (* first [fence f | fence g | fence h | fence i]; + rewrite ?assoc; weak_fencepost term | *) + rewrite (compose2_map f g h i); + weak_fencepost (f⊗h); weak_fencepost (g⊗i)] + | (?f ∘ ?g) ⊗ ?h => first [ + (* first [fence f | fence g | fence h]; + rewrite ?assoc; weak_fencepost term | *) + rewrite (stack_distr_pushout_r_bot f g h)] + | ?f ⊗ (?g ∘ ?h) => first [ + (* first [fence f | fence g | fence h]; + idtac term; rewrite ?assoc; idtac term; weak_fencepost term | *) + rewrite (stack_distr_pushout_r_top f g h)] + | ?f ⊗ ?g => first [ + (* first [fence f | fence g]; + rewrite ?assoc; weak_fencepost term | *) + rewrite <- (nwire_stackcompose_topright_general f g)] + (* | ?f ⊗ ?g => idtac f "⊗" g; match type of f with + | ?topIn ~> ?topOut => match type of g with + | ?botIn ~> ?botOut => rewrite <- (nwire_stackcompose_topright_general f g); + __fencepost (f ⊗ c_identity botIn); __fencepost (c_identity topOut ⊗ g) + end end *) + | ?f => idtac "should be clean:" f + end. + +Ltac weak_fencepost' term := + match term with + | id_ ?A => idtac + | ?f ∘ ?g => (* idtac f "∘" g; *) weak_fencepost f; weak_fencepost g + (* | ?f ⊗ id_ _ => fence f (* ; idtac "fix:"; print_state *) + | id_ _ ⊗ ?f => fence f (* ; idtac "fix:"; print_state *) *) + | (?f ∘ ?g) ⊗ (?h ∘ ?i) => first [ + (* first [fence f | fence g | fence h | fence i]; + rewrite ?assoc; weak_fencepost term | *) + rewrite (compose2_map f g h i); + weak_fencepost (f⊗h); weak_fencepost (g⊗i)] + | (?f ∘ ?g) ⊗ ?h => first [ + (* first [fence f | fence g | fence h]; + rewrite ?assoc; weak_fencepost term | *) + rewrite (stack_distr_pushout_r_bot f g h)] + | ?f ⊗ (?g ∘ ?h) => first [ + (* first [fence f | fence g | fence h]; + idtac term; rewrite ?assoc; idtac term; weak_fencepost term | *) + rewrite (stack_distr_pushout_r_top f g h)] + | ?f ⊗ ?g => first [ + (* first [fence f | fence g]; + rewrite ?assoc; weak_fencepost term | *) + rewrite <- (nwire_stackcompose_topright_general f g)] + (* | ?f ⊗ ?g => idtac f "⊗" g; match type of f with + | ?topIn ~> ?topOut => match type of g with + | ?botIn ~> ?botOut => rewrite <- (nwire_stackcompose_topright_general f g); + __fencepost (f ⊗ c_identity botIn); __fencepost (c_identity topOut ⊗ g) + end end *) + | ?f => idtac "should be clean:" f + end. + +Definition cast_fn {C : Type} `{Category C} (A B : C) {A' B' : C} (prfA : A = A') (prfB : B = B') (f : A' ~> B') : A ~> B. Proof. destruct prfA. @@ -56,9 +275,9 @@ Proof. Defined. Add Parametric Morphism {C : Type} `{cC : Category C} {A B : C} {A' B' : C} - {prfA : A = A'} {prfB : B = B'} : (@cast C cC A B A' B' prfA prfB) + {prfA : A = A'} {prfB : B = B'} : (@cast_fn C cC A B A' B' prfA prfB) with signature - (@Category.equiv C cC A' B') ==> (@Category.equiv C cC A B) as cast_equiv_morphism. + (@Category.equiv C cC A' B') ==> (@Category.equiv C cC A B) as cast_fn_equiv_morphism. Proof. intros. subst. diff --git a/examples/DirectSumMatrixExample.v b/examples/DirectSumMatrixExample.v new file mode 100644 index 0000000..1c4da5d --- /dev/null +++ b/examples/DirectSumMatrixExample.v @@ -0,0 +1,666 @@ +Require Import MatrixExampleBase. +Require Import MatrixPermBase. +From ViCaR Require Import ExamplesAutomation. + +#[export] Instance MxCategory : Category nat := { + morphism := Matrix; + + equiv := @mat_equiv; (* fun m n => @eq (Matrix m n); *) + equiv_rel := @mat_equiv_equivalence; + + compose := @Mmult; + compose_compat := fun n m o f g Hfg h i Hhi => + @Mmult_simplify_mat_equiv n m o f g h i Hfg Hhi; + assoc := @Mmult_assoc_mat_equiv; + + c_identity n := I n; + left_unit := Mmult_1_l_mat_eq; + right_unit := Mmult_1_r_mat_eq; +}. + +Lemma direct_sum'_id_mat_equiv : forall n m, + direct_sum' (I n) (I m) ≡ I (n + m). +Proof. + intros n m. + intros i j Hi Hj. + unfold direct_sum', I. + bdestructΩ'simp. +Qed. + +Lemma direct_sum'_id : forall n m, + direct_sum' (I n) (I m) = I (n + m). +Proof. + intros n m. + rewrite <- mat_equiv_eq_iff by auto with wf_db. + apply direct_sum'_id_mat_equiv. +Qed. + +Lemma big_sum_split : forall {G : Type} {H : Monoid G} (n k : nat) (f : nat -> G), + big_sum f (n + k) = + (big_sum f n + big_sum (fun x : nat => f (n + x)%nat) k)%G. +Proof. + intros G H n. + induction k; intros f. + - simpl; rewrite Nat.add_0_r, Gplus_0_r; easy. + - rewrite Nat.add_succ_r, <- big_sum_extend_r, IHk, <- Gplus_assoc. + reflexivity. +Qed. + +Lemma direct_sum'_Mmult : forall {n m o p q r} + (A : Matrix n m) (B : Matrix m o) (C : Matrix p q) (D : Matrix q r), + direct_sum' (A × B) (C × D) ≡ direct_sum' A C × direct_sum' B D. +Proof. + intros n m o p q r A B C D. + intros i j Hi Hj. + symmetry. + unfold direct_sum', Mmult. + bdestruct (i {| + forward := (I (n + m + o) : Matrix (n + m + o) (n + (m +o))); + reverse := (I (n + m + o) : Matrix (n + (m +o)) (n + m + o)); + id_A := ltac:(simpl; rewrite Nat.add_assoc, Mmult_1_r_mat_eq; easy); + id_B := ltac:(simpl; rewrite Nat.add_assoc, Mmult_1_r_mat_eq; easy); + |}; + + left_unitor := fun n => {| + forward := (I n : Matrix (0 + n) n); + reverse := (I n : Matrix n (0 + n)); + id_A := ltac:(simpl; rewrite Mmult_1_r_mat_eq; easy); + id_B := ltac:(simpl; rewrite Mmult_1_r_mat_eq; easy); + |}; + + right_unitor := fun n => {| + forward := (I n : Matrix (n + 0) n); + reverse := (I n : Matrix n (n + 0)); + id_A := ltac:(rewrite Nat.add_0_r, Mmult_1_r_mat_eq; easy); + id_B := ltac:(rewrite Nat.add_0_r, Mmult_1_r_mat_eq; easy); + |}; + + associator_cohere := ltac:(intros; simpl in *; + rewrite 2!Nat.add_assoc, Mmult_1_l_mat_eq, + Mmult_1_r_mat_eq, <-2!Nat.add_assoc; rewrite (direct_sum'_assoc f g h); easy); + left_unitor_cohere := ltac:(intros; simpl; + rewrite direct_sum'_0_l, Mmult_1_l_mat_eq, Mmult_1_r_mat_eq; easy); + right_unitor_cohere := ltac:(intros; simpl; rewrite direct_sum'_0_r; + rewrite 2!Nat.add_0_r, Mmult_1_r_mat_eq, Mmult_1_l_mat_eq; easy); + + pentagon := ltac:(intros; simpl in *; restore_dims; + rewrite 4!Nat.add_assoc, Mmult_1_r_mat_eq, + 2!direct_sum'_id, 2!Mmult_1_l_mat_eq, 2!Nat.add_assoc; easy); + triangle := ltac:(intros; simpl in *; + rewrite Nat.add_0_r, direct_sum'_id, Mmult_1_r_mat_eq; easy); +}. + + + + +Notation mx_braiding := (fun n m => (perm_mat (n+m) (rotr (n+m) n) : Matrix (n+m) (m+n))%nat). + +Lemma mx_braiding_compose_inv : forall n m, + (mx_braiding n m) × (mx_braiding m n) ≡ I (n + m). +Proof. + intros n m. + simpl. + rewrite (Nat.add_comm m n). + rewrite perm_mat_Mmult by auto with perm_db. + cleanup_perm. + rewrite perm_mat_I by easy. + easy. +Qed. + + + + + +Lemma mx_braiding_braids_eq : forall n m o p (A : Matrix n m) (B : Matrix o p), + (direct_sum' A B × perm_mat (m + p) (rotr (m + p) m) + = perm_mat (n + o) (rotr (n + o) n) × direct_sum' B A). +Proof. + intros n m o p A B. + apply equal_on_basis_vectors_implies_equal; + [|rewrite Nat.add_comm, (Nat.add_comm m p) |]; auto with wf_db. + intros k Hk. + rewrite <- mat_equiv_eq_iff; + [| | apply WF_mult; [apply WF_mult|]]; auto with wf_db; + [|rewrite (Nat.add_comm m p), (Nat.add_comm n o); auto with wf_db]. + rewrite Mmult_assoc. + rewrite perm_mat_permutes_basis_vectors_r, basis_vector_eq_e_i by easy. + rewrite <- (matrix_by_basis _ _ Hk). + rewrite <- matrix_by_basis. + 2: { (* TODO: replace with 'apply permutation_is_bounded; auto with perm_db' *) + assert (Hp: (permutation (m+p) (rotr (m + p) m))%nat) by auto with perm_db. + destruct Hp as [finv Hfinv]. + destruct (Hfinv k Hk); easy. + } + intros x y Hx Hy; replace y with O by lia; clear y Hy. + unfold get_vec. + rewrite Nat.eqb_refl. + rewrite perm_mat_permutes_matrix_r by auto with perm_db. + rewrite perm_inv_of_rotr by easy. + rewrite rotl_eq_rotr_sub. + bdestruct (o =? 0); + [subst o; + rewrite Nat.add_0_r, Nat.mod_same, Nat.sub_0_r, rotr_n by lia| + rewrite Nat.mod_small by lia; + replace (n + o - n)%nat with o by lia]; + unfold direct_sum', rotr, rotl; simpl; + rewrite 1?Nat.add_0_r; + + [ replace_bool_lia (x Build_Isomorphism nat MxCategory (n+m)%nat (m+n)%nat + (mx_braiding n m) (mx_braiding m n) + (mx_braiding_compose_inv n m) (mx_braiding_compose_inv m n). + +#[export] Instance MxBraidingBiIsomorphism : + NaturalBiIsomorphism MxDirectSumBiFunctor (CommuteBifunctor MxDirectSumBiFunctor) := {| + component2_iso := MxBraidingIsomorphism; + component2_iso_natural := ltac:(intros; simpl in *; + restore_dims; rewrite mx_braiding_braids_eq; easy); +|}. + +Lemma if_mult_dist_r (b : bool) (z : C) : + (if b then C1 else C0) * z = + if b then z else C0. +Proof. + destruct b; lca. +Qed. + +Lemma if_mult_dist_l (b : bool) (z : C) : + z * (if b then C1 else C0) = + if b then z else C0. +Proof. + destruct b; lca. +Qed. + +Lemma if_mult_dist_r_gen (b : bool) (z x y : C) : + (if b then x else y) * z = + if b then x*z else y*z. +Proof. + destruct b; lca. +Qed. + +Lemma if_mult_dist_l_gen (b : bool) (z x y : C) : + z * (if b then x else y) = + if b then z*x else z*y. +Proof. + destruct b; lca. +Qed. + +Lemma if_mult_and (b c : bool) : + (if b then C1 else C0) * (if c then C1 else C0) = + if (b && c) then C1 else C0. +Proof. + destruct b; destruct c; lca. +Qed. + +Lemma if_if_if_combine {A : Type} : forall (x y : A) (b c d:bool), + (if b then if c then x else y + else if d then x else y) = + if (b&&c)||((¬b) && d) then x else y. +Proof. + intros. + bdestructΩ'. +Qed. + +(* Definition unshift {A} (f : nat -> A) (k : nat) (x : A) : nat -> A := + fun i => if i if (O if (i permutation m g -> + direct_sum' (perm_mat n f) (perm_mat m g) ≡ perm_mat (n+m) (stack_perms n m f g). +Proof. + intros n m f g Hf Hg. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + rewrite perm_mat_permutes_ei_r, direct_sum'_mul_vec_r by lia. + unfold make_WF, Mplus. + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + simpl_bools. + unfold make_WF, perm_mat, stack_perms, unshift_mx, shift, e_i, Mmult. + replace_bool_lia (k apply big_sum_unique + | |- _ = C0 => rewrite big_sum_0_bounded; [easy|]; intros; bdestructΩ'simp + end. + 4: { + rewrite Nat.add_sub in *; + lia. + } + - exists k; repeat split; try lia; intros; bdestructΩ'simp. + - destruct Hf as [? Hf]. + specialize (Hf k). + lia. + - exists (k - n)%nat; repeat split; intros; bdestructΩ'simp. +Qed. + +Lemma perm_mat_idn : forall n, + perm_mat n idn = I n. +Proof. + intros n. + apply perm_mat_I; easy. +Qed. + +Lemma direct_sum'_stack_perm_I_r : forall n m f, + permutation n f -> + direct_sum' (perm_mat n f) (I m) ≡ perm_mat (n+m) (stack_perms n m f idn). +Proof. + intros n m f Hf. + rewrite <- perm_mat_idn. + rewrite direct_sum'_stack_perms; (auto with perm_db). + easy. +Qed. + +Lemma direct_sum'_stack_perm_I_l : forall n m f, + permutation m f -> + direct_sum' (I n) (perm_mat m f) ≡ perm_mat (n+m) (stack_perms n m idn f). +Proof. + intros n m f Hf. + rewrite <- perm_mat_idn. + rewrite direct_sum'_stack_perms; (auto with perm_db). + easy. +Qed. + +Lemma mx_braiding_hexagon1: forall n m o (* M B A *), + ((direct_sum' (mx_braiding n m) (I o) × I (m + n + o) + × direct_sum' (I m) (mx_braiding n o))) + ≡ (I (n + m + o) × (mx_braiding n (m + o)%nat) × I (m + o + n)). +Proof. + intros n m o. + (* replace (n + (o+m))%nat with (m+o+n)%nat by lia. *) + rewrite 2!Mmult_1_r_mat_eq, Mmult_1_l_mat_eq. + rewrite (Nat.add_comm m n), (Nat.add_comm o n). + rewrite direct_sum'_stack_perm_I_r by auto with perm_db. + replace (n+m+o)%nat with (m+(n+o))%nat by lia. + replace (m+o+n)%nat with (m+(n+o))%nat by lia. + rewrite direct_sum'_stack_perm_I_l by auto with perm_db. + rewrite perm_mat_Mmult by auto with perm_db. + apply perm_mat_equiv_of_perm_eq'; [lia|]. + intros k Hk. + unfold stack_perms, rotr, Basics.compose; + replace_bool_lia (k if (k =? j) then _ else _); shelve. + bdestruct (k =? j). + - replace j with k by easy. + rewrite Nat.eqb_refl. + unshelve (instantiate (1:=_)). + refine (if (k if (k (f i) ⊤) p. +Proof. + intros. + rewrite (big_sum_func_distr f transpose); easy. +Qed. + + + +Definition kron_comm p q : Matrix (p*q) (p*q):= + @make_WF (p*q) (p*q) (fun s t => + (* have blocks H_ij, p by q of them, and each is q by p *) + let i := (s / q)%nat in let j := (t / p)%nat in + let k := (s mod q)%nat in let l := (t mod p) in + (* let k := (s - q * i)%nat in let l := (t - p * t)%nat in *) + if (i =? l) && (j =? k) then C1 else C0 + (* s/q =? t mod p /\ t/p =? s mod q *) +). + +Lemma WF_kron_comm p q : WF_Matrix (kron_comm p q). +Proof. unfold kron_comm; auto with wf_db. Qed. +#[export] Hint Resolve WF_kron_comm : wf_db. + +(* Lemma test_kron : kron_comm 2 3 = Matrix.Zero. +Proof. + apply mat_equiv_eq; unfold kron_comm; auto with wf_db. + print_LHS_matU. +*) + +Lemma kron_comm_transpose_mat_equiv : forall p q, + (kron_comm p q) ⊤ ≡ kron_comm q p. +Proof. + intros p q. + intros i j Hi Hj. + unfold kron_comm, transpose, make_WF. + rewrite andb_comm, Nat.mul_comm. + rewrite (andb_comm (_ =? _)). + easy. +Qed. + +Lemma kron_comm_transpose : forall p q, + (kron_comm p q) ⊤ = kron_comm q p. +Proof. + intros p q. + apply mat_equiv_eq; auto with wf_db. + 1: rewrite Nat.mul_comm; apply WF_kron_comm. + apply kron_comm_transpose_mat_equiv. +Qed. + +Lemma kron_comm_1_r_mat_equiv : forall p, + (kron_comm p 1) ≡ Matrix.I p. +Proof. + intros p. + intros s t Hs Ht. + unfold kron_comm. + unfold make_WF. + unfold Matrix.I. + rewrite Nat.mul_1_r, Nat.div_1_r, Nat.mod_1_r, Nat.div_small, Nat.mod_small by lia. + bdestructΩ'. +Qed. + +Lemma kron_comm_1_r : forall p, + (kron_comm p 1) = Matrix.I p. +Proof. + intros p. + apply mat_equiv_eq; [|rewrite 1?Nat.mul_1_r|]; auto with wf_db. + apply kron_comm_1_r_mat_equiv. +Qed. + +Lemma kron_comm_1_l_mat_equiv : forall p, + (kron_comm 1 p) ≡ Matrix.I p. +Proof. + intros p. + intros s t Hs Ht. + unfold kron_comm. + unfold make_WF. + unfold Matrix.I. + rewrite Nat.mul_1_l, Nat.div_1_r, Nat.mod_1_r, Nat.div_small, Nat.mod_small by lia. + bdestructΩ'. +Qed. + +Lemma kron_comm_1_l : forall p, + (kron_comm 1 p) = Matrix.I p. +Proof. + intros p. + apply mat_equiv_eq; [|rewrite 1?Nat.mul_1_l|]; auto with wf_db. + apply kron_comm_1_l_mat_equiv. +Qed. + +Definition mx_to_vec {n m} (A : Matrix n m) : Vector (n*m) := + make_WF (fun i j => A (i mod n)%nat (i / n)%nat + (* Note: goes columnwise. Rowwise would be: + make_WF (fun i j => A (i / m)%nat (i mod n)%nat + *) +). + +Lemma WF_mx_to_vec {n m} (A : Matrix n m) : WF_Matrix (mx_to_vec A). +Proof. unfold mx_to_vec; auto with wf_db. Qed. +#[export] Hint Resolve WF_mx_to_vec : wf_db. + +(* Compute vec_to_list (mx_to_vec (Matrix.I 2)). *) +From Coq Require Import ZArith. +Ltac Zify.zify_post_hook ::= PreOmega.Z.div_mod_to_equations. + +Lemma kron_comm_mx_to_vec_helper : forall i p q, (i < p * q)%nat -> + (p * (i mod q) + i / q < p * q)%nat. +Proof. + intros i p q. + intros Hi. + assert (i / q < p)%nat by (apply Nat.div_lt_upper_bound; lia). + destruct p; [lia|]; + destruct q; [lia|]. + enough (S p * (i mod S q) <= S p * q)%nat by lia. + apply Nat.mul_le_mono; [lia | ]. + pose proof (Nat.mod_upper_bound i (S q) ltac:(easy)). + lia. +Qed. + +Lemma mx_to_vec_additive_mat_equiv {n m} (A B : Matrix n m) : + mx_to_vec (A .+ B) ≡ mx_to_vec A .+ mx_to_vec B. +Proof. + intros i j Hi Hj. + replace j with O by lia; clear dependent j. + unfold mx_to_vec, make_WF, Mplus. + bdestructΩ'. +Qed. + +Lemma mx_to_vec_additive {n m} (A B : Matrix n m) : + mx_to_vec (A .+ B) = mx_to_vec A .+ mx_to_vec B. +Proof. + apply mat_equiv_eq; auto with wf_db. + apply mx_to_vec_additive_mat_equiv. +Qed. + +Lemma if_mult_dist_r (b : bool) (z : C) : + (if b then C1 else C0) * z = + if b then z else C0. +Proof. + destruct b; lca. +Qed. + +Lemma if_mult_dist_l (b : bool) (z : C) : + z * (if b then C1 else C0) = + if b then z else C0. +Proof. + destruct b; lca. +Qed. + +Lemma if_mult_and (b c : bool) : + (if b then C1 else C0) * (if c then C1 else C0) = + if (b && c) then C1 else C0. +Proof. + destruct b; destruct c; lca. +Qed. + +Lemma kron_comm_mx_to_vec_mat_equiv : forall p q (A : Matrix p q), + kron_comm p q × mx_to_vec A ≡ mx_to_vec (A ⊤). +Proof. + intros p q A. + intros i j Hi Hj. + replace j with O by lia; clear dependent j. + unfold transpose, mx_to_vec, kron_comm, make_WF, Mmult. + rewrite (Nat.mul_comm q p). + replace_bool_lia (i . + destruct p; [lia|]. + destruct q; [lia|]. + split. + + rewrite Nat.add_comm, Nat.mul_comm. + rewrite Nat.mod_add by easy. + rewrite Nat.mod_small; [lia|]. + apply Nat.div_lt_upper_bound; lia. + + rewrite Nat.mul_comm, Nat.div_add_l by easy. + rewrite Nat.div_small; [lia|]. + apply Nat.div_lt_upper_bound; lia. + - intros [Hmodp Hdivp]. + rewrite (Nat.div_mod_eq k p). + lia. + } + apply big_sum_unique. + exists (p * (i mod q) + i / q)%nat; repeat split; + [apply kron_comm_mx_to_vec_helper; easy | rewrite Nat.eqb_refl | intros; bdestructΩ'simp]. + destruct p; [lia|]; + destruct q; [lia|]. + f_equal. + - rewrite Nat.add_comm, Nat.mul_comm, Nat.mod_add, Nat.mod_small; try easy. + apply Nat.div_lt_upper_bound; lia. + - rewrite Nat.mul_comm, Nat.div_add_l by easy. + rewrite Nat.div_small; [lia|]. + apply Nat.div_lt_upper_bound; lia. +Qed. + +Lemma kron_comm_mx_to_vec : forall p q (A : Matrix p q), + kron_comm p q × mx_to_vec A = mx_to_vec (A ⊤). +Proof. + intros p q A. + apply mat_equiv_eq; [|rewrite Nat.mul_comm|]; auto with wf_db. + apply kron_comm_mx_to_vec_mat_equiv. +Qed. + +Lemma kron_comm_ei_kron_ei_sum_mat_equiv : forall p q, + kron_comm p q ≡ + big_sum (G:=Square (p*q)) (fun i => big_sum (G:=Square (p*q)) (fun j => + (@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤)) + q) p. +Proof. + intros p q. + intros i j Hi Hj. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros l Hl. + unfold Mmult, kron, transpose, e_i. + erewrite big_sum_eq_bounded. + 2: { + intros m Hm. + (* replace m with O by lia. *) + rewrite Nat.div_1_r, Nat.mod_1_r. + replace_bool_lia (m =? 0) true; rewrite 4!andb_true_r. + rewrite 3!if_mult_and. + match goal with + |- context[if ?b then _ else _] => + replace b with ((i =? k * q + l) && (j =? l * p + k)) + end. + 1: reflexivity. (* set our new function *) + clear dependent m. + rewrite eq_iff_eq_true, 8!andb_true_iff, + 6!Nat.eqb_eq, 4!Nat.ltb_lt. + split. + - intros [Hieq Hjeq]. + subst i j. + rewrite 2!Nat.div_add_l, Nat.div_small, Nat.add_0_r by lia. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small, + Nat.div_small, Nat.add_0_r by lia. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by lia. + easy. + - intros [[[] []] [[] []]]. + split. + + rewrite (Nat.div_mod_eq i q) by lia; lia. + + rewrite (Nat.div_mod_eq j p) by lia; lia. + } + simpl; rewrite Cplus_0_l. + reflexivity. + } + apply big_sum_unique. + exists (i mod q). + split; [|split]. + - apply Nat.mod_upper_bound; lia. + - reflexivity. + - intros l Hl Hnmod. + bdestructΩ'simp. + exfalso; apply Hnmod. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by lia; lia. + } + symmetry. + apply big_sum_unique. + exists (j mod p). + repeat split. + - apply Nat.mod_upper_bound; lia. + - unfold kron_comm, make_WF. + replace_bool_lia (i + enough (H: b = c) by (rewrite H; easy) + end. + rewrite eq_iff_eq_true, 2!andb_true_iff, 4!Nat.eqb_eq. + split. + + intros [Hieq Hjeq]. + split; [rewrite Hieq | rewrite Hjeq]; + rewrite Hieq, Nat.div_add_l by lia; + (rewrite Nat.div_small; [lia|]); + apply Nat.mod_upper_bound; lia. + + intros [Hidiv Hjdiv]. + rewrite (Nat.div_mod_eq i q) at 1 by lia. + rewrite (Nat.div_mod_eq j p) at 2 by lia. + lia. + - intros k Hk Hkmod. + bdestructΩ'simp. + exfalso; apply Hkmod. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by lia; lia. +Qed. + +Lemma kron_comm_ei_kron_ei_sum : forall p q, + kron_comm p q = + big_sum (G:=Square (p*q)) (fun i => big_sum (G:=Square (p*q)) (fun j => + (@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤)) + q) p. +Proof. + intros p q. + apply mat_equiv_eq; auto with wf_db. + 1: apply WF_Msum; intros; apply WF_Msum; intros; + rewrite Nat.mul_comm; apply WF_mult; + auto with wf_db; rewrite Nat.mul_comm; + auto with wf_db. + apply kron_comm_ei_kron_ei_sum_mat_equiv. +Qed. + +Lemma kron_comm_ei_kron_ei_sum'_mat_equiv : forall p q, + kron_comm p q ≡ + big_sum (G:=Square (p*q)) (fun ij => + let i := (ij / q)%nat in let j := (ij mod q) in + ((@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤))) (p*q). +Proof. + intros p q. + rewrite kron_comm_ei_kron_ei_sum, big_sum_double_sum, Nat.mul_comm. + reflexivity. +Qed. + +(* TODO: put somewhere sensible *) +Lemma big_sum_mat_equiv_bounded : forall {o p} (f g : nat -> Matrix o p) (n : nat), + (forall x : nat, (x < n)%nat -> f x ≡ g x) -> big_sum f n ≡ big_sum g n. +Proof. + intros. + induction n. + - easy. + - simpl. + rewrite IHn, H; [easy|lia|auto]. +Qed. + +Lemma kron_comm_Hij_sum_mat_equiv : forall p q, + kron_comm p q ≡ + big_sum (G:=Square (p*q)) (fun i => big_sum (G:=Square (p*q)) (fun j => + @kron p q q p (@e_i p i × ((@e_i q j) ⊤)) + ((@Mmult p 1 q (@e_i p i) (((@e_i q j) ⊤))) ⊤)) q) p. +Proof. + intros p q. + rewrite kron_comm_ei_kron_ei_sum_mat_equiv. + apply big_sum_mat_equiv_bounded; intros i Hi. + apply big_sum_mat_equiv_bounded; intros j Hj. + rewrite kron_transpose, kron_mixed_product. + rewrite Mmult_transpose, transpose_involutive. + easy. +Qed. + +Lemma kron_comm_Hij_sum : forall p q, + kron_comm p q = + big_sum (G:=Square (p*q)) (fun i => big_sum (G:=Square (p*q)) (fun j => + @kron p q q p (@e_i p i × ((@e_i q j) ⊤)) + ((@Mmult p 1 q (@e_i p i) (((@e_i q j) ⊤))) ⊤)) q) p. +Proof. + intros p q. + apply mat_equiv_eq; [auto with wf_db| | ]. + - apply WF_Msum; intros i Hi. + apply WF_Msum; intros j Hj. + apply WF_kron; try lia; + [| apply WF_transpose]; + auto with wf_db. + - apply kron_comm_Hij_sum_mat_equiv. +Qed. + + +Lemma kron_comm_ei_kron_ei_sum' : forall p q, + kron_comm p q = + big_sum (G:=Square (p*q)) (fun ij => + let i := (ij / q)%nat in let j := (ij mod q) in + ((@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤))) (p*q). +Proof. + intros p q. + rewrite kron_comm_ei_kron_ei_sum, big_sum_double_sum, Nat.mul_comm. + reflexivity. +Qed. + +Local Notation H := (fun i j => e_i i × (e_i j)⊤). + +Lemma kron_comm_Hij_sum'_mat_equiv : forall p q, + kron_comm p q ≡ + big_sum (G:=Square (p*q)) ( fun ij => + let i := (ij / q)%nat in let j := (ij mod q) in + @kron p q q p (H i j) + ((H i j) ⊤)) (p*q). +Proof. + intros p q. + rewrite kron_comm_Hij_sum_mat_equiv, big_sum_double_sum, Nat.mul_comm. + easy. +Qed. + +Lemma kron_comm_Hij_sum' : forall p q, + kron_comm p q = + big_sum (G:=Square (p*q)) ( fun ij => + let i := (ij / q)%nat in let j := (ij mod q) in + @kron p q q p (H i j) + ((H i j) ⊤)) (p*q). +Proof. + intros p q. + rewrite kron_comm_Hij_sum, big_sum_double_sum, Nat.mul_comm. + easy. +Qed. + + +Lemma div_eq_iff : forall a b c, b <> O -> + (a / b)%nat = c <-> (b * c <= a /\ a < b * (S c))%nat. +Proof. + intros a b c Hb. + split. + intros Hadivb. + split; + subst c. + etransitivity; [ + apply Nat.div_mul_le, Hb |]. + rewrite Nat.mul_comm, Nat.div_mul; easy. + apply Nat.mul_succ_div_gt, Hb. + intros [Hge Hlt]. + symmetry. + apply (Nat.div_unique _ _ _ (a - b*c)); [lia|]. + lia. +Qed. + +Lemma kron_e_i_transpose_l : forall k n m o (A : Matrix m o), (k < n)%nat -> + (o <> O) -> (m <> O) -> + (@e_i n k)⊤ ⊗ A = (fun i j => + if (i + (o <> O) -> (m <> O) -> + (@e_i n k)⊤ ⊗ A ≡ (fun i j => + if (i + (@e_i n k)⊤ ⊗ A ≡ (fun i j => + if (i + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A = (fun i j => + if (j + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A ≡ (fun i j => + if (j + (@e_i n k) ⊗ A ≡ (fun i j => + if (j + (o <> O) -> (m <> O) -> + (@e_i n k)⊤ ⊗ A = (fun i j => + if (i ((j/o)%nat=k)) by lia; + rewrite Hrw; clear Hrw. + symmetry. + rewrite div_eq_iff by lia. + lia. + - replace (i / m =? 0) with false. + rewrite andb_false_r; easy. + symmetry. + rewrite Nat.eqb_neq. + rewrite Nat.div_small_iff; lia. +Qed. + +Lemma kron_e_i_transpose_l'_mat_equiv : forall k n m o (A : Matrix m o), (k < n)%nat -> + (o <> O) -> (m <> O) -> + (@e_i n k)⊤ ⊗ A ≡ (fun i j => + if (i + (@e_i n k)⊤ ⊗ A ≡ (fun i j => + if (i + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A = (fun i j => + if (j ((i/m)%nat=k)) by lia; + rewrite Hrw; clear Hrw. + symmetry. + rewrite div_eq_iff by lia. + lia. + - replace (j / o =? 0) with false. + rewrite andb_false_r; easy. + symmetry. + rewrite Nat.eqb_neq. + rewrite Nat.div_small_iff; lia. +Qed. + +Lemma kron_e_i_l'_mat_equiv : forall k n m o (A : Matrix m o), (k < n)%nat -> + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A ≡ (fun i j => + if (j + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A ≡ (fun i j => + if (j + (o <> O) -> (m <> O) -> + A ⊗ (@e_i n k) = (fun i j => + if (i mod n =? k) then A (i / n)%nat j else 0). +Proof. + intros k n m o A Hk Ho Hm. + apply functional_extensionality; intros i; + apply functional_extensionality; intros j. + unfold kron, e_i. + rewrite if_mult_dist_l, Nat.div_1_r. + rewrite Nat.mod_1_r, Nat.eqb_refl, andb_true_r. + replace (i mod n + (o <> O) -> (m <> O) -> + A ⊗ (@e_i n k) ≡ (fun i j => + if (i mod n =? k) then A (i / n)%nat j else 0). +Proof. + intros. + rewrite kron_e_i_r; easy. +Qed. + +Lemma kron_e_i_r_mat_equiv' : forall k n m o (A : Matrix m o), (k < n)%nat -> + A ⊗ (@e_i n k) ≡ (fun i j => + if (i mod n =? k) then A (i / n)%nat j else 0). +Proof. + intros. + destruct m; [|destruct o]; + try (intros i j Hi Hj; lia). + rewrite kron_e_i_r; easy. +Qed. + +Lemma kron_e_i_transpose_r : forall k n m o (A : Matrix m o), (k < n)%nat -> + (o <> O) -> (m <> O) -> + A ⊗ (@e_i n k) ⊤ = (fun i j => + if (j mod n =? k) then A i (j / n)%nat else 0). +Proof. + intros k n m o A Hk Ho Hm. + apply functional_extensionality; intros i; + apply functional_extensionality; intros j. + unfold kron, transpose, e_i. + rewrite if_mult_dist_l, Nat.div_1_r. + rewrite Nat.mod_1_r, Nat.eqb_refl, andb_true_r. + replace (j mod n + (o <> O) -> (m <> O) -> + A ⊗ (@e_i n k) ⊤ ≡ (fun i j => + if (j mod n =? k) then A i (j / n)%nat else 0). +Proof. + intros. + rewrite kron_e_i_transpose_r; easy. +Qed. + +Lemma kron_e_i_transpose_r_mat_equiv' : forall k n m o (A : Matrix m o), (k < n)%nat -> + A ⊗ (@e_i n k) ⊤ ≡ (fun i j => + if (j mod n =? k) then A i (j / n)%nat else 0). +Proof. + intros. + destruct m; [|destruct o]; + try (intros i j Hi Hj; lia). + rewrite kron_e_i_transpose_r; easy. +Qed. + +Lemma ei_kron_I_kron_ei : forall m n k, (k < n)%nat -> m <> O -> + (@e_i n k) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n k) = + (fun i j => if (i mod n =? k) && (j / m =? k)%nat + && (i / n =? j - k * m) && (i / n m <> O -> + (@e_i n k) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n k) ≡ + (fun i j => if (i mod n =? k) && (j / m =? k)%nat + && (i / n =? j - k * m) && (i / n + (@e_i n k) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n k) ≡ + (fun i j => if (i mod n =? k) && (j / m =? k)%nat + && (i / n =? j - k * m) && (i / n + (@e_i n j) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n j)) n. +Proof. + intros m n. + intros i j Hi Hj. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros ij Hij. + rewrite ei_kron_I_kron_ei by lia. + reflexivity. + } + unfold kron_comm, make_WF. + replace_bool_lia (i + (@e_i n j) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n j)) n. +Proof. + intros m n. + apply mat_equiv_eq; auto with wf_db. + 1: apply WF_Msum; intros; apply WF_kron; auto with wf_db arith. + apply kron_comm_kron_form_sum_mat_equiv; easy. +Qed. + +Lemma kron_comm_kron_form_sum' : forall m n, + kron_comm m n = big_sum (G:=Square (m*n)) (fun i => + (@e_i m i) ⊗ (Matrix.I n) ⊗ (@e_i m i)⊤) m. +Proof. + intros. + rewrite <- (kron_comm_transpose n m). + rewrite (kron_comm_kron_form_sum n m). + rewrite Msum_transpose. + apply big_sum_eq_bounded. + intros k Hk. + rewrite Nat.mul_1_l. + pose proof (kron_transpose _ _ _ _ ((@e_i m k) ⊤ ⊗ Matrix.I n) (@e_i m k)) as H; + rewrite Nat.mul_1_l, Nat.mul_1_r in H; + rewrite (Nat.mul_comm n m), H in *; clear H. + pose proof (kron_transpose _ _ _ _ ((@e_i m k) ⊤) (Matrix.I n)) as H; + rewrite Nat.mul_1_l in H; + rewrite H; clear H. + rewrite transpose_involutive, id_transpose_eq; easy. +Qed. + +Lemma kron_comm_kron_form_sum'_mat_equiv : forall m n, + kron_comm m n ≡ big_sum (G:=Square (m*n)) (fun i => + (@e_i m i) ⊗ (Matrix.I n) ⊗ (@e_i m i)⊤) m. +Proof. + intros. + rewrite kron_comm_kron_form_sum'; easy. +Qed. + +Lemma e_i_dot_is_component_mat_equiv : forall p k (x : Vector p), + (k < p)%nat -> + (@e_i p k) ⊤ × x ≡ x k O .* Matrix.I 1. +Proof. + intros p k x Hk. + intros i j Hi Hj; + replace i with O by lia; + replace j with O by lia; + clear i Hi; + clear j Hj. + unfold Mmult, transpose, scale, e_i, Matrix.I. + simpl_bools. + rewrite Cmult_1_r. + apply big_sum_unique. + exists k. + split; [easy|]. + bdestructΩ'simp. + rewrite Cmult_1_l. + split; [easy|]. + intros l Hl Hkl. + bdestructΩ'simp. +Qed. + +Lemma e_i_dot_is_component : forall p k (x : Vector p), + (k < p)%nat -> WF_Matrix x -> + (@e_i p k) ⊤ × x = x k O .* Matrix.I 1. +Proof. + intros p k x Hk HWF. + apply mat_equiv_eq; auto with wf_db. + apply e_i_dot_is_component_mat_equiv; easy. +Qed. + +Lemma kron_e_i_e_i : forall p q k l, + (k < p)%nat -> (l < q)%nat -> + @e_i q l ⊗ @e_i p k = @e_i (p*q) (l*p + k). +Proof. + intros p q k l Hk Hl. + apply functional_extensionality; intro i. + apply functional_extensionality; intro j. + unfold kron, e_i. + rewrite Nat.mod_1_r, Nat.div_1_r. + rewrite if_mult_and. + lazymatch goal with + |- (if ?b then _ else _) = (if ?c then _ else _) => + enough (H : b = c) by (rewrite H; easy) + end. + rewrite Nat.eqb_refl, andb_true_r. + destruct (j =? 0); [|rewrite 2!andb_false_r; easy]. + rewrite 2!andb_true_r. + rewrite eq_iff_eq_true, 4!andb_true_iff, 3!Nat.eqb_eq, 3!Nat.ltb_lt. + split. + - intros [[] []]. + rewrite (Nat.div_mod_eq i p). + split; nia. + - intros []. + subst i. + rewrite Nat.div_add_l, Nat.div_small, Nat.add_0_r, + Nat.add_comm, Nat.mod_add, Nat.mod_small by lia. + easy. +Qed. + +Lemma kron_e_i_e_i_mat_equiv : forall p q k l, + (k < p)%nat -> (l < q)%nat -> + @e_i q l ⊗ @e_i p k ≡ @e_i (p*q) (l*p + k). +Proof. + intros p q k l; intros. + rewrite (kron_e_i_e_i p q); easy. +Qed. + +Lemma kron_eq_sum_mat_equiv : forall p q (x : Vector q) (y : Vector p), + y ⊗ x ≡ big_sum (fun ij => + let i := (ij / q)%nat in let j := ij mod q in + (x j O * y i O) .* (@e_i p i ⊗ @e_i q j)) (p * q). +Proof. + intros p q x y. + erewrite big_sum_eq_bounded. + 2: { + intros ij Hij. + simpl. + rewrite (@kron_e_i_e_i q p) by + (try apply Nat.mod_upper_bound; try apply Nat.div_lt_upper_bound; lia). + rewrite (Nat.mul_comm (ij / q) q). + rewrite <- (Nat.div_mod_eq ij q). + reflexivity. + } + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + simpl. + rewrite Msum_Csum. + symmetry. + apply big_sum_unique. + exists i. + split; [lia|]. + unfold e_i; split. + - unfold scale, kron; bdestructΩ'simp. + - intros j Hj Hij. + unfold scale, kron; bdestructΩ'simp. +Qed. + +Lemma kron_eq_sum : forall p q (x : Vector q) (y : Vector p), + WF_Matrix x -> WF_Matrix y -> + y ⊗ x = big_sum (fun ij => + let i := (ij / q)%nat in let j := ij mod q in + (x j O * y i O) .* (@e_i p i ⊗ @e_i q j)) (p * q). +Proof. + intros p q x y Hwfx Hwfy. + apply mat_equiv_eq; [| |]; auto with wf_db. + apply kron_eq_sum_mat_equiv. +Qed. + +Lemma kron_comm_commutes_vectors_l_mat_equiv : forall p q (x : Vector q) (y : Vector p), + kron_comm p q × (x ⊗ y) ≡ (y ⊗ x). +Proof. + intros p q x y. + rewrite kron_comm_ei_kron_ei_sum'_mat_equiv, Mmult_Msum_distr_r. + rewrite (big_sum_mat_equiv_bounded _ + (fun k => x (k mod q) 0 * y (k / q) 0 .* (e_i (k / q) ⊗ e_i (k mod q)))%nat); + [rewrite <- kron_eq_sum_mat_equiv; easy|]. + intros k Hk. + simpl. + match goal with + |- (?A × ?B) × ?C ≡ _ => + assert (Hassoc: (A × B) × C = A × (B × C)) by apply Mmult_assoc + end. + simpl in Hassoc. + rewrite (Nat.mul_comm q p) in *. + rewrite Hassoc. clear Hassoc. + pose proof (kron_transpose _ _ _ _ (@e_i q (k mod q)) (@e_i p (k / q))) as Hrw; + rewrite (Nat.mul_comm q p) in Hrw; + simpl in Hrw; rewrite Hrw; clear Hrw. + pose proof (kron_mixed_product ((e_i (k mod q)) ⊤) ((e_i (k / q)) ⊤) x y) as Hrw; + rewrite (Nat.mul_comm q p) in Hrw; + simpl in Hrw; rewrite Hrw; clear Hrw. + rewrite 2!(e_i_dot_is_component_mat_equiv); + [ | apply Nat.div_lt_upper_bound; lia | + apply Nat.mod_upper_bound; lia]. + rewrite Mscale_kron_dist_l, Mscale_kron_dist_r, Mscale_assoc. + rewrite kron_1_l, Mscale_mult_dist_r, Mmult_1_r by auto with wf_db. + reflexivity. +Qed. + +Lemma kron_comm_commutes_vectors_l : forall p q (x : Vector q) (y : Vector p), + WF_Matrix x -> WF_Matrix y -> + kron_comm p q × (x ⊗ y) = (y ⊗ x). +Proof. + intros p q x y Hwfx Hwfy. + apply mat_equiv_eq; [apply WF_mult; restore_dims| |]; auto with wf_db. + apply kron_comm_commutes_vectors_l_mat_equiv. +Qed. + +Lemma kron_basis_vector_basis_vector : forall p q k l, + (k < p)%nat -> (l < q)%nat -> + basis_vector q l ⊗ basis_vector p k = basis_vector (p*q) (l*p + k). +Proof. + intros p q k l Hk Hl. + apply functional_extensionality; intros i. + apply functional_extensionality; intros j. + unfold kron, basis_vector. + rewrite Nat.mod_1_r, Nat.div_1_r, Nat.eqb_refl, andb_true_r, if_mult_and. + pose proof (Nat.div_mod_eq i p). + bdestructΩ'simp. + rewrite Nat.div_add_l, Nat.div_small in * by lia. + lia. +Qed. + +Lemma kron_basis_vector_basis_vector_mat_equiv : forall p q k l, + (k < p)%nat -> (l < q)%nat -> + basis_vector q l ⊗ basis_vector p k ≡ basis_vector (p*q) (l*p + k). +Proof. + intros. + rewrite (kron_basis_vector_basis_vector p q); easy. +Qed. + +Lemma kron_extensionality_mat_equiv : forall n m s t (A B : Matrix (n*m) (s*t)), + (forall (x : Vector s) (y :Vector t), + A × (x ⊗ y) ≡ B × (x ⊗ y)) -> + A ≡ B. +Proof. + intros n m s t A B Hext. + apply mat_equiv_of_equiv_on_ei. + intros i Hi. + + pose proof (Nat.div_lt_upper_bound i t s ltac:(lia) ltac:(lia)). + pose proof (Nat.mod_upper_bound i s ltac:(lia)). + pose proof (Nat.mod_upper_bound i t ltac:(lia)). + + specialize (Hext (@e_i s (i / t)) (@e_i t (i mod t))). + rewrite (kron_e_i_e_i_mat_equiv t s) in Hext by lia. + (* simpl in Hext. *) + rewrite (Nat.mul_comm (i/t) t), <- (Nat.div_mod_eq i t) in Hext. + rewrite (Nat.mul_comm t s) in Hext. easy. +Qed. + +Lemma kron_extensionality : forall n m s t (A B : Matrix (n*m) (s*t)), + WF_Matrix A -> WF_Matrix B -> + (forall (x : Vector s) (y :Vector t), + WF_Matrix x -> WF_Matrix y -> + A × (x ⊗ y) = B × (x ⊗ y)) -> + A = B. +Proof. + intros n m s t A B HwfA HwfB Hext. + apply equal_on_basis_vectors_implies_equal; try easy. + intros i Hi. + + pose proof (Nat.div_lt_upper_bound i t s ltac:(lia) ltac:(lia)). + pose proof (Nat.mod_upper_bound i s ltac:(lia)). + pose proof (Nat.mod_upper_bound i t ltac:(lia)). + + specialize (Hext (basis_vector s (i / t)) (basis_vector t (i mod t)) + ltac:(apply basis_vector_WF; easy) + ltac:(apply basis_vector_WF; easy) + ). + rewrite (kron_basis_vector_basis_vector t s) in Hext by lia. + + simpl in Hext. + rewrite (Nat.mul_comm (i/t) t), <- (Nat.div_mod_eq i t) in Hext. + rewrite (Nat.mul_comm t s) in Hext. easy. +Qed. + +Lemma kron_comm_commutes_mat_equiv : forall n s m t + (A : Matrix n s) (B : Matrix m t), + kron_comm m n × (A ⊗ B) × (kron_comm s t) ≡ (B ⊗ A). +Proof. + intros n s m t A B. + rewrite (Nat.mul_comm s t). + apply (kron_extensionality_mat_equiv _ _ t s). + intros x y. + (* simpl. *) + (* Search "assoc" in Matrix. *) + rewrite (Mmult_assoc (_ × _)). + rewrite (Nat.mul_comm t s). + rewrite kron_comm_commutes_vectors_l_mat_equiv. + rewrite Mmult_assoc, (Nat.mul_comm m n). + rewrite kron_mixed_product. + rewrite (Nat.mul_comm n m), kron_comm_commutes_vectors_l_mat_equiv. + rewrite <- kron_mixed_product. + rewrite (Nat.mul_comm t s). + easy. +Qed. + +Lemma kron_comm_commutes : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + kron_comm m n × (A ⊗ B) × (kron_comm s t) = (B ⊗ A). +Proof. + intros n s m t A B HwfA HwfB. + apply (kron_extensionality _ _ t s); [| + apply WF_kron; try easy; lia |]. + rewrite (Nat.mul_comm t s); apply WF_mult; auto with wf_db; + apply WF_mult; auto with wf_db; + rewrite Nat.mul_comm; auto with wf_db. + (* rewrite Nat.mul_comm; apply WF_mult; [rewrite Nat.mul_comm|auto with wf_db]; + apply WF_mult; auto with wf_db; rewrite Nat.mul_comm; auto with wf_db. *) + intros x y Hwfx Hwfy. + (* simpl. *) + (* Search "assoc" in Matrix. *) + rewrite (Nat.mul_comm s t). + rewrite (Mmult_assoc (_ × _)). + rewrite (Nat.mul_comm t s). + rewrite kron_comm_commutes_vectors_l by easy. + rewrite Mmult_assoc, (Nat.mul_comm m n). + rewrite kron_mixed_product. + rewrite (Nat.mul_comm n m), kron_comm_commutes_vectors_l by (auto with wf_db). + rewrite <- kron_mixed_product. + f_equal; lia. +Qed. + +Lemma commute_kron_mat_equiv : forall n s m t + (A : Matrix n s) (B : Matrix m t), + (A ⊗ B) ≡ kron_comm n m × (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B i j Hi Hj. + rewrite (kron_comm_commutes_mat_equiv m t n s B A); try easy; lia. +Qed. + + +Lemma commute_kron : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + (A ⊗ B) = kron_comm n m × (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B HA HB. + rewrite (kron_comm_commutes m t n s B A HB HA); easy. +Qed. + +Lemma kron_comm_mul_inv_mat_equiv : forall p q, + kron_comm p q × kron_comm q p ≡ Matrix.I _. +Proof. + intros p q. + intros i j Hi Hj. + unfold Mmult, kron_comm, make_WF. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + rewrite <- 2!andb_if, if_mult_and. + replace_bool_lia (k + replace b with ((i =? j) && (k =? (i mod q) * p + (j/q))) + end; + [reflexivity|]. + rewrite eq_iff_eq_true, 4!andb_true_iff, 6!Nat.eqb_eq. + split. + - intros [? ?]; subst. + destruct p; [easy|destruct q;[lia|]]. + assert (j / S q < S p)%nat by (apply Nat.div_lt_upper_bound; lia). + rewrite Nat.div_add_l, (Nat.div_small (j / (S q))), Nat.add_0_r by easy. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by easy. + easy. + - intros [[Hiqkp Hkpiq] [Hkpjq Hjqkp]]. + split. + + rewrite (Nat.div_mod_eq i q), (Nat.div_mod_eq j q). + lia. + + rewrite (Nat.div_mod_eq k p). + lia. + } + bdestruct (i =? j). + - subst. + apply big_sum_unique. + exists ((j mod q) * p + (j/q))%nat. + split; [|split]. + + rewrite Nat.mul_comm. apply kron_comm_mx_to_vec_helper; easy. + + unfold Matrix.I. + rewrite Nat.eqb_refl; bdestructΩ'simp. + + intros; bdestructΩ'simp. + - unfold Matrix.I. + replace_bool_lia (i =? j) false. + rewrite andb_false_l. + rewrite big_sum_0; [easy|]. + intros; rewrite andb_false_l; easy. +Qed. + +Lemma kron_comm_mul_inv : forall p q, + kron_comm p q × kron_comm q p = Matrix.I _. +Proof. + intros p q. + apply mat_equiv_eq; auto with wf_db. + rewrite kron_comm_mul_inv_mat_equiv; easy. +Qed. + +Lemma kron_comm_mul_transpose_r_mat_equiv : forall p q, + kron_comm p q × (kron_comm p q) ⊤ = Matrix.I _. +Proof. + intros p q. + rewrite (kron_comm_transpose p q). + apply kron_comm_mul_inv. +Qed. + +Lemma kron_comm_mul_transpose_r : forall p q, + kron_comm p q × (kron_comm p q) ⊤ = Matrix.I _. +Proof. + intros p q. + rewrite (kron_comm_transpose p q). + apply kron_comm_mul_inv. +Qed. + +Lemma kron_comm_mul_transpose_l_mat_equiv : forall p q, + (kron_comm p q) ⊤ × kron_comm p q = Matrix.I _. +Proof. + intros p q. + rewrite <- (kron_comm_transpose q p). + rewrite (Nat.mul_comm p q). + rewrite (transpose_involutive _ _ (kron_comm q p)). + apply kron_comm_mul_transpose_r_mat_equiv. +Qed. + +Lemma kron_comm_mul_transpose_l : forall p q, + (kron_comm p q) ⊤ × kron_comm p q = Matrix.I _. +Proof. + intros p q. + rewrite <- (kron_comm_transpose q p). + rewrite (Nat.mul_comm p q). + rewrite (transpose_involutive _ _ (kron_comm q p)). + apply kron_comm_mul_transpose_r. +Qed. + + + +Lemma kron_comm_commutes_l_mat_equiv : forall n s m t + (A : Matrix n s) (B : Matrix m t), + kron_comm m n × (A ⊗ B) ≡ (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B. + match goal with |- ?A ≡ ?B => + rewrite <- (Mmult_1_r_mat_eq _ _ A), <- (Mmult_1_r_mat_eq _ _ B) + end. + rewrite (Nat.mul_comm t s). + rewrite <- (kron_comm_mul_transpose_r), <- 2!Mmult_assoc. + rewrite (kron_comm_commutes_mat_equiv n s m t). + apply Mmult_simplify_mat_equiv; [|easy]. + rewrite Mmult_assoc. + (* let rec gen_patt H := + match type of H with + | ?f ?x => idtac x; uconstr:(gen_patt f -> _) + | _ => uconstr:(_) + end + in + let pat := gen_patt (kron_comm_mul_inv_mat_equiv t s) in + idtac pat; + match goal with + | |- [pat] => idtac "match" + end. + match type of (kron_comm_mul_inv_mat_equiv t s) with + | ?f ?x ≡ ?g ?y => idtac f; idtac x; idtac g; idtac y + end. *) + (* Mmult_transpose + match goal with + |- context[@Mmult ?n ?m ?o (kron_comm ?t' ?s') (kron_comm ?s'' ?t'')] => + (* idtac n m o t' s' s'' t''; *) + replace s'' with s' by lia; + replace t'' with t' by lia; + replace n with (t'*s')%nat by lia; + replace m with (t'*s')%nat by lia; + replace o with (t'*s')%nat by lia; + replace (@Mmult n m o (kron_comm t' s') (kron_comm s' t')) + with (@Mmult (t'*s') (t'*s') (t'*s') (kron_comm t' s') (kron_comm s' t')) + by (f_equal;lia); + rewrite (kron_comm_mul_inv_mat_equiv t' s') + end. *) + (* rewrite Mmult_1_r_mat_eq. *) + + rewrite (Nat.mul_comm s t), (kron_comm_mul_inv_mat_equiv t s), Mmult_1_r_mat_eq. + easy. +Qed. + +Lemma kron_comm_commutes_l : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + kron_comm m n × (A ⊗ B) = (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B HwfA HwfB. + apply mat_equiv_eq; auto with wf_db. + apply kron_comm_commutes_l_mat_equiv. +Qed. + +Lemma kron_comm_commutes_r_mat_equiv : forall n s m t + (A : Matrix n s) (B : Matrix m t), + (A ⊗ B) × kron_comm s t ≡ (kron_comm n m) × (B ⊗ A). +Proof. + intros. + rewrite kron_comm_commutes_l_mat_equiv; easy. +Qed. + +Lemma kron_comm_commutes_r : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + (A ⊗ B) × kron_comm s t = (kron_comm n m) × (B ⊗ A). +Proof. + intros n s m t A B HA HB. + rewrite kron_comm_commutes_l; easy. +Qed. + + + +(* Lemma kron_comm_commutes_r : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + kron_comm m n × (A ⊗ B) = (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B HwfA HwfB. + match goal with |- ?A = ?B => + rewrite <- (Mmult_1_r _ _ A), <- (Mmult_1_r _ _ B) ; auto with wf_db + end. + rewrite (Nat.mul_comm t s). + rewrite <- (kron_comm_mul_transpose_r), <- 2!Mmult_assoc. + rewrite (kron_comm_commutes n s m t) by easy. + apply Mmult_simplify; [|easy]. + rewrite Mmult_assoc. + rewrite (Nat.mul_comm s t), (kron_comm_mul_inv t s), Mmult_1_r by auto with wf_db. + easy. +Qed. *) + + +Lemma vector_eq_basis_comb_mat_equiv : forall n (y : Vector n), + y ≡ big_sum (G:=Vector n) (fun i => y i O .* @e_i n i) n. +Proof. + intros n y. + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + symmetry. + rewrite Msum_Csum. + apply big_sum_unique. + exists i. + repeat split; try easy. + - unfold ".*", e_i; bdestructΩ'simp. + - intros l Hl Hnk. + unfold ".*", e_i; bdestructΩ'simp. +Qed. + + +Lemma vector_eq_basis_comb : forall n (y : Vector n), + WF_Matrix y -> + y = big_sum (G:=Vector n) (fun i => y i O .* @e_i n i) n. +Proof. + intros n y Hwfy. + apply mat_equiv_eq; auto with wf_db. + apply vector_eq_basis_comb_mat_equiv. +Qed. + +(* Lemma kron_vecT_matrix_vec : forall m n o p + (P : Matrix m o) (y : Vector n) (z : Vector p), + WF_Matrix y -> WF_Matrix z -> WF_Matrix P -> + (z⊤) ⊗ P ⊗ y = @Mmult (m*n) (m*n) (o*p) (kron_comm m n) ((y × (z⊤)) ⊗ P). +Proof. + intros m n o p P y z Hwfy Hwfz HwfP. + match goal with |- ?A = ?B => + rewrite <- (Mmult_1_l _ _ A) ; auto with wf_db + end. + rewrite Nat.mul_1_l. + rewrite <- (kron_comm_mul_transpose_r), Mmult_assoc at 1. + rewrite Nat.mul_1_r, (Nat.mul_comm o p). + apply Mmult_simplify; [easy|]. + rewrite kron_comm_kron_form_sum. + rewrite Msum_transpose. + rewrite Mmult_Msum_distr_r. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + pose proof (kron_transpose _ _ _ _ ((@e_i n k) ⊤ ⊗ Matrix.I m) (@e_i n k)) as H; + rewrite Nat.mul_1_l, Nat.mul_1_r, (Nat.mul_comm m n) in *; + rewrite H; clear H. + pose proof (kron_transpose _ _ _ _ ((@e_i n k) ⊤) (Matrix.I m)) as H; + rewrite Nat.mul_1_l in *; + rewrite H; clear H. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite id_transpose_eq, Mmult_1_l by easy. + rewrite e_i_dot_is_component, transpose_involutive by easy. + (* rewrite <- Mmult_transpose. *) + rewrite Mscale_kron_dist_r, <- 2!Mscale_kron_dist_l. + rewrite kron_1_r. + rewrite <- Mscale_mult_dist_l. + reflexivity. + } + rewrite <- (kron_Msum_distr_r n _ P). + rewrite <- (Mmult_Msum_distr_r). + rewrite <- vector_eq_basis_comb by easy. + easy. +Qed. +*) + +Lemma kron_vecT_matrix_vec_mat_equiv : forall m n o p + (P : Matrix m o) (y : Vector n) (z : Vector p), + (z⊤) ⊗ P ⊗ y ≡ @Mmult (m*n) (m*n) (o*p) (kron_comm m n) ((y × (z⊤)) ⊗ P). +Proof. + intros m n o p P y z. + match goal with |- ?A ≡ ?B => + rewrite <- (Mmult_1_l_mat_eq _ _ A) + end. + rewrite Nat.mul_1_l. + rewrite <- (kron_comm_mul_transpose_r_mat_equiv), Mmult_assoc at 1. + rewrite Nat.mul_1_r, (Nat.mul_comm o p). + apply Mmult_simplify_mat_equiv; [easy|]. + rewrite kron_comm_kron_form_sum_mat_equiv. + rewrite Msum_transpose. + rewrite Mmult_Msum_distr_r. + erewrite (big_sum_mat_equiv_bounded _ _ n). + 2: { + intros k Hk. + unshelve (instantiate (1:=_)). + refine (fun k : nat => y k 0%nat .* e_i k × (z) ⊤ ⊗ P); exact n. + pose proof (kron_transpose _ _ _ _ ((@e_i n k) ⊤ ⊗ Matrix.I m) (@e_i n k)) as H; + rewrite Nat.mul_1_l, Nat.mul_1_r, (Nat.mul_comm m n) in *; + rewrite H; clear H. + pose proof (kron_transpose _ _ _ _ ((@e_i n k) ⊤) (Matrix.I m)) as H; + rewrite Nat.mul_1_l in *; + rewrite H; clear H. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite (id_transpose_eq m). + rewrite Mscale_mult_dist_l, transpose_involutive. + rewrite <- (kron_1_r _ _ P) at 2. + rewrite Mscale_kron_dist_l, <- !Mscale_kron_dist_r. + match goal with + |- (?A ⊗ ?B ⊗ ?C) ≡ _ => pose proof (kron_assoc_mat_equiv A B C) as H + end; + rewrite 4!Nat.mul_1_r in H; rewrite H by easy; clear H. + apply kron_simplify_mat_equiv; [easy|]. + epose proof (Mscale_kron_dist_r _ _ _ _ _ P (Matrix.I 1)) as H; + rewrite 2Nat.mul_1_r in H; + rewrite <- H; clear H. + match goal with + |- (?A ⊗ ?B) ≡ (?C ⊗ ?D) => pose proof (kron_simplify_mat_equiv A C B D) as H + end; + rewrite 2!Nat.mul_1_r in H. apply H. + - rewrite Mmult_1_l_mat_eq; easy. + - rewrite (e_i_dot_is_component_mat_equiv); easy. + } + rewrite <- (kron_Msum_distr_r n _ P). + rewrite <- (Mmult_Msum_distr_r). + rewrite (Nat.mul_comm m n). + rewrite <- vector_eq_basis_comb_mat_equiv by easy. + easy. +Qed. + +Lemma kron_vecT_matrix_vec : forall m n o p + (P : Matrix m o) (y : Vector n) (z : Vector p), + WF_Matrix y -> WF_Matrix z -> WF_Matrix P -> + (z⊤) ⊗ P ⊗ y = @Mmult (m*n) (m*n) (o*p) (kron_comm m n) ((y × (z⊤)) ⊗ P). +Proof. + intros m n o p P y z Hwfy Hwfz HwfP. + apply mat_equiv_eq; + [|rewrite ?Nat.mul_1_l, ?Nat.mul_1_r, (Nat.mul_comm o p); apply WF_mult|]; + auto with wf_db; + [apply WF_kron; auto with wf_db; lia|]. + apply kron_vecT_matrix_vec_mat_equiv. +Qed. + + +Lemma kron_vec_matrix_vecT : forall m n o p + (Q : Matrix n o) (x : Vector m) (z : Vector p), + WF_Matrix x -> WF_Matrix z -> WF_Matrix Q -> + x ⊗ Q ⊗ (z⊤) = @Mmult (m*n) (m*n) (o*p) (kron_comm m n) (Q ⊗ (x × z⊤)). +Proof. + intros m n o p Q x z Hwfx Hwfz HwfQ. + match goal with |- ?A = ?B => + rewrite <- (Mmult_1_l _ _ A) ; auto with wf_db + end. + rewrite Nat.mul_1_r. + rewrite <- (kron_comm_mul_transpose_r), Mmult_assoc at 1. + rewrite Nat.mul_1_l. + apply Mmult_simplify; [easy|]. + rewrite kron_comm_kron_form_sum'. + rewrite Msum_transpose. + rewrite Mmult_Msum_distr_r. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + pose proof (kron_transpose _ _ _ _ ((@e_i m k) ⊗ Matrix.I n) ((@e_i m k) ⊤)) as H; + rewrite Nat.mul_1_l, Nat.mul_1_r, (Nat.mul_comm m n) in *; + rewrite H; clear H. + pose proof (kron_transpose _ _ _ _ ((@e_i m k)) (Matrix.I n)) as H; + rewrite Nat.mul_1_l, (Nat.mul_comm m n) in *; + rewrite H; clear H. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite id_transpose_eq, Mmult_1_l by easy. + rewrite e_i_dot_is_component, transpose_involutive by easy. + (* rewrite <- Mmult_transpose. *) + rewrite 2!Mscale_kron_dist_l, kron_1_l, <-Mscale_kron_dist_r by easy. + rewrite <- Mscale_mult_dist_l. + restore_dims. + reflexivity. + } + rewrite <- (kron_Msum_distr_l m _ Q). + rewrite <- (Mmult_Msum_distr_r). + rewrite <- vector_eq_basis_comb by easy. + easy. +Qed. + +(* TODO: Relocate *) +Lemma kron_1_l_mat_equiv : forall {n m} (A : Matrix n m), + Matrix.I 1 ⊗ A ≡ A. +Proof. + intros n m A. + intros i j Hi Hj. + unfold kron, I. + rewrite 2!Nat.div_small, 2!Nat.mod_small by lia. + rewrite Cmult_1_l. + easy. +Qed. + +Lemma kron_1_r_mat_equiv : forall {n m} (A : Matrix n m), + A ⊗ Matrix.I 1 ≡ A. +Proof. + intros n m A. + intros i j Hi Hj. + unfold kron, I. + rewrite 2!Nat.div_1_r, 2!Nat.mod_1_r by lia. + rewrite Cmult_1_r. + easy. +Qed. + +Lemma kron_vec_matrix_vecT_mat_equiv : forall m n o p + (Q : Matrix n o) (x : Vector m) (z : Vector p), + x ⊗ Q ⊗ (z⊤) ≡ @Mmult (m*n) (m*n) (o*p) (kron_comm m n) (Q ⊗ (x × z⊤)). +Proof. + intros m n o p Q x z. + match goal with |- ?A ≡ ?B => + rewrite <- (Mmult_1_l_mat_eq _ _ A) + end. + rewrite Nat.mul_1_r. + rewrite <- (kron_comm_mul_transpose_r_mat_equiv), Mmult_assoc at 1. + rewrite Nat.mul_1_l. + apply Mmult_simplify_mat_equiv; [easy|]. + rewrite kron_comm_kron_form_sum'. + rewrite Msum_transpose. + rewrite Mmult_Msum_distr_r. + erewrite (big_sum_mat_equiv_bounded). + 2: { + intros k Hk. + unshelve (instantiate (1:=(fun k : nat => + @kron n o m p Q + (@Mmult m 1 p (@scale m 1 (x k 0%nat) (@e_i m k)) + (@transpose p 1 z))))). + pose proof (kron_transpose _ _ _ _ ((@e_i m k) ⊗ Matrix.I n) ((@e_i m k) ⊤)) as H; + rewrite Nat.mul_1_l, Nat.mul_1_r, (Nat.mul_comm m n) in *; + rewrite H; clear H. + pose proof (kron_transpose _ _ _ _ ((@e_i m k)) (Matrix.I n)) as H; + rewrite Nat.mul_1_l, (Nat.mul_comm m n) in *; + rewrite H; clear H. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite id_transpose_eq, transpose_involutive. + rewrite Mscale_mult_dist_l, Mscale_kron_dist_r, <- Mscale_kron_dist_l. + rewrite 2!(Nat.mul_1_l). + apply kron_simplify_mat_equiv; [|easy]. + intros i j Hi Hj. + unfold kron. + rewrite (Mmult_1_l_mat_eq _ _ Q) by (apply Nat.mod_upper_bound; lia). + (* revert i j Hi Hj. *) + rewrite (e_i_dot_is_component_mat_equiv m k x Hk) by (apply Nat.div_lt_upper_bound; lia). + set (a:= (@kron 1 1 n o ((x k 0%nat .* Matrix.I 1)) Q) i j). + match goal with + |- ?A = _ => change A with a + end. + unfold a. + clear a. + rewrite Mscale_kron_dist_l. + unfold scale. + rewrite kron_1_l_mat_equiv by lia. + easy. + } + rewrite <- (kron_Msum_distr_l m _ Q). + rewrite <- (Mmult_Msum_distr_r). + rewrite (Nat.mul_comm m n). + rewrite <- vector_eq_basis_comb_mat_equiv. + easy. +Qed. + +Lemma kron_comm_triple_cycle_mat : forall m n s t p q (A : Matrix m n) + (B : Matrix s t) (C : Matrix p q), + A ⊗ B ⊗ C ≡ (kron_comm (m*s) p) × (C ⊗ A ⊗ B) × (kron_comm q (t*n)). +Proof. + intros m n s t p q A B C. + rewrite (commute_kron_mat_equiv _ _ _ _ (A ⊗ B) C) by auto with wf_db. + rewrite (Nat.mul_comm n t), (Nat.mul_comm q (t*n)). + (* replace (q * (t * n))%nat with (t * n * q)%nat by lia. *) + apply Mmult_simplify_mat_equiv; [|easy]. + apply Mmult_simplify_mat_equiv; [easy|]. + rewrite (Nat.mul_comm t n). + intros i j Hi Hj; + rewrite <- (kron_assoc_mat_equiv C A B); + [easy|lia|lia]. +Qed. + +Lemma kron_comm_triple_cycle : forall m n s t p q (A : Matrix m n) + (B : Matrix s t) (C : Matrix p q), WF_Matrix A -> WF_Matrix B -> WF_Matrix C -> + A ⊗ B ⊗ C = (kron_comm (m*s) p) × (C ⊗ A ⊗ B) × (kron_comm q (t*n)). +Proof. + intros m n s t p q A B C HA HB HC. + rewrite (commute_kron _ _ _ _ (A ⊗ B) C) by auto with wf_db. + rewrite kron_assoc by easy. + f_equal; try lia; f_equal; lia. +Qed. + +Lemma kron_comm_triple_cycle2_mat_equiv : forall m n s t p q (A : Matrix m n) + (B : Matrix s t) (C : Matrix p q), + A ⊗ B ⊗ C ≡ (kron_comm m (s*p)) × (B ⊗ C ⊗ A) × (kron_comm (q*t) n). +Proof. + intros m n s t p q A B C. + rewrite kron_assoc_mat_equiv. + intros i j Hi Hj. + rewrite (commute_kron_mat_equiv _ _ _ _ A (B ⊗ C)) by lia. + f_equal; try lia; f_equal; lia. +Qed. + +Lemma kron_comm_triple_cycle2 : forall m n s t p q (A : Matrix m n) + (B : Matrix s t) (C : Matrix p q), WF_Matrix A -> WF_Matrix B -> WF_Matrix C -> + A ⊗ B ⊗ C = (kron_comm m (s*p)) × (B ⊗ C ⊗ A) × (kron_comm (q*t) n). +Proof. + intros m n s t p q A B C HA HB HC. + rewrite kron_assoc by easy. + rewrite (commute_kron _ _ _ _ A (B ⊗ C)) by auto with wf_db. + f_equal; try lia; f_equal; lia. +Qed. + + + + + +(* #[export] Instance big_sum_mat_equiv_morphism {n m : nat} : + Proper (pointwise_relation nat (@mat_equiv n m) + ==> pointwise_relation nat (@mat_equiv n m)) + (@big_sum (Matrix n m) (M_is_monoid n m)) := big_sum_mat_equiv. *) + +(* Instance forall_mat_equiv_morphism {A: Type} {n m : nat} {f g : A -> Matrix m n}: + pointwise_relation A mat_equiv (fun x => f x) (fun x => f x). + +Instance forall_mat_equiv_morphism `{Equivalence A eqA, Equivalence B eqB} : + Proper ((eqA ==> eqB) ==> list_equiv eqA ==> list_equiv eqB) (@map A B). + +Goal (forall_relation (fun n:nat => @mat_equiv m m)) (fun n => Matrix.I m × direct_sum' (@Zero 0 0) (Matrix.I m)) (fun n => Matrix.I m). +setoid_rewrite Mmult_1_l_mat_eq. *) + + +Lemma id_eq_sum_kron_e_is_mat_equiv : forall n, + Matrix.I n ≡ big_sum (G:=Square n) (fun i => @e_i n i ⊗ (@e_i n i) ⊤) n. +Proof. + intros n. + symmetry. + intros i j Hi Hj. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + rewrite kron_e_i_l by lia. + unfold transpose, e_i. + rewrite <- andb_if. + replace_bool_lia (j @e_i n i ⊗ (@e_i n i) ⊤) n. +Proof. + intros n. + apply mat_equiv_eq; auto with wf_db. + apply id_eq_sum_kron_e_is_mat_equiv. +Qed. + +Lemma kron_comm_cycle_indices : forall t s n, + (kron_comm (t*s) n = @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n))). +Proof. + intros t s n. + rewrite kron_comm_kron_form_sum. + erewrite big_sum_eq_bounded. + 2: { + intros j Hj. + rewrite (Nat.mul_comm t s), <- id_kron, <- kron_assoc by auto with wf_db. + restore_dims. + rewrite kron_assoc by auto with wf_db. + (* rewrite (kron_assoc ((@e_i n j)⊤ ⊗ Matrix.I t) (Matrix.I s) (@e_i n j)) by auto with wf_db. *) + lazymatch goal with + |- ?A ⊗ ?B = _ => rewrite (commute_kron _ _ _ _ A B) by auto with wf_db + end. + (* restore_dims. *) + reflexivity. + } + (* rewrite ?Nat.mul_1_r, ?Nat.mul_1_l. *) + (* rewrite <- Mmult_Msum_distr_r. *) + + rewrite <- (Mmult_Msum_distr_r n _ (kron_comm (t*1) (n*s))). + rewrite <- Mmult_Msum_distr_l. + erewrite big_sum_eq_bounded. + 2: { + intros j Hj. + rewrite <- kron_assoc, (kron_assoc (Matrix.I t)) by auto with wf_db. + restore_dims. + reflexivity. + } + (* rewrite Nat.mul_1_l *) + rewrite <- (kron_Msum_distr_r n _ (Matrix.I s)). + rewrite <- (kron_Msum_distr_l n _ (Matrix.I t)). + rewrite 2!Nat.mul_1_r, 2!Nat.mul_1_l. + rewrite <- (id_eq_sum_kron_e_is n). + rewrite 2!id_kron. + restore_dims. + rewrite Mmult_1_r by auto with wf_db. + rewrite (Nat.mul_comm t n), (Nat.mul_comm n s). + easy. +Qed. + +Lemma kron_comm_cycle_indices_mat_equiv : forall t s n, + (kron_comm (t*s) n ≡ @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n))). +Proof. + intros t s n. + rewrite kron_comm_cycle_indices; easy. +Qed. + +Lemma kron_comm_cycle_indices_rev : forall t s n, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) = kron_comm (t*s) n. +Proof. + intros. + rewrite <- kron_comm_cycle_indices. + easy. +Qed. + +Lemma kron_comm_cycle_indices_rev_mat_equiv : forall t s n, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) ≡ kron_comm (t*s) n. +Proof. + intros. + rewrite <- kron_comm_cycle_indices. + easy. +Qed. + +Lemma kron_comm_triple_id : forall t s n, + (kron_comm (t*s) n) × (kron_comm (s*n) t) × (kron_comm (n*t) s) = Matrix.I (t*s*n). +Proof. + intros t s n. + rewrite kron_comm_cycle_indices. + restore_dims. + rewrite (Mmult_assoc (kron_comm s (n*t))). + restore_dims. + rewrite (Nat.mul_comm (s*n) t). (* TODO: Fix kron_comm_mul_inv to have the + right dimensions by default (or, better yet, to be ambivalent) *) + rewrite (kron_comm_mul_inv t (s*n)). + restore_dims. + rewrite Mmult_1_r by auto with wf_db. + rewrite (Nat.mul_comm (n*t) s). + rewrite (kron_comm_mul_inv). + f_equal; lia. +Qed. + +Lemma kron_comm_triple_id_mat_equiv : forall t s n, + (kron_comm (t*s) n) × (kron_comm (s*n) t) × (kron_comm (n*t) s) ≡ Matrix.I (t*s*n). +Proof. + intros t s n. + setoid_rewrite kron_comm_triple_id; easy. +Qed. + +Lemma kron_comm_triple_id' : forall n t s, + (kron_comm n (t*s)) × (kron_comm t (s*n)) × (kron_comm s (n*t)) = Matrix.I (t*s*n). +Proof. + intros n t s. + (* rewrite kron_comm_cycle_indices. *) + apply transpose_matrices. + + rewrite 2!Mmult_transpose. + (* restore_dims. *) + rewrite (kron_comm_transpose s (n*t)). + (* restore_dims. *) + rewrite (kron_comm_transpose n (t*s)). + restore_dims. + replace (n*(t*s))%nat with (t*(s*n))%nat by lia. + replace (s*(n*t))%nat with (t*(s*n))%nat by lia. + + rewrite (kron_comm_transpose t (s*n)). + (* rewrite <- Mmult_assoc. *) + + (* rewrite (kron_comm_transpose t (s*n)). *) + restore_dims. + rewrite Nat.mul_assoc, id_transpose_eq. + replace (t*s*n)%nat with (t*n*s)%nat by lia. + rewrite <- (kron_comm_triple_id t n s). + rewrite Mmult_assoc. + replace (s*t*n)%nat with (t*n*s)%nat by lia. + replace (n*t*s)%nat with (t*n*s)%nat by lia. + apply Mmult_simplify; [f_equal; lia|]. + repeat (f_equal; try lia). +Qed. + +Lemma kron_comm_triple_id'_mat_equiv : forall t s n, + (kron_comm n (t*s)) × (kron_comm t (s*n)) × (kron_comm s (n*t)) = Matrix.I (t*s*n). +Proof. + intros t s n. + rewrite kron_comm_triple_id'. + easy. +Qed. + +Lemma kron_comm_triple_id'C : forall n t s, + (kron_comm n (s*t)) × (kron_comm t (n*s)) × (kron_comm s (t*n)) = Matrix.I (t*s*n). +Proof. + intros n t s. + rewrite (Nat.mul_comm s t), (Nat.mul_comm n s), + (Nat.mul_comm t n), kron_comm_triple_id'. + easy. +Qed. + +Lemma kron_comm_triple_id'C_mat_equiv : forall n t s, + (kron_comm n (s*t)) × (kron_comm t (n*s)) × (kron_comm s (t*n)) ≡ Matrix.I (t*s*n). +Proof. + intros n t s. + rewrite kron_comm_triple_id'C. + easy. +Qed. + +Lemma kron_comm_triple_indices_collapse_mat_equiv : forall s n t, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) + ≡ (kron_comm (t*s) n). +Proof. + intros s n t. + rewrite <- (Mmult_1_r_mat_eq _ _ (_ × _)). + replace (t*(s*n))%nat with (n*(t*s))%nat by lia. + rewrite <- (kron_comm_mul_inv_mat_equiv). + rewrite <- Mmult_assoc. + rewrite (kron_comm_triple_id'C s t n). + replace (t*n*s)%nat with (n*(t*s))%nat by lia. + replace (s*(n*t))%nat with (t*s*n)%nat by lia. + replace (n*(t*s))%nat with (t*s*n)%nat by lia. + rewrite Mmult_1_l_mat_eq. + easy. +Qed. + +Lemma kron_comm_triple_indices_collapse : forall s n t, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) + = (kron_comm (t*s) n). +Proof. + intros s n t. + apply mat_equiv_eq; auto with wf_db; + [apply_with_obligations (WF_kron_comm (t*s) n); lia|]. + apply kron_comm_triple_indices_collapse_mat_equiv. +Qed. + +Lemma kron_comm_triple_indices_collapse_mat_equivC : forall s n t, + @Mmult (s*(t*n)) (s*(t*n)) (t*(n*s)) (kron_comm s (t*n)) (kron_comm t (n*s)) + ≡ (kron_comm (t*s) n). +Proof. + intros s n t. + rewrite (Nat.mul_comm t n), (Nat.mul_comm n s). + rewrite kron_comm_triple_indices_collapse_mat_equiv. + easy. +Qed. + +Lemma kron_comm_triple_indices_collapseC : forall s n t, + @Mmult (s*(t*n)) (s*(t*n)) (t*(n*s)) (kron_comm s (t*n)) (kron_comm t (n*s)) + = (kron_comm (t*s) n). +Proof. + intros s n t. + apply mat_equiv_eq; auto with wf_db; + [apply_with_obligations (WF_kron_comm (t*s) n); lia|]. + apply kron_comm_triple_indices_collapse_mat_equivC. +Qed. + +(* +Not sure what this is, or if it's true: +Lemma kron_comm_triple_indices_commute : forall t s n, + @Mmult (s*t*n) (s*t*n) (t*(s*n)) (kron_comm (s*t) n) (kron_comm t (s*n)) = + @Mmult (t*(s*n)) (t*(s*n)) (s*t*n) (kron_comm t (s*n)) (kron_comm (s*t) n). *) +Lemma kron_comm_triple_indices_commute_mat_equiv : forall t s n, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) ≡ + @Mmult (t*(s*n)) (t*(s*n)) (s*(n*t)) (kron_comm t (s*n)) (kron_comm s (n*t)). +Proof. + intros t s n. + rewrite kron_comm_triple_indices_collapse_mat_equiv. + rewrite (Nat.mul_comm t s). + rewrite <- (kron_comm_triple_indices_collapseC t n s). + easy. +Qed. + +Lemma kron_comm_triple_indices_commute : forall t s n, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) = + @Mmult (t*(s*n)) (t*(s*n)) (s*(n*t)) (kron_comm t (s*n)) (kron_comm s (n*t)). +Proof. + intros t s n. + apply mat_equiv_eq; auto with wf_db; + [replace (s*(n*t))%nat with (t*(s*n))%nat by lia; apply WF_mult; + auto with wf_db; apply_with_obligations (WF_kron_comm s (n*t)); lia|]. + apply kron_comm_triple_indices_commute_mat_equiv. +Qed. + +Lemma kron_comm_triple_indices_commute_mat_equivC : forall t s n, + @Mmult (s*(t*n)) (s*(t*n)) (t*(n*s)) (kron_comm s (t*n)) (kron_comm t (n*s)) ≡ + @Mmult (t*(s*n)) (t*(s*n)) (s*(n*t)) (kron_comm t (s*n)) (kron_comm s (n*t)). +Proof. + intros t s n. + rewrite (Nat.mul_comm t n), (Nat.mul_comm n s). + apply kron_comm_triple_indices_commute_mat_equiv. +Qed. + +Lemma kron_comm_triple_indices_commuteC : forall t s n, + @Mmult (s*(t*n)) (s*(t*n)) (t*(n*s)) (kron_comm s (t*n)) (kron_comm t (n*s)) = + @Mmult (t*(s*n)) (t*(s*n)) (s*(n*t)) (kron_comm t (s*n)) (kron_comm s (n*t)). +Proof. + intros t s n. + rewrite (Nat.mul_comm t n), (Nat.mul_comm n s). + apply kron_comm_triple_indices_commute. +Qed. + +Lemma kron_comm_kron_of_mult_commute1_mat_equiv : forall m n p q s t + (A : Matrix m n) (B : Matrix p q) (C : Matrix q s) (D : Matrix n t), + @mat_equiv (m*p) (s*t) ((kron_comm m p) × ((B × C) ⊗ (A × D))) + ((A ⊗ B) × kron_comm n q × (C ⊗ D)). +Proof. + intros m n p q s t A B C D. + rewrite <- kron_mixed_product. + rewrite (Nat.mul_comm p m), <- Mmult_assoc. + rewrite kron_comm_commutes_r_mat_equiv. + match goal with (* TODO: Make a lemma *) + |- ?A ≡ ?B => enough (H : A = B) by (rewrite H; easy) + end. + f_equal; lia. +Qed. + +Lemma kron_comm_kron_of_mult_commute2_mat_equiv : forall m n p q s t + (A : Matrix m n) (B : Matrix p q) (C : Matrix q s) (D : Matrix n t), + ((A ⊗ B) × kron_comm n q × (C ⊗ D)) ≡ (A × D ⊗ (B × C)) × kron_comm t s. +Proof. + intros m n p q s t A B C D. + rewrite Mmult_assoc, kron_comm_commutes_l_mat_equiv, <-Mmult_assoc, + <- kron_mixed_product. + easy. +Qed. + +Lemma kron_comm_kron_of_mult_commute3_mat_equiv : forall m n p q s t + (A : Matrix m n) (B : Matrix p q) (C : Matrix q s) (D : Matrix n t), + (A × D ⊗ (B × C)) × kron_comm t s ≡ + (Matrix.I m) ⊗ (B × C) × kron_comm m s × (Matrix.I s ⊗ (A × D)). +Proof. + intros m n p q s t A B C D. + rewrite <- 2!kron_comm_commutes_l_mat_equiv, Mmult_assoc. + restore_dims. + rewrite kron_mixed_product. + rewrite (Nat.mul_comm m p), (Nat.mul_comm t s). + rewrite Mmult_1_r_mat_eq, Mmult_1_l_mat_eq. + easy. +Qed. + +Lemma kron_comm_kron_of_mult_commute4_mat_equiv : forall m n p q s t + (A : Matrix m n) (B : Matrix p q) (C : Matrix q s) (D : Matrix n t), + @mat_equiv (m*p) (s*t) + ((Matrix.I m) ⊗ (B × C) × kron_comm m s × (Matrix.I s ⊗ (A × D))) + ((A × D) ⊗ (Matrix.I p) × kron_comm t p × ((B × C) ⊗ Matrix.I t)). +Proof. + intros m n p q s t A B C D. + rewrite <- 2!kron_comm_commutes_l_mat_equiv, 2!Mmult_assoc. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite (Nat.mul_comm m p), 2!Mmult_1_r_mat_eq. + rewrite 2!Mmult_1_l_mat_eq. + easy. +Qed. + +Lemma trace_mmult_trans : forall m n (A B : Matrix m n), + trace (A⊤ × B) = Σ (fun j => Σ (fun i => A i j * B i j) m) n. +Proof. + intros m n A B. + apply big_sum_eq_bounded. + intros j Hj. + apply big_sum_eq_bounded. + intros i Hi; reflexivity. +Qed. + +Lemma trace_mmult_trans' : forall m n (A B : Matrix m n), + trace (A⊤ × B) = Σ (fun ij => let j := (ij / m)%nat in + let i := ij mod m in + A i j * B i j) (m*n). +Proof. + intros m n A B. + rewrite trace_mmult_trans, big_sum_double_sum. + reflexivity. +Qed. + +Lemma trace_0_l : forall (A : Square 0), + trace A = 0. +Proof. + intros A. + unfold trace. + easy. +Qed. + +Lemma trace_0_r : forall n, + trace (@Zero n n) = 0. +Proof. + intros A. + unfold trace. + rewrite big_sum_0; easy. +Qed. + +Lemma trace_mplus : forall n (A B : Square n), + trace (A .+ B) = trace A + trace B. +Proof. + intros n A B. + induction n. + - rewrite 3!trace_0_l; lca. + - unfold trace in *. + rewrite <- 3!big_sum_extend_r. + setoid_rewrite (IHn A B). + lca. +Qed. + +Lemma trace_big_sum : forall n k f, + trace (big_sum (G:=Square n) f k) = Σ (fun x => trace (f x)) k. +Proof. + intros n k f. + induction k. + - rewrite trace_0_r; easy. + - rewrite <- 2!big_sum_extend_r, <-IHk. + setoid_rewrite trace_mplus. + easy. +Qed. + +Lemma Hij_decomp_mat_equiv : forall n m (A : Matrix n m), + A ≡ big_sum (G:=Matrix n m) (fun ij => + let i := (ij/m)%nat in let j := ij mod m in + A i j .* H i j) (n*m). +Proof. + intros n m A. + intros i j Hi Hj. + rewrite Msum_Csum. + symmetry. + apply big_sum_unique. + exists (i*m + j)%nat. + simpl. + repeat split. + - nia. + - rewrite Nat.div_add_l, Nat.div_small, Nat.add_0_r by lia. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by lia. + unfold scale, Mmult. + erewrite big_sum_unique, Cmult_1_r; [easy|]. + exists O; repeat split; auto; + unfold transpose, e_i; + intros; + rewrite !Nat.eqb_refl; + simpl_bools; + bdestructΩ'simp. + - intros ab Hab Habneq. + unfold scale, Mmult, transpose, e_i. + simpl. + rewrite Cplus_0_l. + simpl_bools. + bdestructΩ'simp. + exfalso; apply Habneq. + symmetry. + rewrite (Nat.div_mod_eq ab m) at 1 by lia. + lia. +Qed. + +Lemma Mmult_Hij_Hij_mat_equiv : forall n m o i j k l, (j < m)%nat -> + @Mmult n m o (H i j) (H k l) ≡ (if (j =? k) then H i l else Zero). +Proof. + intros n m o i j k l Hj. + intros a b Ha Hb. + unfold Mmult, transpose, e_i. + simpl. + bdestruct (j =? k). + - subst k. + rewrite Cplus_0_l. + bdestruct (a =? i); simpl; + bdestruct (b =? l); simpl; + Csimpl. + 1: simpl_bools; + replace_bool_lia (a (j < m)%nat -> + (H i j : Matrix n m) × A ≡ big_sum (G:=Matrix n o) + (fun kl : nat => A (kl / o)%nat (kl mod o) + .* (if j =? kl / o then @e_i n i × (@e_i o (kl mod o)) ⊤ else Zero)) (m * o). +Proof. + intros n m o A i j Hi Hj. + rewrite (Hij_decomp_mat_equiv _ _ A) at 1. + rewrite Mmult_Msum_distr_l. + simpl. + set (f := fun a b => A a b .* (if j =? a then @e_i n i × (@e_i o (b)) ⊤ else Zero)). + rewrite (big_sum_mat_equiv_bounded _ (fun kl => f (kl/o)%nat (kl mod o))). + 2:{ + intros kl Hkl. + rewrite Mscale_mult_dist_r. + rewrite Mmult_Hij_Hij_mat_equiv by easy. + easy. + } + easy. +Qed. + +Lemma Hij_elem : forall n m i j k l, + ((H i j) : Matrix n m) k l = if (k=?i)&&(l=?j)&&(i (j < m)%nat -> + trace (H i j × (A⊤)) = A i j. +Proof. + intros n m A i j Hi Hj. + rewrite (Hij_decomp_mat_equiv _ _ A) at 1. + rewrite (Msum_transpose n m (n*m)). + simpl. + rewrite Mmult_Hij_l_mat_equiv by easy. + erewrite big_sum_eq_bounded. + 2: { + intros ij Hij. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + unfold scale, transpose, Mmult, e_i. + simpl; rewrite Cplus_0_l. + rewrite if_mult_and. + replace (ij / n Σ (fun j => A i j * B j i) m) n. +Proof. + reflexivity. +Qed. + +Lemma trace_mmult_eq_comm : forall {n m} (A : Matrix n m) (B : Matrix m n), + trace (A×B) = trace (B×A). +Proof. + intros n m A B. + rewrite 2!trace_mmult_eq_ptwise. + rewrite big_sum_swap_order. + do 2 (apply big_sum_eq_bounded; intros). + apply Cmult_comm. +Qed. + +Lemma trace_transpose : forall {n} (A : Square n), + trace (A ⊤) = trace A. +Proof. + reflexivity. +Qed. + +Lemma trace_mmult_transpose_Hij_l : forall {n m} (A: Matrix m n) i j, + (i < m)%nat -> (j < n)%nat -> + trace ((H i j)⊤ × A) = A i j. +Proof. + intros n m A i j Hi Hj. + rewrite trace_mmult_eq_comm, <- trace_transpose, 3!Mmult_transpose, + 2!transpose_involutive, trace_mmult_Hij_transpose_l; try easy. +Qed. + + +Lemma trace_kron : forall {n p} (A : Square n) (B : Square p), + trace (A ⊗ B) = trace A * trace B. +Proof. + intros n p A B. + destruct p; + [rewrite Nat.mul_0_r, 2!trace_0_l; lca|]. + unfold trace. + simpl_rewrite big_sum_product; [|easy]. + reflexivity. +Qed. + +Lemma trace_kron_comm_kron : forall m n (A B : Matrix m n), + trace (kron_comm m n × (A ⊤ ⊗ B)) = trace (A⊤ × B). +Proof. + intros m n A B. + rewrite kron_comm_Hij_sum'. + rewrite Mmult_Msum_distr_r. + rewrite trace_mmult_trans', trace_big_sum. + set (f:= fun a b => A a b * B a b). + erewrite big_sum_eq_bounded. + 2:{ + intros ij Hij. + simpl. + rewrite kron_mixed_product' by lia. + rewrite trace_kron, trace_mmult_Hij_transpose_l by + (try apply Nat.div_lt_upper_bound; try apply Nat.mod_upper_bound; lia). + rewrite trace_mmult_transpose_Hij_l by + (try apply Nat.div_lt_upper_bound; try apply Nat.mod_upper_bound; lia). + fold (f (ij/n)%nat (ij mod n)). + reflexivity. + } + rewrite (Nat.mul_comm m n), <- (big_sum_double_sum f). + rewrite big_sum_swap_order. + rewrite big_sum_double_sum. + rewrite Nat.mul_comm. + easy. +Qed. + + +(* TODO: put a normal place *) +Lemma kron_comm_mx_to_vec_r_mat_equiv : forall p q (A : Matrix p q), + (mx_to_vec (A ⊤)) ⊤ × kron_comm p q ≡ (mx_to_vec A) ⊤. +Proof. + intros p q A. + match goal with + |- ?B ≡ ?C => rewrite <- (transpose_involutive _ _ B), <- (transpose_involutive _ _ C) + end. + rewrite Nat.mul_comm. + apply transpose_simplify_mat_equiv. + rewrite Mmult_transpose. + rewrite Nat.mul_comm. + rewrite kron_comm_transpose_mat_equiv. + rewrite transpose_involutive. + (* rewrite Nat.mul_comm. *) + (* restore_dims. *) + apply_with_obligations (kron_comm_mx_to_vec_mat_equiv q p (A⊤)); + [f_equal|]; lia. +Qed. + +Lemma trace_mmult_eq_dot_mx_to_vec : forall {m n} (A B : Matrix m n), + trace (A⊤ × B) = mx_to_vec A ∘ mx_to_vec B. +Proof. + intros m n A B. + rewrite trace_mmult_eq_ptwise. + rewrite big_sum_double_sum. + unfold dot, mx_to_vec. + (* rewrite Nat.mul_comm. *) + apply big_sum_eq_bounded. + intros ij Hij. + unfold make_WF. + replace_bool_lia (ij enough (C ≡ D) by auto + end. + rewrite kron_comm_mx_to_vec_r_mat_equiv. + easy. +Qed. + +Lemma gcd_grow : forall n m, + Nat.gcd (S n) m = Nat.gcd (m mod S n) (S n). +Proof. reflexivity. Qed. + +Lemma gcd_le : forall n m, + (Nat.gcd (S n) (S m) <= S n /\ Nat.gcd (S n) (S m) <= S m)%nat. +Proof. + intros n m. + pose proof (Nat.gcd_divide (S n) (S m)). + split; apply Nat.divide_pos_le; try easy; lia. +Qed. + +Lemma div_mul_combine : forall a b c d, + Nat.divide b a -> Nat.divide d c -> + (a / b * (c / d) = (a * c) / (b * d))%nat. +Proof. + intros a b c d [a' Ha'] [c' Hc']. + subst a c. + destruct b; + [rewrite ?Nat.mul_0_r, ?Nat.mul_0_l; easy|]. + rewrite Nat.div_mul by easy. + destruct d; + [rewrite ?Nat.mul_0_r, ?Nat.mul_0_l; easy|]. + rewrite Nat.div_mul by easy. + rewrite <- Nat.mul_assoc, (Nat.mul_comm (S b)), <- Nat.mul_assoc, + Nat.mul_assoc, (Nat.mul_comm (S d)), Nat.div_mul by lia. + easy. +Qed. + +Lemma prod_eq_gcd_lcm : forall n m, + (S n * S m = Nat.gcd (S n) (S m) * Nat.lcm (S n) (S m))%nat. +Proof. + intros n m. + unfold Nat.lcm. + rewrite <- 2!Nat.divide_div_mul_exact, (Nat.mul_comm (Nat.gcd _ _)), + Nat.div_mul; try easy; + try (try apply Nat.divide_mul_r; apply Nat.gcd_divide; lia); + rewrite Nat.gcd_eq_0; lia. +Qed. + +Lemma gcd_eq_div_lcm : forall n m, + (Nat.gcd (S n) (S m) = (S n * S m) / (Nat.lcm (S n) (S m)))%nat. +Proof. + intros n m. + rewrite prod_eq_gcd_lcm, Nat.div_mul; try easy. + rewrite Nat.lcm_eq_0; lia. +Qed. + + + +Lemma times_n_C1 : forall n, + times_n C1 n = RtoC (INR n). +Proof. + induction n; [easy|]. + rewrite S_INR, RtoC_plus, <- IHn, Cplus_comm. + easy. +Qed. + + +Lemma div_0_r : forall n, + (n / 0 = 0)%nat. +Proof. + intros n. + easy. +Qed. + +Lemma div_divides : forall n m, + Nat.divide m n -> (n / m <> 0)%nat -> + Nat.divide (n / m) n. +Proof. + intros n m Hdiv Hnz. + assert (H: m <> O) by (intros Hfalse; subst m; rewrite div_0_r in *; lia). + exists m. + rewrite <- Nat.divide_div_mul_exact, Nat.mul_comm, Nat.div_mul; try easy. +Qed. + +Lemma div_div : forall n m, + Nat.divide m n -> (n / m <> 0)%nat -> + (n / (n / m) = m)%nat. +Proof. + intros n m Hdiv Hnz. + rewrite <- (Nat.mul_cancel_r _ _ (n/m)) by easy. + rewrite Nat.mul_comm. + + assert (H: m <> O) by (intros Hfalse; subst m; rewrite div_0_r in *; lia). + rewrite <- Nat.divide_div_mul_exact, Nat.mul_comm, Nat.div_mul; try easy. + rewrite <- Nat.divide_div_mul_exact, Nat.mul_comm, Nat.div_mul; try easy. + apply div_divides; easy. +Qed. + + + +(* Lemma gcd_prod_helper : forall n m k, + n <> O -> m <> O -> k <> O -> + (Nat.gcd (Nat.gcd n m) k * Nat.gcd (m*n) k = (Nat.gcd n k) * (Nat.gcd m k))%nat. + (* (Nat.gcd (Nat.gcd (S n) (S m)) (S k) * Nat.gcd (S n * S m) (S k) = (Nat.gcd (S n) (S k)) * (Nat.gcd (S m) (S k)))%nat. *) +Proof. + intros n m k Hn Hm Hk. + Admitted. +Lemma gcd_prod : forall n m k, + n <> O -> m <> O -> k <> O -> + (Nat.gcd (m*n) k = (Nat.gcd n k) * (Nat.gcd m k) / Nat.gcd (Nat.gcd n m) k)%nat. +Proof. + intros n m k Hn Hm Hk. + rewrite <- gcd_prod_helper by easy. + rewrite (Nat.mul_comm (Nat.gcd _ _) _), Nat.div_mul; [easy | + rewrite ?Nat.gcd_eq_0; lia]. +Qed. + +Lemma sum_if_prod_eq_one_plus_gcd : forall m n, + Σ (fun j => Σ (fun i => if j * n =? i * m then C1 else 0) (S n)) (S m) = + INR (1 + Nat.gcd m n). +Proof. + intros m n. + destruct m, n; + try replace j with O by lia; + rewrite 1?Nat.mul_0_r, 1?Nat.mul_0_l, 1?Nat.mod_0_l, 1?Nat.eqb_refl by easy. + 1-3: (erewrite big_sum_eq_bounded; [ | + intros j Hj; + erewrite big_sum_eq_bounded; [ | + intros i Hi; + rewrite 2?Nat.mul_0_r; + reflexivity + ]; + reflexivity + ]). + 1: lca. + 1: rewrite Nat.gcd_0_l; simpl big_sum. + 2: erewrite Nat.gcd_0_r, big_sum_eq_bounded; [| intros j Hj; simpl; rewrite Cplus_0_l; + reflexivity]. + 1,2: rewrite big_sum_constant, Nat.add_comm, times_n_C1, + 1?plus_INR, ?S_INR; Csimpl; rewrite 2!RtoC_plus; lca. + rewrite (big_sum_eq_bounded _ (fun j => if (j*(S n) mod (S m) =? 0) then C1 else C0)). + 2: { + intros j Hj. + bdestruct_one. + - rewrite Nat.mod_divide in H by easy. + destruct H as [k Hk]. + apply big_sum_unique. + exists k; split; [nia | split; intros; bdestructΩ'simp]. + nia. + - rewrite big_sum_0_bounded; [easy|]. + intros i Hi. + bdestructΩ'simp. + rewrite H0, Nat.mod_mul in * by easy. + easy. + } + rewrite <- big_sum_extend_l. + rewrite plus_INR, RtoC_plus. + unfold C_is_monoid, Gplus. f_equal. + 1: rewrite Nat.mul_0_l, Nat.mod_0_l; easy. + pose proof (Nat.gcd_divide (S m) (S n)) as Hgcddiv. + (* Search (Nat.divide _ (_*_)). *) + assert (Hgcdnz: Nat.gcd (S m) (S n) <> O) by (rewrite Nat.gcd_eq_0; lia). + rewrite (big_sum_eq_bounded _ + (fun j => if ((S j) mod ((S m) / (Nat.gcd (S m) (S n))) =? 0) then C1 else C0)). + 2: { + intros j Hj. + enough (H: ((S j * S n) mod S m =? 0) = (S j mod (S m / Nat.gcd (S m) (S n)) =? 0)) + by (rewrite H; easy). + rewrite eq_iff_eq_true, 2!Nat.eqb_eq. + rewrite 2!Nat.mod_divide by (try easy; + rewrite Nat.div_small_iff by easy; + pose proof (gcd_le m n); lia). + split. + - intros [k Hk]. + exists (k * S m * Nat.gcd (S m) (S n) / (S m * S n))%nat. + rewrite <- Hk. + rewrite div_mul_combine; try easy. + 1: replace (S j * S n * Nat.gcd (S m) (S n) * S m)%nat with + (S j * (S m * S n * Nat.gcd (S m) (S n)))%nat by lia. + 1: rewrite Nat.div_mul by lia; easy. + rewrite Hk. + assert (Hgcd: Nat.gcd (S j * S n) (S n) = Nat.gcd (k * S m) (S n)) by + (rewrite Hk; easy). + assert (Hmul: (forall j', Nat.gcd (S j' * S n) (S n) = S n)%nat) by + (induction j'; [rewrite Nat.mul_1_l; apply Nat.gcd_diag| + rewrite Nat.mul_succ_l, Nat.gcd_comm, Nat.gcd_add_diag_r, + Nat.gcd_comm; apply IHj']). + rewrite Hmul in Hgcd. + rewrite (Nat.mul_comm k), <- Nat.mul_assoc. + rewrite Nat.mul_divide_cancel_l by easy. + rewrite Hgcd at 1. + exists (Nat.gcd (Nat.gcd (S m) k ) (S n) * (k/Nat.gcd k (S n)))%nat. + symmetry. + rewrite Nat.mul_comm, Nat.mul_assoc, (Nat.mul_comm (Nat.gcd (_*_) _)), gcd_prod_helper by lia. + rewrite <-Nat.mul_assoc, Nat.mul_comm. + f_equal. + rewrite <- Nat.divide_div_mul_exact, Nat.mul_comm, Nat.div_mul; try easy; + try apply Nat.gcd_divide; + rewrite Nat.gcd_eq_0; lia. + - intros [k Hk]. + rewrite Hk. + exists (k * (S n / Nat.gcd (S m) (S n)))%nat. + rewrite Nat.mul_comm, Nat.mul_assoc. + rewrite <- Nat.divide_div_mul_exact by easy. + symmetry. + rewrite Nat.mul_comm, Nat.mul_assoc. + rewrite <- Nat.divide_div_mul_exact by easy. + f_equal. + lia. + } + rewrite <- (Cplus_0_l (big_sum _ _)). + set (f:= fun j => if j mod (S m / Nat.gcd (S m) (S n)) =? 0 then C1 else 0). + + erewrite <- (Cminus_diag (f O) _ ltac:(reflexivity)) at 1. + unfold Cminus. + rewrite (Cplus_comm (_) (- _)), <- Cplus_assoc. + assert (Hdiv: (S m / Nat.gcd (S m) (S n) <> 0)%nat)by (rewrite Nat.div_small_iff by easy; + enough (Nat.gcd (S m) (S n) <= S m)%nat by lia; apply gcd_le). + pose proof (big_sum_extend_l (S m) f) as H. + unfold f at 2 in H. + unfold C_is_monoid, Gplus in H. + rewrite H. + set (g := fun (i j : nat) => if j =? 0 then C1 else 0). + rewrite (big_sum_eq_bounded f + (fun j => g (j / (S m / Nat.gcd (S m) (S n))) (j mod (S m / Nat.gcd (S m) (S n)))))%nat + by easy. + rewrite <- big_sum_extend_r. + replace (S m) with (S m / Nat.gcd (S m) (S n) * Nat.gcd (S m) (S n))%nat at 1. + rewrite <- big_sum_double_sum. + rewrite (big_sum_eq_bounded _ (fun _ => C1)). + 2: { + intros j Hj. + apply big_sum_unique. + exists O; split; [lia|unfold g; split; intros; bdestructΩ'simp]. + } + rewrite big_sum_constant, times_n_C1. + rewrite Cplus_assoc, (Cplus_comm (- _)), <- Cplus_assoc. + symmetry. + rewrite <- (Cplus_0_r (INR _)). + f_equal; try lca. + rewrite div_div by easy. + pose proof (div_divides (S m) (Nat.gcd (S m) (S n)) ltac:(easy) ltac:(easy)). + erewrite (proj2 (Nat.mod_divide _ _ _)); try easy. + unfold f, g; rewrite Nat.mod_0_l, Nat.eqb_refl; try lca; try easy. + Unshelve. + 2: easy. + rewrite Nat.mul_comm, <- Nat.divide_div_mul_exact, Nat.mul_comm, Nat.div_mul; easy. +Qed. *) + +Lemma f_to_vec_split : forall (f : nat -> bool) (m n : nat), + f_to_vec (m + n) f = f_to_vec m f ⊗ f_to_vec n (VectorStates.shift f m). +Proof. + intros f m n. + rewrite f_to_vec_merge. + apply f_to_vec_eq. + intros i Hi. + unfold VectorStates.shift. + bdestructΩ'. + f_equal; lia. +Qed. + +Lemma n_top_to_bottom_semantics_eq_kron_comm : forall n o, + ⟦ n_top_to_bottom n o ⟧ = kron_comm (2^o) (2^n). +Proof. + intros n o. + rewrite zxperm_permutation_semantics by auto with zxperm_db. + unfold zxperm_to_matrix. + rewrite perm_of_n_top_to_bottom. + apply equal_on_basis_states_implies_equal; auto with wf_db. + 1: { + rewrite Nat.add_comm, Nat.pow_add_r. + auto with wf_db. + } + intros f. + pose proof (perm_to_matrix_permutes_qubits (n + o) (rotr (n+o) n) f). + unfold perm_to_matrix in H. + rewrite H by auto with perm_db. + rewrite (f_to_vec_split f). + pose proof (kron_comm_commutes_vectors_l (2^o) (2^n) + (f_to_vec n f) (f_to_vec o (@VectorStates.shift bool f n)) + ltac:(auto with wf_db) ltac:(auto with wf_db)). + replace (2^(n+o))%nat with (2^o *2^n)%nat by (rewrite Nat.pow_add_r; lia). + simpl in H0. + rewrite H0. + rewrite Nat.add_comm, f_to_vec_split. + f_equal. + - apply f_to_vec_eq. + intros i Hi. + unfold VectorStates.shift. + f_equal; unfold rotr. + rewrite Nat.mod_small by lia. + bdestructΩ'. + - apply f_to_vec_eq. + intros i Hi. + unfold VectorStates.shift, rotr. + rewrite <- Nat.add_assoc, mod_add_n_r, Nat.mod_small by lia. + bdestructΩ'. +Qed. + +Lemma n_top_to_bottom_semantics_eq_kron_comm_mat_equiv : forall n o, + ⟦ n_top_to_bottom n o ⟧ ≡ kron_comm (2^o) (2^n). +Proof. + intros n o. + rewrite n_top_to_bottom_semantics_eq_kron_comm; easy. +Qed. + +Lemma compose_semantics' : +forall {n m o : nat} (zx0 : ZX n m) (zx1 : ZX m o), +@eq (Matrix (Nat.pow 2 o) (Nat.pow 2 n)) + (@ZX_semantics n o (@Compose n m o zx0 zx1)) + (@Mmult (Nat.pow 2 o) (Nat.pow 2 m) (Nat.pow 2 n) + (@ZX_semantics m o zx1) (@ZX_semantics n m zx0)). +Proof. + intros. + rewrite (@compose_semantics n m o). + easy. +Qed. diff --git a/examples/KronComm_orig.v b/examples/KronComm_orig.v new file mode 100644 index 0000000..aed18ba --- /dev/null +++ b/examples/KronComm_orig.v @@ -0,0 +1,1213 @@ +Require Import Setoid. + +From VyZX Require Import CoreData. +From VyZX Require Import CoreRules. +From VyZX Require Import PermutationRules. +From ViCaR Require Export CategoryTypeclass. + +Lemma Msum_transpose : forall n m p f, + (big_sum (G:=Matrix n m) f p) ⊤ = + big_sum (G:=Matrix n m) (fun i => (f i) ⊤) p. +Proof. + intros. + rewrite (big_sum_func_distr f transpose); easy. +Qed. + +Ltac print_state := + try (match goal with | H : ?p |- _ => idtac H ":" p; fail end); + idtac "---------------------------------------------------------"; + match goal with |- ?P => idtac P +end. + + +Ltac is_C0 x := + assert (x = C0) by lca. + +Ltac is_C1 x := + assert (x = C1) by lca. + +Ltac print_C x := + tryif is_C0 x then idtac "0" else + tryif is_C1 x then idtac "1" else idtac "X". + +Ltac print_LHS_matU := + intros; + (let i := fresh "i" in + let j := fresh "j" in + let Hi := fresh "Hi" in + let Hj := fresh "Hj" in + intros i j Hi Hj; try solve_end; + repeat + (destruct i as [| i]; [ | apply <- Nat.succ_lt_mono in Hi ]; + try solve_end); clear Hi; + repeat + (destruct j as [| j]; [ | apply <- Nat.succ_lt_mono in Hj ]; + try solve_end); clear Hj); + match goal with |- ?x = ?y ?i ?j => autounfold with U_db; simpl; + match goal with + | |- ?x = _ => idtac i; idtac j; print_C x; idtac "" + end +end. + +Definition kron_comm p q : Matrix (p*q) (p*q):= + @make_WF (p*q) (p*q) (fun s t => + (* have blocks H_ij, p by q of them, and each is q by p *) + let i := (s / q)%nat in let j := (t / p)%nat in + let k := (s mod q)%nat in let l := (t mod p) in + (* let k := (s - q * i)%nat in let l := (t - p * t)%nat in *) + if (i =? l) && (j =? k) then C1 else C0 + (* s/q =? t mod p /\ t/p =? s mod q *) +). + +Lemma WF_kron_comm p q : WF_Matrix (kron_comm p q). +Proof. unfold kron_comm; auto with wf_db. Qed. +#[export] Hint Resolve WF_kron_comm : wf_db. + +(* Lemma test_kron : kron_comm 2 3 = Matrix.Zero. +Proof. + apply mat_equiv_eq; unfold kron_comm; auto with wf_db. + print_LHS_matU. +*) + +Lemma kron_comm_transpose : forall p q, + (kron_comm p q) ⊤ = kron_comm q p. +Proof. + intros p q. + apply mat_equiv_eq; auto with wf_db. + 1: rewrite Nat.mul_comm; apply WF_kron_comm. + intros i j Hi Hj. + unfold kron_comm, transpose, make_WF. + rewrite andb_comm, Nat.mul_comm. + rewrite (andb_comm (_ =? _)). + easy. +Qed. + +Lemma kron_comm_1_r : forall p, + (kron_comm p 1) = Matrix.I p. +Proof. + intros p. + apply mat_equiv_eq; [|rewrite 1?Nat.mul_1_r|]; auto with wf_db. + intros s t Hs Ht. + unfold kron_comm. + unfold make_WF. + unfold Matrix.I. + rewrite Nat.mul_1_r, Nat.div_1_r, Nat.mod_1_r, Nat.div_small, Nat.mod_small by lia. + bdestructΩ'. +Qed. + +Lemma kron_comm_1_l : forall p, + (kron_comm 1 p) = Matrix.I p. +Proof. + intros p. + apply mat_equiv_eq; [|rewrite 1?Nat.mul_1_l|]; auto with wf_db. + intros s t Hs Ht. + unfold kron_comm. + unfold make_WF. + unfold Matrix.I. + rewrite Nat.mul_1_l, Nat.div_1_r, Nat.mod_1_r, Nat.div_small, Nat.mod_small by lia. + bdestructΩ'. +Qed. + +Definition mx_to_vec {n m} (A : Matrix n m) : Vector (n*m) := + make_WF (fun i j => A (i mod n)%nat (i / n)%nat + (* Note: goes columnwise. Rowwise would be: + make_WF (fun i j => A (i / m)%nat (i mod n)%nat + *) +). + +Lemma WF_mx_to_vec {n m} (A : Matrix n m) : WF_Matrix (mx_to_vec A). +Proof. unfold mx_to_vec; auto with wf_db. Qed. +#[export] Hint Resolve WF_mx_to_vec : wf_db. + +(* Compute vec_to_list (mx_to_vec (Matrix.I 2)). *) +From Coq Require Import ZArith. +Ltac Zify.zify_post_hook ::= PreOmega.Z.div_mod_to_equations. + +Lemma kron_comm_vec_helper : forall i p q, (i < p * q)%nat -> + (p * (i mod q) + i / q < p * q)%nat. +Proof. + intros i p q. + intros Hi. + assert (i / q < p)%nat by (apply Nat.div_lt_upper_bound; lia). + destruct p; [lia|]; + destruct q; [lia|]. + enough (S p * (i mod S q) <= S p * q)%nat by lia. + apply Nat.mul_le_mono; [lia | ]. + pose proof (Nat.mod_upper_bound i (S q) ltac:(easy)). + lia. +Qed. + +Lemma mx_to_vec_additive {n m} (A B : Matrix n m) : + mx_to_vec (A .+ B) = mx_to_vec A .+ mx_to_vec B. +Proof. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + replace j with O by lia; clear dependent j. + unfold mx_to_vec, make_WF, Mplus. + bdestructΩ'. +Qed. + +Lemma if_mult_dist_r (b : bool) (z : C) : + (if b then C1 else C0) * z = + if b then z else C0. +Proof. + destruct b; lca. +Qed. + +Lemma if_mult_dist_l (b : bool) (z : C) : + z * (if b then C1 else C0) = + if b then z else C0. +Proof. + destruct b; lca. +Qed. + +Lemma if_mult_and (b c : bool) : + (if b then C1 else C0) * (if c then C1 else C0) = + if (b && c) then C1 else C0. +Proof. + destruct b; destruct c; lca. +Qed. + +Lemma kron_comm_vec : forall p q (A : Matrix p q), + kron_comm p q × mx_to_vec A = mx_to_vec (A ⊤). +Proof. + intros p q A. + apply mat_equiv_eq; [|rewrite Nat.mul_comm|]; auto with wf_db. + intros i j Hi Hj. + replace j with O by lia; clear dependent j. + unfold transpose, mx_to_vec, kron_comm, make_WF, Mmult. + rewrite (Nat.mul_comm q p). + replace_bool_lia (i . + destruct p; [lia|]. + destruct q; [lia|]. + split. + + rewrite Nat.add_comm, Nat.mul_comm. + rewrite Nat.mod_add by easy. + rewrite Nat.mod_small; [lia|]. + apply Nat.div_lt_upper_bound; lia. + + rewrite Nat.mul_comm, Nat.div_add_l by easy. + rewrite Nat.div_small; [lia|]. + apply Nat.div_lt_upper_bound; lia. + - intros [Hmodp Hdivp]. + rewrite (Nat.div_mod_eq k p). + lia. + } + apply big_sum_unique. + exists (p * (i mod q) + i / q)%nat; repeat split; + [apply kron_comm_vec_helper; easy | rewrite Nat.eqb_refl | intros; bdestructΩ']. + destruct p; [lia|]; + destruct q; [lia|]. + f_equal. + - rewrite Nat.add_comm, Nat.mul_comm, Nat.mod_add, Nat.mod_small; try easy. + apply Nat.div_lt_upper_bound; lia. + - rewrite Nat.mul_comm, Nat.div_add_l by easy. + rewrite Nat.div_small; [lia|]. + apply Nat.div_lt_upper_bound; lia. +Qed. + +Lemma kron_comm_ei_kron_ei_sum : forall p q, + kron_comm p q = + big_sum (G:=Square (p*q)) (fun i => big_sum (G:=Square (p*q)) (fun j => + (@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤)) + q) p. +Proof. + intros p q. + apply mat_equiv_eq; auto with wf_db. + 1: apply WF_Msum; intros; apply WF_Msum; intros; + rewrite Nat.mul_comm; apply WF_mult; + auto with wf_db; rewrite Nat.mul_comm; + auto with wf_db. + intros i j Hi Hj. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros l Hl. + unfold Mmult, kron, transpose, e_i. + erewrite big_sum_eq_bounded. + 2: { + intros m Hm. + (* replace m with O by lia. *) + rewrite Nat.div_1_r, Nat.mod_1_r. + replace_bool_lia (m =? 0) true; rewrite 4!andb_true_r. + rewrite 3!if_mult_and. + match goal with + |- context[if ?b then _ else _] => + replace b with ((i =? k * q + l) && (j =? l * p + k)) + end. + 1: reflexivity. (* set our new function *) + clear dependent m. + rewrite eq_iff_eq_true, 8!andb_true_iff, + 6!Nat.eqb_eq, 4!Nat.ltb_lt. + split. + - intros [Hieq Hjeq]. + subst i j. + rewrite 2!Nat.div_add_l, Nat.div_small, Nat.add_0_r by lia. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small, + Nat.div_small, Nat.add_0_r by lia. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by lia. + easy. + - intros [[[] []] [[] []]]. + split. + + rewrite (Nat.div_mod_eq i q) by lia; lia. + + rewrite (Nat.div_mod_eq j p) by lia; lia. + } + simpl; rewrite Cplus_0_l. + reflexivity. + } + apply big_sum_unique. + exists (i mod q). + split; [|split]. + - apply Nat.mod_upper_bound; lia. + - reflexivity. + - intros l Hl Hnmod. + bdestructΩ'. + exfalso; apply Hnmod. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by lia; lia. + } + symmetry. + apply big_sum_unique. + exists (j mod p). + repeat split. + - apply Nat.mod_upper_bound; lia. + - unfold kron_comm, make_WF. + replace_bool_lia (i + enough (b = c) by bdestructΩ' + end. + rewrite eq_iff_eq_true, 2!andb_true_iff, 4!Nat.eqb_eq. + split. + + intros [Hieq Hjeq]. + split; [rewrite Hieq | rewrite Hjeq]; + rewrite Hieq, Nat.div_add_l by lia; + (rewrite Nat.div_small; [lia|]); + apply Nat.mod_upper_bound; lia. + + intros [Hidiv Hjdiv]. + rewrite (Nat.div_mod_eq i q) at 1 by lia. + rewrite (Nat.div_mod_eq j p) at 2 by lia. + lia. + - intros k Hk Hkmod. + bdestructΩ'. + exfalso; apply Hkmod. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by lia; lia. +Qed. + +Lemma kron_comm_ei_kron_ei_sum' : forall p q, + kron_comm p q = + big_sum (G:=Square (p*q)) (fun ij => + let i := (ij / q)%nat in let j := (ij mod q) in + ((@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤))) (p*q). +Proof. + intros p q. + rewrite kron_comm_ei_kron_ei_sum, big_sum_double_sum, Nat.mul_comm. + reflexivity. +Qed. + +Lemma div_eq_iff : forall a b c, b <> O -> + (a / b)%nat = c <-> (b * c <= a /\ a < b * (S c))%nat. +Proof. + intros a b c Hb. + split. + intros Hadivb. + split; + subst c. + etransitivity; [ + apply Nat.div_mul_le, Hb |]. + rewrite Nat.mul_comm, Nat.div_mul; easy. + apply Nat.mul_succ_div_gt, Hb. + intros [Hge Hlt]. + symmetry. + apply (Nat.div_unique _ _ _ (a - b*c)); [lia|]. + lia. +Qed. + +Lemma kron_e_i_transpose_l : forall k n m o (A : Matrix m o), (k < n)%nat -> + (o <> O) -> (m <> O) -> + (@e_i n k)⊤ ⊗ A = (fun i j => + if (i + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A = (fun i j => + if (j + (o <> O) -> (m <> O) -> + (@e_i n k)⊤ ⊗ A = (fun i j => + if (i ((j/o)%nat=k)) by lia; + rewrite Hrw; clear Hrw. + symmetry. + rewrite div_eq_iff by lia. + lia. + - replace (i / m =? 0) with false. + rewrite andb_false_r; easy. + symmetry. + rewrite Nat.eqb_neq. + rewrite Nat.div_small_iff; lia. +Qed. + +Lemma kron_e_i_l' : forall k n m o (A : Matrix m o), (k < n)%nat -> + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A = (fun i j => + if (j ((i/m)%nat=k)) by lia; + rewrite Hrw; clear Hrw. + symmetry. + rewrite div_eq_iff by lia. + lia. + - replace (j / o =? 0) with false. + rewrite andb_false_r; easy. + symmetry. + rewrite Nat.eqb_neq. + rewrite Nat.div_small_iff; lia. +Qed. + +Lemma kron_e_i_r : forall k n m o (A : Matrix m o), (k < n)%nat -> + (o <> O) -> (m <> O) -> + A ⊗ (@e_i n k) = (fun i j => + if (i mod n =? k) then A (i / n)%nat j else 0). +Proof. + intros k n m o A Hk Ho Hm. + apply functional_extensionality; intros i; + apply functional_extensionality; intros j. + unfold kron, e_i. + rewrite if_mult_dist_l, Nat.div_1_r. + rewrite Nat.mod_1_r, Nat.eqb_refl, andb_true_r. + replace (i mod n + (o <> O) -> (m <> O) -> + A ⊗ (@e_i n k) ⊤ = (fun i j => + if (j mod n =? k) then A i (j / n)%nat else 0). +Proof. + intros k n m o A Hk Ho Hm. + apply functional_extensionality; intros i; + apply functional_extensionality; intros j. + unfold kron, transpose, e_i. + rewrite if_mult_dist_l, Nat.div_1_r. + rewrite Nat.mod_1_r, Nat.eqb_refl, andb_true_r. + replace (j mod n m <> O -> + (@e_i n k) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n k) = + (fun i j => if (i mod n =? k) && (j / m =? k)%nat + && (i / n =? j - k * m) && (i / n + (@e_i n j) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n j)) n. +Proof. + intros m n. + apply mat_equiv_eq; auto with wf_db. + 1: apply WF_Msum; intros; apply WF_kron; auto with wf_db arith. + intros i j Hi Hj. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros ij Hij. + rewrite ei_kron_I_kron_ei by lia. + reflexivity. + } + unfold kron_comm, make_WF. + replace_bool_lia (i + (@e_i m i) ⊗ (Matrix.I n) ⊗ (@e_i m i)⊤) m. +Proof. + intros. + rewrite <- (kron_comm_transpose n m). + rewrite (kron_comm_kron_form_sum n m). + rewrite Msum_transpose. + apply big_sum_eq_bounded. + intros k Hk. + rewrite Nat.mul_1_l. + pose proof (kron_transpose _ _ _ _ ((@e_i m k) ⊤ ⊗ Matrix.I n) (@e_i m k)) as H; + rewrite Nat.mul_1_l, Nat.mul_1_r in H; + rewrite (Nat.mul_comm n m), H in *; clear H. + pose proof (kron_transpose _ _ _ _ ((@e_i m k) ⊤) (Matrix.I n)) as H; + rewrite Nat.mul_1_l in H; + rewrite H; clear H. + rewrite transpose_involutive, id_transpose_eq; easy. +Qed. + +Lemma e_i_dot_is_component : forall p k (x : Vector p), + (k < p)%nat -> WF_Matrix x -> + (@e_i p k) ⊤ × x = x k O .* Matrix.I 1. +Proof. + intros p k x Hk HWF. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj; + replace i with O by lia; + replace j with O by lia; + clear i Hi; + clear j Hj. + unfold Mmult, transpose, scale, e_i, Matrix.I. + bdestructΩ'. + rewrite Cmult_1_r. + apply big_sum_unique. + exists k. + split; [easy|]. + bdestructΩ'. + rewrite Cmult_1_l. + split; [easy|]. + intros l Hl Hkl. + bdestructΩ'; lca. +Qed. + +Lemma kron_e_i_e_i : forall p q k l, + (k < p)%nat -> (l < q)%nat -> + @e_i q l ⊗ @e_i p k = @e_i (p*q) (l*p + k). +Proof. + intros p q k l Hk Hl. + apply functional_extensionality; intro i. + apply functional_extensionality; intro j. + unfold kron, e_i. + rewrite Nat.mod_1_r, Nat.div_1_r. + rewrite if_mult_and. + lazymatch goal with + |- (if ?b then _ else _) = (if ?c then _ else _) => + enough (H : b = c) by (rewrite H; easy) + end. + rewrite Nat.eqb_refl, andb_true_r. + destruct (j =? 0); [|rewrite 2!andb_false_r; easy]. + rewrite 2!andb_true_r. + rewrite eq_iff_eq_true, 4!andb_true_iff, 3!Nat.eqb_eq, 3!Nat.ltb_lt. + split. + - intros [[] []]. + rewrite (Nat.div_mod_eq i p). + split; nia. + - intros []. + subst i. + rewrite Nat.div_add_l, Nat.div_small, Nat.add_0_r, + Nat.add_comm, Nat.mod_add, Nat.mod_small by lia. + easy. +Qed. + +Lemma kron_eq_sum : forall p q (x : Vector q) (y : Vector p), + WF_Matrix x -> WF_Matrix y -> + y ⊗ x = big_sum (fun ij => + let i := (ij / q)%nat in let j := ij mod q in + (x j O * y i O) .* (@e_i p i ⊗ @e_i q j)) (p * q). +Proof. + intros p q x y Hwfx Hwfy. + + erewrite big_sum_eq_bounded. + 2: { + intros ij Hij. + simpl. + rewrite (@kron_e_i_e_i q p) by + (try apply Nat.mod_upper_bound; try apply Nat.div_lt_upper_bound; lia). + rewrite (Nat.mul_comm (ij / q) q). + rewrite <- (Nat.div_mod_eq ij q). + reflexivity. + } + apply mat_equiv_eq; [|rewrite Nat.mul_comm|]; auto with wf_db. + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + simpl. + rewrite Msum_Csum. + symmetry. + apply big_sum_unique. + exists i. + split; [lia|]. + unfold e_i; split. + - unfold scale, kron; bdestructΩ'. + rewrite Cmult_1_r, Cmult_comm. + easy. + - intros j Hj Hij. + unfold scale, kron; bdestructΩ'. + rewrite Cmult_0_r. + easy. +Qed. + +Lemma kron_comm_commutes_vectors_l : forall p q (x : Vector q) (y : Vector p), + WF_Matrix x -> WF_Matrix y -> + kron_comm p q × (x ⊗ y) = (y ⊗ x). +Proof. + intros p q x y Hwfx Hwfy. + rewrite kron_comm_ei_kron_ei_sum', Mmult_Msum_distr_r. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + simpl. + match goal with + |- (?A × ?B) × ?C = _ => + assert (Hassoc: (A × B) × C = A × (B × C)) by apply Mmult_assoc + end. + simpl in Hassoc. + rewrite (Nat.mul_comm q p) in *. + rewrite Hassoc. clear Hassoc. + pose proof (kron_transpose _ _ _ _ (@e_i q (k mod q)) (@e_i p (k / q))) as Hrw; + rewrite (Nat.mul_comm q p) in Hrw; + simpl in Hrw; rewrite Hrw; clear Hrw. + pose proof (kron_mixed_product ((e_i (k mod q)) ⊤) ((e_i (k / q)) ⊤) x y) as Hrw; + rewrite (Nat.mul_comm q p) in Hrw; + simpl in Hrw; rewrite Hrw; clear Hrw. + rewrite 2!e_i_dot_is_component; [| + apply Nat.div_lt_upper_bound; lia | + easy | + apply Nat.mod_upper_bound; lia | + easy]. + rewrite Mscale_kron_dist_l, Mscale_kron_dist_r, Mscale_assoc. + rewrite kron_1_l, Mscale_mult_dist_r, Mmult_1_r by auto with wf_db. + reflexivity. + } + rewrite <- kron_eq_sum; easy. +Qed. + +Lemma kron_basis_vector_basis_vector : forall p q k l, + (k < p)%nat -> (l < q)%nat -> + basis_vector q l ⊗ basis_vector p k = basis_vector (p*q) (l*p + k). +Proof. + intros p q k l Hk Hl. + apply functional_extensionality; intros i. + apply functional_extensionality; intros j. + unfold kron, basis_vector. + rewrite Nat.mod_1_r, Nat.div_1_r, Nat.eqb_refl, andb_true_r, if_mult_and. + bdestructΩ'; + try pose proof (Nat.div_mod_eq i p); + try nia. + rewrite Nat.div_add_l, Nat.div_small in * by lia. + lia. +Qed. + +Lemma kron_extensionality : forall n m s t (A B : Matrix (n*m) (s*t)), + WF_Matrix A -> WF_Matrix B -> + (forall (x : Vector s) (y :Vector t), + WF_Matrix x -> WF_Matrix y -> + A × (x ⊗ y) = B × (x ⊗ y)) -> + A = B. +Proof. + intros b n s t A B HwfA HwfB Hext. + apply equal_on_basis_vectors_implies_equal; try easy. + intros i Hi. + + pose proof (Nat.div_lt_upper_bound i t s ltac:(lia) ltac:(lia)). + pose proof (Nat.mod_upper_bound i s ltac:(lia)). + pose proof (Nat.mod_upper_bound i t ltac:(lia)). + + specialize (Hext (basis_vector s (i / t)) (basis_vector t (i mod t)) + ltac:(apply basis_vector_WF; easy) + ltac:(apply basis_vector_WF; easy) + ). + rewrite (kron_basis_vector_basis_vector t s) in Hext by lia. + + simpl in Hext. + rewrite (Nat.mul_comm (i/t) t), <- (Nat.div_mod_eq i t) in Hext. + rewrite (Nat.mul_comm t s) in Hext. easy. +Qed. + +Lemma kron_comm_commutes : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + kron_comm m n × (A ⊗ B) × (kron_comm s t) = (B ⊗ A). +Proof. + intros n s m t A B HwfA HwfB. + apply (kron_extensionality _ _ t s); [| + apply WF_kron; try easy; lia |]. + rewrite (Nat.mul_comm t s); apply WF_mult; auto with wf_db; + apply WF_mult; auto with wf_db; + rewrite Nat.mul_comm; auto with wf_db. + (* rewrite Nat.mul_comm; apply WF_mult; [rewrite Nat.mul_comm|auto with wf_db]; + apply WF_mult; auto with wf_db; rewrite Nat.mul_comm; auto with wf_db. *) + intros x y Hwfx Hwfy. + (* simpl. *) + (* Search "assoc" in Matrix. *) + rewrite (Nat.mul_comm s t). + rewrite (Mmult_assoc (_ × _)). + rewrite (Nat.mul_comm t s). + rewrite kron_comm_commutes_vectors_l by easy. + rewrite Mmult_assoc, (Nat.mul_comm m n). + rewrite kron_mixed_product. + rewrite (Nat.mul_comm n m), kron_comm_commutes_vectors_l by (auto with wf_db). + rewrite <- kron_mixed_product. + f_equal; lia. +Qed. + +Lemma commute_kron : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + (A ⊗ B) = kron_comm n m × (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B HA HB. + rewrite (kron_comm_commutes m t n s B A HB HA); easy. +Qed. + +#[export] Hint Extern 4 (WF_Matrix (@Mmult ?m ?n ?o ?A ?B)) => ( + match type of A with Matrix ?m' ?n' => + match type of B with Matrix ?n'' ?o'' => + let Hm' := fresh "Hm'" in let Hn' := fresh "Hn'" in + let Hn'' := fresh "Hn''" in let Ho'' := fresh "Hoo'" in + assert (Hm' : m = m') by lia; + assert (Hn' : n = n') by lia; + assert (Hn'' : n = n'') by lia; + assert (Ho' : o = o'') by lia; + replace (@Mmult m n o A B) with (@Mmult m' n' o A B) + by (first [try (rewrite Hm' at 1); try (rewrite Hn' at 1); reflexivity | f_equal; lia]); + apply WF_mult; [ + auto with wf_db | + apply WF_Matrix_dim_change; + auto with wf_db + ] + end end) : wf_db. + +Lemma kron_comm_mul_inv : forall p q, + kron_comm p q × kron_comm q p = Matrix.I _. +Proof. + intros p q. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + unfold Mmult, kron_comm, make_WF. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + rewrite <- 2!andb_if, if_mult_and. + replace_bool_lia (k + replace b with ((i =? j) && (k =? (i mod q) * p + (j/q))) + end; + [reflexivity|]. + rewrite eq_iff_eq_true, 4!andb_true_iff, 6!Nat.eqb_eq. + split. + - intros [? ?]; subst. + destruct p; [easy|destruct q;[lia|]]. + assert (j / S q < S p)%nat by (apply Nat.div_lt_upper_bound; lia). + rewrite Nat.div_add_l, (Nat.div_small (j / (S q))), Nat.add_0_r by easy. + rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by easy. + easy. + - intros [[Hiqkp Hkpiq] [Hkpjq Hjqkp]]. + split. + + rewrite (Nat.div_mod_eq i q), (Nat.div_mod_eq j q). + lia. + + rewrite (Nat.div_mod_eq k p). + lia. + } + bdestruct (i =? j). + - subst. + apply big_sum_unique. + exists ((j mod q) * p + (j/q))%nat. + split; [|split]. + + rewrite Nat.mul_comm. apply kron_comm_vec_helper; easy. + + unfold Matrix.I. + rewrite Nat.eqb_refl; bdestructΩ'. + + intros; bdestructΩ'. + - unfold Matrix.I. + replace_bool_lia (i =? j) false. + rewrite andb_false_l. + rewrite big_sum_0; [easy|]. + intros; rewrite andb_false_l; easy. +Qed. + +Lemma kron_comm_mul_transpose_r : forall p q, + kron_comm p q × (kron_comm p q) ⊤ = Matrix.I _. +Proof. + intros p q. + rewrite (kron_comm_transpose p q). + apply kron_comm_mul_inv. +Qed. + +Lemma kron_comm_mul_transpose_l : forall p q, + (kron_comm p q) ⊤ × kron_comm p q = Matrix.I _. +Proof. + intros p q. + rewrite <- (kron_comm_transpose q p). + rewrite (Nat.mul_comm p q). + rewrite (transpose_involutive _ _ (kron_comm q p)). + apply kron_comm_mul_transpose_r. +Qed. + +Lemma kron_comm_commutes_l : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + kron_comm m n × (A ⊗ B) = (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B HwfA HwfB. + match goal with |- ?A = ?B => + rewrite <- (Mmult_1_r _ _ A), <- (Mmult_1_r _ _ B) ; auto with wf_db + end. + rewrite (Nat.mul_comm t s). + rewrite <- (kron_comm_mul_transpose_r), <- 2!Mmult_assoc. + rewrite (kron_comm_commutes n s m t) by easy. + apply Mmult_simplify; [|easy]. + rewrite Mmult_assoc. + rewrite (Nat.mul_comm s t), (kron_comm_mul_inv t s), Mmult_1_r by auto with wf_db. + easy. +Qed. + +Lemma kron_comm_commutes_r : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + (A ⊗ B) × kron_comm s t = (kron_comm n m) × (B ⊗ A). +Proof. + intros n s m t A B HA HB. + rewrite kron_comm_commutes_l; easy. +Qed. + + + +(* Lemma kron_comm_commutes_r : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + kron_comm m n × (A ⊗ B) = (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B HwfA HwfB. + match goal with |- ?A = ?B => + rewrite <- (Mmult_1_r _ _ A), <- (Mmult_1_r _ _ B) ; auto with wf_db + end. + rewrite (Nat.mul_comm t s). + rewrite <- (kron_comm_mul_transpose_r), <- 2!Mmult_assoc. + rewrite (kron_comm_commutes n s m t) by easy. + apply Mmult_simplify; [|easy]. + rewrite Mmult_assoc. + rewrite (Nat.mul_comm s t), (kron_comm_mul_inv t s), Mmult_1_r by auto with wf_db. + easy. +Qed. *) + + + + +Lemma vector_eq_basis_comb : forall n (y : Vector n), + WF_Matrix y -> + y = big_sum (G:=Vector n) (fun i => y i O .* @e_i n i) n. +Proof. + intros n y Hwfy. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + symmetry. + rewrite Msum_Csum. + apply big_sum_unique. + exists i. + repeat split; try easy. + - unfold ".*", e_i; bdestructΩ'; lca. + - intros l Hl Hnk. + unfold ".*", e_i; bdestructΩ'; lca. +Qed. + +Lemma kron_vecT_matrix_vec : forall m n o p + (P : Matrix m o) (y : Vector n) (z : Vector p), + WF_Matrix y -> WF_Matrix z -> WF_Matrix P -> + (z⊤) ⊗ P ⊗ y = @Mmult (m*n) (m*n) (o*p) (kron_comm m n) ((y × (z⊤)) ⊗ P). +Proof. + intros m n o p P y z Hwfy Hwfz HwfP. + match goal with |- ?A = ?B => + rewrite <- (Mmult_1_l _ _ A) ; auto with wf_db + end. + rewrite Nat.mul_1_l. + rewrite <- (kron_comm_mul_transpose_r), Mmult_assoc at 1. + rewrite Nat.mul_1_r, (Nat.mul_comm o p). + apply Mmult_simplify; [easy|]. + rewrite kron_comm_kron_form_sum. + rewrite Msum_transpose. + rewrite Mmult_Msum_distr_r. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + pose proof (kron_transpose _ _ _ _ ((@e_i n k) ⊤ ⊗ Matrix.I m) (@e_i n k)) as H; + rewrite Nat.mul_1_l, Nat.mul_1_r, (Nat.mul_comm m n) in *; + rewrite H; clear H. + pose proof (kron_transpose _ _ _ _ ((@e_i n k) ⊤) (Matrix.I m)) as H; + rewrite Nat.mul_1_l in *; + rewrite H; clear H. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite id_transpose_eq, Mmult_1_l by easy. + rewrite e_i_dot_is_component, transpose_involutive by easy. + (* rewrite <- Mmult_transpose. *) + rewrite Mscale_kron_dist_r, <- 2!Mscale_kron_dist_l. + rewrite kron_1_r. + rewrite <- Mscale_mult_dist_l. + reflexivity. + } + rewrite <- (kron_Msum_distr_r n _ P). + rewrite <- (Mmult_Msum_distr_r). + rewrite <- vector_eq_basis_comb by easy. + easy. +Qed. + +Lemma kron_vec_matrix_vecT : forall m n o p + (Q : Matrix n o) (x : Vector m) (z : Vector p), + WF_Matrix x -> WF_Matrix z -> WF_Matrix Q -> + x ⊗ Q ⊗ (z⊤) = @Mmult (m*n) (m*n) (o*p) (kron_comm m n) (Q ⊗ (x × z⊤)). +Proof. + intros m n o p Q x z Hwfx Hwfz HwfQ. + match goal with |- ?A = ?B => + rewrite <- (Mmult_1_l _ _ A) ; auto with wf_db + end. + rewrite Nat.mul_1_r. + rewrite <- (kron_comm_mul_transpose_r), Mmult_assoc at 1. + rewrite Nat.mul_1_l. + apply Mmult_simplify; [easy|]. + rewrite kron_comm_kron_form_sum'. + rewrite Msum_transpose. + rewrite Mmult_Msum_distr_r. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + pose proof (kron_transpose _ _ _ _ ((@e_i m k) ⊗ Matrix.I n) ((@e_i m k) ⊤)) as H; + rewrite Nat.mul_1_l, Nat.mul_1_r, (Nat.mul_comm m n) in *; + rewrite H; clear H. + pose proof (kron_transpose _ _ _ _ ((@e_i m k)) (Matrix.I n)) as H; + rewrite Nat.mul_1_l, (Nat.mul_comm m n) in *; + rewrite H; clear H. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite id_transpose_eq, Mmult_1_l by easy. + rewrite e_i_dot_is_component, transpose_involutive by easy. + (* rewrite <- Mmult_transpose. *) + rewrite 2!Mscale_kron_dist_l, kron_1_l, <-Mscale_kron_dist_r by easy. + rewrite <- Mscale_mult_dist_l. + restore_dims. + reflexivity. + } + rewrite <- (kron_Msum_distr_l m _ Q). + rewrite <- (Mmult_Msum_distr_r). + rewrite <- vector_eq_basis_comb by easy. + easy. +Qed. + +Lemma kron_comm_triple_cycle : forall m n s t p q (A : Matrix m n) + (B : Matrix s t) (C : Matrix p q), WF_Matrix A -> WF_Matrix B -> WF_Matrix C -> + A ⊗ B ⊗ C = (kron_comm (m*s) p) × (C ⊗ A ⊗ B) × (kron_comm q (t*n)). +Proof. + intros m n s t p q A B C HA HB HC. + rewrite (commute_kron _ _ _ _ (A ⊗ B) C) by auto with wf_db. + rewrite kron_assoc by easy. + f_equal; try lia; f_equal; lia. +Qed. + +Lemma kron_comm_triple_cycle2 : forall m n s t p q (A : Matrix m n) + (B : Matrix s t) (C : Matrix p q), WF_Matrix A -> WF_Matrix B -> WF_Matrix C -> + A ⊗ B ⊗ C = (kron_comm m (s*p)) × (B ⊗ C ⊗ A) × (kron_comm (q*t) n). +Proof. + intros m n s t p q A B C HA HB HC. + rewrite kron_assoc by easy. + rewrite (commute_kron _ _ _ _ A (B ⊗ C)) by auto with wf_db. + f_equal; try lia; f_equal; lia. +Qed. + +Lemma id_eq_sum_kron_e_is : forall n, + Matrix.I n = big_sum (G:=Square n) (fun i => @e_i n i ⊗ (@e_i n i) ⊤) n. +Proof. + intros n. + symmetry. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + rewrite kron_e_i_l by lia. + unfold transpose, e_i. + rewrite <- andb_if. + replace_bool_lia (j rewrite (commute_kron _ _ _ _ A B) by auto with wf_db + end. + (* restore_dims. *) + reflexivity. + } + (* rewrite ?Nat.mul_1_r, ?Nat.mul_1_l. *) + (* rewrite <- Mmult_Msum_distr_r. *) + + rewrite <- (Mmult_Msum_distr_r n _ (kron_comm (t*1) (n*s))). + rewrite <- Mmult_Msum_distr_l. + erewrite big_sum_eq_bounded. + 2: { + intros j Hj. + rewrite <- kron_assoc, (kron_assoc (Matrix.I t)) by auto with wf_db. + restore_dims. + reflexivity. + } + (* rewrite Nat.mul_1_l *) + rewrite <- (kron_Msum_distr_r n _ (Matrix.I s)). + rewrite <- (kron_Msum_distr_l n _ (Matrix.I t)). + rewrite 2!Nat.mul_1_r, 2!Nat.mul_1_l. + rewrite <- (id_eq_sum_kron_e_is n). + rewrite 2!id_kron. + restore_dims. + rewrite Mmult_1_r by auto with wf_db. + rewrite (Nat.mul_comm t n), (Nat.mul_comm n s). + easy. +Qed. + +Lemma kron_comm_cycle_indices_rev : forall t s n, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) = kron_comm (t*s) n. +Proof. + intros. + rewrite <- kron_comm_cycle_indices. + easy. +Qed. + + +Lemma kron_comm_triple_id : forall t s n, + (kron_comm (t*s) n) × (kron_comm (s*n) t) × (kron_comm (n*t) s) = Matrix.I (t*s*n). +Proof. + intros t s n. + rewrite kron_comm_cycle_indices. + restore_dims. + rewrite (Mmult_assoc (kron_comm s (n*t))). + restore_dims. + rewrite (Nat.mul_comm (s*n) t). (* TODO: Fix kron_comm_mul_inv to have the + right dimensions by default (or, better yet, to be ambivalent) *) + rewrite (kron_comm_mul_inv t (s*n)). + restore_dims. + rewrite Mmult_1_r by auto with wf_db. + rewrite (Nat.mul_comm (n*t) s). + rewrite (kron_comm_mul_inv). + f_equal; lia. +Qed. + +Lemma kron_comm_triple_id' : forall t s n, + (kron_comm n (t*s)) × (kron_comm t (s*n)) × (kron_comm s (n*t)) = Matrix.I (t*s*n). +Proof. + intros t s n. + apply transpose_matrices. + rewrite 2!Mmult_transpose. + (* restore_dims. *) + rewrite (kron_comm_transpose s (n*t)). + restore_dims. + rewrite (kron_comm_transpose n (t*s)). + restore_dims. + replace (n*(t*s))%nat with (t*(s*n))%nat by lia. + replace (s*(n*t))%nat with (t*(s*n))%nat by lia. + rewrite (kron_comm_transpose t (s*n)). + restore_dims. + rewrite Nat.mul_assoc, id_transpose_eq. + restore_dims. + (* rewrite (kron_comm_triple_id n t s). *) + Admitted. + +Lemma kron_comm_triple_indices_commute : forall t s n, + @Mmult (s*t*n) (s*t*n) (t*(s*n)) (kron_comm (s*t) n) (kron_comm t (s*n)) = + @Mmult (t*(s*n)) (t*(s*n)) (s*t*n) (kron_comm t (s*n)) (kron_comm (s*t) n). +Proof. + intros t s n. + rewrite <- (kron_comm_cycle_indices_rev s t n). + Admitted. + (* rewrite kron_comm_triple_id. +rewrite <- (kron_comm_cycle_indices t s n). *) + + +Lemma f_to_vec_split : forall (f : nat -> bool) (m n : nat), + f_to_vec (m + n) f = f_to_vec m f ⊗ f_to_vec n (VectorStates.shift f m). +Proof. + intros f m n. + rewrite f_to_vec_merge. + apply f_to_vec_eq. + intros i Hi. + unfold VectorStates.shift. + bdestructΩ'. + f_equal; lia. +Qed. + +Lemma n_top_to_bottom_semantics_eq_kron_comm : forall n o, + ⟦ n_top_to_bottom n o ⟧ = kron_comm (2^o) (2^n). +Proof. + intros n o. + rewrite zxperm_permutation_semantics by auto with zxperm_db. + unfold zxperm_to_matrix. + rewrite perm_of_n_top_to_bottom. + apply equal_on_basis_states_implies_equal; auto with wf_db. + 1: { + rewrite Nat.add_comm, Nat.pow_add_r. + auto with wf_db. + } + intros f. + pose proof (perm_to_matrix_permutes_qubits (n + o) (rotr (n+o) n) f). + unfold perm_to_matrix in H. + rewrite H by auto with perm_db. + rewrite (f_to_vec_split f). + pose proof (kron_comm_commutes_vectors_l (2^o) (2^n) + (f_to_vec n f) (f_to_vec o (@VectorStates.shift bool f n)) + ltac:(auto with wf_db) ltac:(auto with wf_db)). + replace (2^(n+o))%nat with (2^o *2^n)%nat by (rewrite Nat.pow_add_r; lia). + simpl in H0. + rewrite H0. + rewrite Nat.add_comm, f_to_vec_split. + f_equal. + - apply f_to_vec_eq. + intros i Hi. + unfold VectorStates.shift. + f_equal; unfold rotr. + rewrite Nat.mod_small by lia. + bdestructΩ'. + - apply f_to_vec_eq. + intros i Hi. + unfold VectorStates.shift, rotr. + rewrite <- Nat.add_assoc, mod_add_n_r, Nat.mod_small by lia. + bdestructΩ'. +Qed. + + +Lemma compose_semantics' : +forall {n m o : nat} (zx0 : ZX n m) (zx1 : ZX m o), +@eq (Matrix (Nat.pow 2 o) (Nat.pow 2 n)) + (@ZX_semantics n o (@Compose n m o zx0 zx1)) + (@Mmult (Nat.pow 2 o) (Nat.pow 2 m) (Nat.pow 2 n) + (@ZX_semantics m o zx1) (@ZX_semantics n m zx0)). +Proof. + intros. + rewrite (@compose_semantics n m o). + easy. +Qed. diff --git a/examples/KronMatrixExample.v b/examples/KronMatrixExample.v new file mode 100644 index 0000000..a61658e --- /dev/null +++ b/examples/KronMatrixExample.v @@ -0,0 +1,103 @@ + +Require Import MatrixPermBase. +Require Import KronComm. +Require Export MatrixExampleBase. +From ViCaR Require Import ExamplesAutomation. + + +#[export] Instance MxCategory : Category nat := { + morphism := Matrix; + + equiv := @mat_equiv; (* fun m n => @eq (Matrix m n); *) + equiv_rel := @mat_equiv_equivalence; + + compose := @Mmult; + compose_compat := fun n m o f g Hfg h i Hhi => + @Mmult_simplify_mat_equiv n m o f g h i Hfg Hhi; + assoc := @Mmult_assoc_mat_equiv; + + c_identity n := I n; + left_unit := Mmult_1_l_mat_eq; + right_unit := Mmult_1_r_mat_eq; +}. + + +Definition MxKronBiFunctor : Bifunctor MxCategory MxCategory MxCategory := {| + obj2_map := Nat.mul; + morphism2_map := @kron; + id2_map := ltac:(intros; rewrite id_kron; easy); + compose2_map := ltac:(intros; rewrite kron_mixed_product; easy); + morphism2_compat := ltac:(intros; apply kron_mat_equiv_morph; easy); +|}. + + + +#[export] Instance MxKronMonoidalCategory : MonoidalCategory nat := { + tensor := MxKronBiFunctor; + I := 1; + + associator := fun n m o => {| + forward := (I (n * m * o) : Matrix (n * m * o) (n * (m * o))); + reverse := (I (n * m * o) : Matrix (n * (m * o)) (n * m * o)); + id_A := ltac:(simpl; rewrite Nat.mul_assoc, Mmult_1_r_mat_eq; easy); + id_B := ltac:(simpl; rewrite Nat.mul_assoc, Mmult_1_r_mat_eq; easy); + |}; + + left_unitor := fun n => {| + forward := (I n : Matrix (1 * n) n); + reverse := (I n : Matrix n (1 * n)); + id_A := ltac:(rewrite Nat.mul_1_l, Mmult_1_r_mat_eq; easy); + id_B := ltac:(rewrite Nat.mul_1_l, Mmult_1_r_mat_eq; easy); + |}; + + right_unitor := fun n => {| + forward := (I n : Matrix (n * 1) n); + reverse := (I n : Matrix n (n * 1)); + id_A := ltac:(rewrite Nat.mul_1_r, Mmult_1_r_mat_eq; easy); + id_B := ltac:(rewrite Nat.mul_1_r, Mmult_1_r_mat_eq; easy); + |}; + + associator_cohere := ltac:(intros; simpl in *; + rewrite kron_assoc_mat_equiv; + rewrite 2!Nat.mul_assoc, Mmult_1_r_mat_eq, Mmult_1_l_mat_eq; + easy + ); + left_unitor_cohere := ltac:(intros; cbn; + rewrite kron_1_l_mat_equiv, 2!Nat.add_0_r, + Mmult_1_l_mat_eq, Mmult_1_r_mat_eq; easy); + right_unitor_cohere := ltac:(intros; cbn; + rewrite kron_1_r_mat_equiv, 2!Nat.mul_1_r, + Mmult_1_l_mat_eq, Mmult_1_r_mat_eq; easy); + + pentagon := ltac:(intros; simpl in *; + rewrite ?Nat.mul_assoc, 2!id_kron, Mmult_1_l_mat_eq; + rewrite ?Nat.mul_assoc, Mmult_1_l_mat_eq; + easy + ); + triangle := ltac:(intros; + cbn; + rewrite Nat.mul_1_r, Nat.add_0_r in *; + rewrite Mmult_1_l_mat_eq; + easy + ); +}. + +Definition MxKronBraidingIsomorphism : forall n m, + Isomorphism (MxKronBiFunctor n m) ((CommuteBifunctor MxKronBiFunctor) n m) := + fun n m => Build_Isomorphism nat MxCategory (n*m)%nat (m*n)%nat + (kron_comm n m) (kron_comm m n) + ltac:(intros; simpl; + rewrite (Nat.mul_comm m n), (kron_comm_mul_inv n m); easy) + ltac:(intros; simpl; + rewrite (Nat.mul_comm n m), (kron_comm_mul_inv m n); easy). + + + +#[export] Instance MxKronBraidingBiIsomorphism : + NaturalBiIsomorphism MxKronBiFunctor (CommuteBifunctor MxKronBiFunctor) := {| + component2_iso := MxKronBraidingIsomorphism; + component2_iso_natural := ltac:(intros; simpl in *; + rewrite (Nat.mul_comm B2 B1), (Nat.mul_comm A2 A1); + rewrite (kron_comm_commutes_r_mat_equiv); + easy); +|}. \ No newline at end of file diff --git a/examples/MatrixExampleBase.v b/examples/MatrixExampleBase.v new file mode 100644 index 0000000..cd7935e --- /dev/null +++ b/examples/MatrixExampleBase.v @@ -0,0 +1,263 @@ +Require Export Setoid. +Require Export Morphisms. + +From Coq Require Export ZArith. +Ltac Zify.zify_post_hook ::= PreOmega.Z.div_mod_to_equations. + +From VyZX Require Export PermutationAutomation PermutationFacts PermutationInstances. +From ViCaR Require Export CategoryTypeclass. +From QuantumLib Require Export Matrix. +From ViCaR Require Import ExamplesAutomation. + +Open Scope matrix_scope. + +Lemma mat_equiv_sym : forall {n m : nat} (A B : Matrix n m), + A ≡ B -> B ≡ A. +Proof. + intros n m A B HAB i j Hi Hj. + rewrite HAB by easy. + easy. +Qed. + +Lemma mat_equiv_trans : forall {n m : nat} (A B C : Matrix n m), + A ≡ B -> B ≡ C -> A ≡ C. +Proof. + intros n m A B C HAB HBC i j Hi Hj. + rewrite HAB, HBC by easy. + easy. +Qed. + +Add Parametric Relation {n m} : (Matrix n m) mat_equiv + reflexivity proved by (mat_equiv_refl _ _) + symmetry proved by (mat_equiv_sym) + transitivity proved by (mat_equiv_trans) + as mat_equiv_rel. + +Lemma mat_equiv_eq_iff {n m} : forall (A B : Matrix n m), + WF_Matrix A -> WF_Matrix B -> A ≡ B <-> A = B. +Proof. + intros; split; try apply mat_equiv_eq; + intros; try subst A; easy. +Qed. + +Lemma Mmult_simplify_mat_equiv : forall {n m o} + (A B : Matrix n m) (C D : Matrix m o), + A ≡ B -> C ≡ D -> A × C ≡ B × D. +Proof. + intros n m o A B C D HAB HCD. + intros i j Hi Hj. + unfold Mmult. + apply big_sum_eq_bounded. + intros k Hk. + rewrite HAB, HCD by easy. + easy. +Qed. + +Add Parametric Morphism {n m o} : (@Mmult n m o) + with signature (@mat_equiv n m) ==> (@mat_equiv m o) ==> (@mat_equiv n o) + as mmult_mat_equiv_morph. +Proof. intros; apply Mmult_simplify_mat_equiv; easy. Qed. + +Lemma kron_simplify_mat_equiv {n m o p} : forall (A B : Matrix n m) + (C D : Matrix o p), A ≡ B -> C ≡ D -> A ⊗ C ≡ B ⊗ D. +Proof. + intros A B C D HAB HCD i j Hi Hj. + unfold kron. + rewrite HAB, HCD; try easy. + 1,2: apply Nat.mod_upper_bound; lia. + 1,2: apply Nat.div_lt_upper_bound; lia. +Qed. + +Add Parametric Morphism {n m o p} : (@kron n m o p) + with signature (@mat_equiv n m) ==> (@mat_equiv o p) + ==> (@mat_equiv (n*o) (m*p)) as kron_mat_equiv_morph. +Proof. intros; apply kron_simplify_mat_equiv; easy. Qed. + +Lemma Mplus_simplify_mat_equiv : forall {n m} + (A B C D : Matrix n m), + A ≡ B -> C ≡ D -> A .+ C ≡ B .+ D. +Proof. + intros n m A B C D HAB HCD. + intros i j Hi Hj; unfold ".+"; + rewrite HAB, HCD; try easy. +Qed. + +Add Parametric Morphism {n m} : (@Mplus n m) + with signature (@mat_equiv n m) ==> (@mat_equiv n m) ==> (@mat_equiv n m) + as Mplus_mat_equiv_morph. +Proof. intros; apply Mplus_simplify_mat_equiv; easy. Qed. + + Lemma scale_simplify_mat_equiv : forall {n m} + (x y : C) (A B : Matrix n m), + x = y -> A ≡ B -> x .* A ≡ y .* B. +Proof. + intros n m x y A B Hxy HAB i j Hi Hj. + unfold scale. + rewrite Hxy, HAB; easy. +Qed. + +Add Parametric Morphism {n m} : (@scale n m) + with signature (@eq C) ==> (@mat_equiv n m) ==> (@mat_equiv n m) + as scale_mat_equiv_morph. +Proof. intros; apply scale_simplify_mat_equiv; easy. Qed. + +Lemma Mopp_simplify_mat_equiv : forall {n m} (A B : Matrix n m), + A ≡ B -> Mopp A ≡ Mopp B. +Proof. + intros n m A B HAB i j Hi Hj. + unfold Mopp, scale. + rewrite HAB; easy. +Qed. + +Add Parametric Morphism {n m} : (@Mopp n m) + with signature (@mat_equiv n m) ==> (@mat_equiv n m) + as Mopp_mat_equiv_morph. +Proof. intros; apply Mopp_simplify_mat_equiv; easy. Qed. + +Lemma Mminus_simplify_mat_equiv : forall {n m} + (A B C D : Matrix n m), + A ≡ B -> C ≡ D -> Mminus A C ≡ Mminus B D. +Proof. + intros n m A B C D HAB HCD. + intros i j Hi Hj; unfold Mminus, Mopp, Mplus, scale; + rewrite HAB, HCD; try easy. +Qed. + +Add Parametric Morphism {n m} : (@Mminus n m) + with signature (@mat_equiv n m) ==> (@mat_equiv n m) ==> (@mat_equiv n m) + as Mminus_mat_equiv_morph. +Proof. intros; apply Mminus_simplify_mat_equiv; easy. Qed. + +Lemma dot_simplify_mat_equiv : forall {n} (A B : Vector n) + (C D : Vector n), A ≡ B -> C ≡ D -> dot A C = dot B D. +Proof. + intros n A B C D HAB HCD. + apply big_sum_eq_bounded. + intros k Hk. + rewrite HAB, HCD; unfold "<"%nat; easy. +Qed. + +Add Parametric Morphism {n} : (@dot n) + with signature (@mat_equiv n 1) ==> (@mat_equiv n 1) ==> (@eq C) + as dot_mat_equiv_morph. +Proof. intros; apply dot_simplify_mat_equiv; easy. Qed. + +Definition direct_sum' {n m o p : nat} (A : Matrix n m) (B : Matrix o p) : + Matrix (n+o) (m+p) := + (fun i j => if (i WF_Matrix B -> + direct_sum' A B = direct_sum A B. +Proof. + intros n m o p A B HA HB. + apply mat_equiv_eq; [|apply WF_direct_sum|]; auto with wf_db. + intros i j Hi Hj. + unfold direct_sum, direct_sum'. + bdestructΩ'simp; + rewrite HA by lia; easy. +Qed. + +Lemma direct_sum'_simplify_mat_equiv {n m o p} : forall (A B : Matrix n m) + (C D : Matrix o p), A ≡ B -> C ≡ D -> direct_sum' A C ≡ direct_sum' B D. +Proof. + intros A B C D HAB HCD i j Hi Hj. + unfold direct_sum'. + bdestruct (i (@mat_equiv o p) + ==> (@mat_equiv (n+o) (m+p)) as direct_sum'_mat_equiv_morph. +Proof. intros; apply direct_sum'_simplify_mat_equiv; easy. Qed. + +(* Search (Matrix ?n ?m -> ?Matrix ?n ?m). *) +Lemma transpose_simplify_mat_equiv {n m} : forall (A B : Matrix n m), + A ≡ B -> A ⊤ ≡ B ⊤. +Proof. + intros A B HAB i j Hi Hj. + unfold transpose; auto. +Qed. + +Add Parametric Morphism {n m} : (@transpose n m) + with signature (@mat_equiv n m) ==> (@mat_equiv m n) + as transpose_mat_equiv_morph. +Proof. intros; apply transpose_simplify_mat_equiv; easy. Qed. + +Lemma adjoint_simplify_mat_equiv {n m} : forall (A B : Matrix n m), + A ≡ B -> A † ≡ B †. +Proof. + intros A B HAB i j Hi Hj. + unfold adjoint; + rewrite HAB by easy; easy. +Qed. + +Add Parametric Morphism {n m} : (@adjoint n m) + with signature (@mat_equiv n m) ==> (@mat_equiv m n) + as adjoint_mat_equiv_morph. +Proof. intros; apply adjoint_simplify_mat_equiv; easy. Qed. + +Lemma trace_of_mat_equiv : forall n (A B : Square n), + A ≡ B -> trace A = trace B. +Proof. + intros n A B HAB. + (* unfold trace. *) + apply big_sum_eq_bounded; intros i Hi. + rewrite HAB; auto. +Qed. + +Add Parametric Morphism {n} : (@trace n) + with signature (@mat_equiv n n) ==> (eq) + as trace_mat_equiv_morph. +Proof. intros; apply trace_of_mat_equiv; easy. Qed. + + +Lemma Mmult_assoc_mat_equiv : forall {n m o p} + (A : Matrix n m) (B : Matrix m o) (C : Matrix o p), + A × B × C ≡ A × (B × C). +Proof. + intros n m o p A B C. + rewrite Mmult_assoc. + easy. +Qed. + +Lemma mat_equiv_equivalence : forall {n m}, + equivalence (Matrix n m) mat_equiv. +Proof. + intros n m. + constructor. + - intros A. apply (mat_equiv_refl). + - intros A; apply mat_equiv_trans. + - intros A; apply mat_equiv_sym. +Qed. + + + +Lemma big_sum_mat_equiv : forall {o p} (f g : nat -> Matrix o p) + (Eq: forall x : nat, f x ≡ g x) (n : nat), big_sum f n ≡ big_sum g n. +Proof. + intros o p f g Eq n. + induction n. + - easy. + - simpl. + rewrite IHn, Eq; easy. +Qed. + +Add Parametric Morphism {n m} : (@big_sum (Matrix n m) (M_is_monoid n m)) + with signature + (pointwise_relation nat (@mat_equiv n m)) ==> (@eq nat) ==> + (@mat_equiv n m) + as big_sum_mat_equiv_morph. +Proof. intros f g Eq k. apply big_sum_mat_equiv; easy. Qed. \ No newline at end of file diff --git a/examples/MatrixPermBase.v b/examples/MatrixPermBase.v new file mode 100644 index 0000000..79f8bcb --- /dev/null +++ b/examples/MatrixPermBase.v @@ -0,0 +1,166 @@ +Require Export MatrixExampleBase. +From ViCaR Require Import ExamplesAutomation. + +Lemma perm_mat_permutes_ei_r : forall n f k, (k < n)%nat -> + (perm_mat n f) × (e_i k) = e_i (f k). +Proof. + intros n f k Hk. + rewrite <- mat_equiv_eq_iff by auto with wf_db. + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + unfold e_i. + bdestruct (i =? f k). + - unfold perm_mat, Mmult. + bdestruct_one; [|lia]. + simpl. + apply big_sum_unique. + exists k. + repeat split; [lia | bdestructΩ'simp | ]. + intros k' Hk' Hk'k'. + bdestructΩ'simp. + - simpl. + unfold perm_mat, Mmult. + rewrite big_sum_0_bounded; [easy|]. + intros k' Hk'. + bdestructΩ'simp. +Qed. + +Lemma basis_vector_equiv_e_i : forall n k, + basis_vector n k ≡ e_i k. +Proof. + intros n k i j Hi Hj. + unfold basis_vector, e_i. + bdestructΩ'simp. +Qed. + +Lemma basis_vector_eq_e_i : forall n k, (k < n)%nat -> + basis_vector n k = e_i k. +Proof. + intros n k Hk. + rewrite <- mat_equiv_eq_iff by auto with wf_db. + apply basis_vector_equiv_e_i. +Qed. + +Lemma perm_mat_permutes_basis_vectors_r : forall n f k, (k < n)%nat -> + (perm_mat n f) × (basis_vector n k) = e_i (f k). +Proof. + intros n f k Hk. + rewrite basis_vector_eq_e_i by easy. + apply perm_mat_permutes_ei_r; easy. +Qed. + +Lemma mat_equiv_of_equiv_on_ei : forall {n m} (A B : Matrix n m), + (forall k, (k < m)%nat -> A × e_i k ≡ B × e_i k) -> + A ≡ B. +Proof. + intros n m A B Heq. + intros i j Hi Hj. + specialize (Heq j Hj). + rewrite <- 2!(matrix_by_basis _ _ Hj) in Heq. + specialize (Heq i O Hi ltac:(lia)). + unfold get_vec in Heq. + rewrite Nat.eqb_refl in Heq. + easy. +Qed. + +(* FIXME: Temp; only until pull mx stuff out of ZXExample *) +Lemma vector_eq_basis_comb : forall n (y : Vector n), + WF_Matrix y -> + y = big_sum (G:=Vector n) (fun i => y i O .* @e_i n i) n. +Proof. + intros n y Hwfy. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + symmetry. + rewrite Msum_Csum. + apply big_sum_unique. + exists i. + repeat split; try easy. + - unfold ".*", e_i; bdestructΩ'simp. + - intros l Hl Hnk. + unfold ".*", e_i; bdestructΩ'simp. +Qed. + +Lemma vector_equiv_basis_comb : forall n (y : Vector n), + y ≡ big_sum (G:=Vector n) (fun i => y i O .* @e_i n i) n. +Proof. + intros n y. + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + symmetry. + rewrite Msum_Csum. + apply big_sum_unique. + exists i. + repeat split; try easy. + - unfold ".*", e_i; bdestructΩ'simp. + - intros l Hl Hnk. + unfold ".*", e_i; bdestructΩ'simp. +Qed. + +Lemma perm_mat_permutes_matrix_r : forall n m f (A : Matrix n m), + permutation n f -> + (perm_mat n f) × A ≡ (fun i j => A (perm_inv n f i) j). +Proof. + intros n m f A Hperm. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + rewrite Mmult_assoc, <- 2(matrix_by_basis _ _ Hk). + rewrite (vector_equiv_basis_comb _ (get_vec _ _)). + rewrite Mmult_Msum_distr_l. + erewrite big_sum_eq_bounded. + 2: { + intros l Hl. + rewrite Mscale_mult_dist_r, perm_mat_permutes_ei_r by easy. + reflexivity. + } + intros i j Hi Hj; replace j with O by lia; clear j Hj. + rewrite Msum_Csum. + unfold get_vec, scale, e_i. + rewrite Nat.eqb_refl. + apply big_sum_unique. + exists (perm_inv n f i). + repeat split; auto with perm_bdd_db. + - rewrite (perm_inv_is_rinv_of_permutation n f Hperm i Hi), Nat.eqb_refl. + bdestructΩ'simp. + - intros j Hj Hjne. + bdestruct (i =? f j); [|bdestructΩ'simp]. + exfalso; apply Hjne. + apply (permutation_is_injective n f Hperm); auto with perm_bdd_db. + rewrite (perm_inv_is_rinv_of_permutation n f Hperm i Hi); easy. +Qed. + +Lemma perm_inv_of_rotr : forall n k, + forall i, (i < n)%nat -> + perm_inv n (rotr n k) i = rotl n k i. +Proof. + intros n k i Hi. + assert (Hp : permutation n (rotr n k)) by auto with perm_db. + apply (permutation_is_injective n _ Hp); auto with perm_bdd_db. + pose proof (rotr_rotl_inv n k) as H. + unfold compose in H. + rewrite perm_inv_is_rinv_of_permutation; auto with perm_db. + set (g:=(fun x : nat => rotr n k (rotl n k x))). + fold g in H. + enough (g i = idn i) by (unfold g in *; easy). + rewrite H. + easy. +Qed. + +Lemma perm_mat_equiv_of_perm_eq : forall n f g, + (forall k, (k f k = g k) -> + perm_mat n f ≡ perm_mat n g. +Proof. + intros n f g Heq. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + rewrite 2!perm_mat_permutes_ei_r, Heq by easy. + easy. +Qed. + +Lemma perm_mat_equiv_of_perm_eq' : forall n m f g, n = m -> + (forall k, (k f k = g k) -> + perm_mat n f ≡ perm_mat m g. +Proof. + intros; subst n; apply perm_mat_equiv_of_perm_eq; easy. +Qed. \ No newline at end of file diff --git a/examples/ZXExample.v b/examples/ZXExample.v index ccc699e..a882b86 100644 --- a/examples/ZXExample.v +++ b/examples/ZXExample.v @@ -5,944 +5,389 @@ From VyZX Require Import CoreRules. From VyZX Require Import PermutationRules. From ViCaR Require Export CategoryTypeclass. + Lemma proportional_equiv {n m : nat} : equivalence (ZX n m) proportional. Proof. - constructor. - unfold reflexive; apply proportional_refl. - unfold transitive; apply proportional_trans. - unfold symmetric; apply proportional_symm. + constructor. + unfold reflexive; apply proportional_refl. + unfold transitive; apply proportional_trans. + unfold symmetric; apply proportional_symm. Qed. Lemma equality_equiv : equivalence nat eq. Proof. - constructor. - unfold reflexive; easy. - unfold transitive; apply eq_trans. - unfold symmetric; apply eq_sym. + constructor. + unfold reflexive; easy. + unfold transitive; apply eq_trans. + unfold symmetric; apply eq_sym. Qed. #[export] Instance ZXCategory : Category nat := { - morphism := ZX; + morphism := ZX; - equiv := @proportional; - equiv_rel := @proportional_equiv; + equiv := @proportional; + equiv_rel := @proportional_equiv; - compose := @Compose; - compose_compat := @Proportional.compose_compat; - assoc := @ComposeRules.compose_assoc; + compose := @Compose; + compose_compat := @Proportional.compose_compat; + assoc := @ComposeRules.compose_assoc; - c_identity n := n_wire n; - left_unit _ _ := nwire_removal_l; - right_unit _ _ := nwire_removal_r; + c_identity n := n_wire n; + left_unit _ _ := nwire_removal_l; + right_unit _ _ := nwire_removal_r; }. Definition zx_associator {n m o} := - let l := (n + m + o)%nat in - let r := (n + (m + o))%nat in - let assoc := Nat.add_assoc n m o in - cast l r (eq_refl l) assoc (n_wire l). + let l := (n + m + o)%nat in + let r := (n + (m + o))%nat in + let assoc := Nat.add_assoc n m o in + cast l r (eq_refl l) assoc (n_wire l). Definition zx_inv_associator {n m o} := - let l := (n + (m + o))%nat in - let r := (n + m + o)%nat in - let assoc := Nat.add_assoc n m o in - cast l r (eq_refl l) (eq_sym assoc) (n_wire l). + let l := (n + (m + o))%nat in + let r := (n + m + o)%nat in + let assoc := Nat.add_assoc n m o in + cast l r (eq_refl l) (eq_sym assoc) (n_wire l). Lemma zx_associator_inv_compose : forall {n m o}, - zx_associator ⟷ zx_inv_associator ∝ n_wire (n + m + o). + zx_associator ⟷ zx_inv_associator ∝ n_wire (n + m + o). Proof. - intros. - unfold zx_associator. unfold zx_inv_associator. - rewrite cast_compose_r. - cleanup_zx. simpl_casts. - reflexivity. + intros. + unfold zx_associator. unfold zx_inv_associator. + rewrite cast_compose_r. + cleanup_zx. simpl_casts. + reflexivity. Qed. Lemma zx_inv_associator_compose : forall {n m o}, - zx_inv_associator ⟷ zx_associator ∝ n_wire (n + (m + o)). + zx_inv_associator ⟷ zx_associator ∝ n_wire (n + (m + o)). Proof. - intros. - unfold zx_associator. unfold zx_inv_associator. - rewrite cast_compose_l. - cleanup_zx. simpl_casts. - reflexivity. + intros. + unfold zx_associator. unfold zx_inv_associator. + rewrite cast_compose_l. + cleanup_zx. simpl_casts. + reflexivity. Qed. Lemma zx_associator_cohere : forall {n m o p q r} - (zx0 : ZX n m) (zx1 : ZX o p) (zx2 : ZX q r), - zx_associator ⟷ (zx0 ↕ (zx1 ↕ zx2)) - ∝ (zx0 ↕ zx1 ↕ zx2) ⟷ zx_associator. -Proof. - intros. - unfold zx_associator. - repeat rewrite cast_compose_r. - simpl_casts. cleanup_zx. - rewrite cast_compose_l. - cleanup_zx. simpl_casts. - rewrite stack_assoc. - reflexivity. + (zx0 : ZX n m) (zx1 : ZX o p) (zx2 : ZX q r), + zx_associator ⟷ (zx0 ↕ (zx1 ↕ zx2)) + ∝ (zx0 ↕ zx1 ↕ zx2) ⟷ zx_associator. +Proof. + intros. + unfold zx_associator. + repeat rewrite cast_compose_r. + simpl_casts. cleanup_zx. + rewrite cast_compose_l. + cleanup_zx. simpl_casts. + rewrite stack_assoc. + reflexivity. Qed. Definition zx_left_unitor {n} := - cast (0 + n) n (Nat.add_0_l n) (eq_refl n) (n_wire n). + cast (0 + n) n (Nat.add_0_l n) (eq_refl n) (n_wire n). Definition zx_inv_left_unitor {n} := - cast n (0 + n) (eq_refl n) (Nat.add_0_l n) (n_wire n). + cast n (0 + n) (eq_refl n) (Nat.add_0_l n) (n_wire n). Lemma zx_left_inv_compose : forall {n}, - zx_left_unitor ⟷ zx_inv_left_unitor ∝ n_wire (0 + n). + zx_left_unitor ⟷ zx_inv_left_unitor ∝ n_wire (0 + n). Proof. - intros. - unfold zx_left_unitor. unfold zx_inv_left_unitor. - simpl_casts. cleanup_zx. reflexivity. + intros. + unfold zx_left_unitor. unfold zx_inv_left_unitor. + simpl_casts. cleanup_zx. reflexivity. Qed. Lemma zx_inv_left_compose : forall {n}, - zx_inv_left_unitor ⟷ zx_left_unitor ∝ n_wire n. + zx_inv_left_unitor ⟷ zx_left_unitor ∝ n_wire n. Proof. - intros. - unfold zx_left_unitor. unfold zx_inv_left_unitor. - simpl_casts. cleanup_zx. reflexivity. + intros. + unfold zx_left_unitor. unfold zx_inv_left_unitor. + simpl_casts. cleanup_zx. reflexivity. Qed. Lemma zx_left_unitor_cohere : forall {n m} (zx : ZX n m), - zx_left_unitor ⟷ zx ∝ (n_wire 0) ↕ zx ⟷ zx_left_unitor. -Proof. - intros. - unfold zx_left_unitor. - simpl_casts. - rewrite nwire_removal_l. - rewrite stack_empty_l. - rewrite nwire_removal_r. - reflexivity. + zx_left_unitor ⟷ zx ∝ (n_wire 0) ↕ zx ⟷ zx_left_unitor. +Proof. + intros. + unfold zx_left_unitor. + simpl_casts. + rewrite nwire_removal_l. + rewrite stack_empty_l. + rewrite nwire_removal_r. + reflexivity. Qed. Definition zx_right_unitor {n} := - cast (n + 0) n (Nat.add_0_r n) (eq_refl n) (n_wire n). + cast (n + 0) n (Nat.add_0_r n) (eq_refl n) (n_wire n). Definition zx_inv_right_unitor {n} := - cast n (n + 0) (eq_refl n) (Nat.add_0_r n) (n_wire n). + cast n (n + 0) (eq_refl n) (Nat.add_0_r n) (n_wire n). Lemma zx_right_inv_compose : forall {n}, - zx_right_unitor ⟷ zx_inv_right_unitor ∝ n_wire (n + 0). + zx_right_unitor ⟷ zx_inv_right_unitor ∝ n_wire (n + 0). Proof. - intros. - unfold zx_right_unitor. unfold zx_inv_right_unitor. - rewrite cast_compose_l. - cleanup_zx. simpl_casts. reflexivity. + intros. + unfold zx_right_unitor. unfold zx_inv_right_unitor. + rewrite cast_compose_l. + cleanup_zx. simpl_casts. reflexivity. Qed. Lemma zx_inv_right_compose : forall {n}, - zx_inv_right_unitor ⟷ zx_right_unitor ∝ n_wire n. + zx_inv_right_unitor ⟷ zx_right_unitor ∝ n_wire n. Proof. - intros. - unfold zx_right_unitor. unfold zx_inv_right_unitor. - rewrite cast_compose_r. - cleanup_zx. simpl_casts. reflexivity. + intros. + unfold zx_right_unitor. unfold zx_inv_right_unitor. + rewrite cast_compose_r. + cleanup_zx. simpl_casts. reflexivity. Qed. Lemma zx_right_unitor_cohere : forall {n m} (zx : ZX n m), - zx_right_unitor ⟷ zx ∝ zx ↕ (n_wire 0) ⟷ zx_right_unitor. -Proof. - intros. - unfold zx_right_unitor; cleanup_zx. - rewrite <- cast_compose_mid_contract. - cleanup_zx. - rewrite cast_compose_l; simpl_casts. - rewrite nwire_removal_l. - reflexivity. - Unshelve. all: easy. + zx_right_unitor ⟷ zx ∝ zx ↕ (n_wire 0) ⟷ zx_right_unitor. +Proof. + intros. + unfold zx_right_unitor; cleanup_zx. + rewrite <- cast_compose_mid_contract. + cleanup_zx. + rewrite cast_compose_l; simpl_casts. + rewrite nwire_removal_l. + reflexivity. + Unshelve. all: easy. Qed. Lemma zx_triangle_lemma : forall {n m}, - zx_associator ⟷ (n_wire n ↕ zx_left_unitor) ∝ - zx_right_unitor ↕ n_wire m. -Proof. - intros. - unfold zx_associator. - unfold zx_right_unitor. - unfold zx_left_unitor. - simpl_casts. - repeat rewrite n_wire_stack. - cleanup_zx. - simpl_casts. - reflexivity. + zx_associator ⟷ (n_wire n ↕ zx_left_unitor) ∝ + zx_right_unitor ↕ n_wire m. +Proof. + intros. + unfold zx_associator. + unfold zx_right_unitor. + unfold zx_left_unitor. + simpl_casts. + repeat rewrite n_wire_stack. + cleanup_zx. + simpl_casts. + reflexivity. Qed. Lemma zx_pentagon_lemma : forall {n m o p}, - (zx_associator ↕ n_wire p) ⟷ zx_associator ⟷ (n_wire n ↕ zx_associator) - ∝ (@zx_associator (n + m) o p) ⟷ zx_associator. + (zx_associator ↕ n_wire p) ⟷ zx_associator ⟷ (n_wire n ↕ zx_associator) + ∝ (@zx_associator (n + m) o p) ⟷ zx_associator. Proof. - intros. - unfold zx_associator. - simpl_casts. - repeat rewrite n_wire_stack. - repeat rewrite cast_compose_l. - repeat rewrite cast_compose_r. - cleanup_zx; simpl_casts; reflexivity. + intros. + unfold zx_associator. + simpl_casts. + repeat rewrite n_wire_stack. + repeat rewrite cast_compose_l. + repeat rewrite cast_compose_r. + cleanup_zx; simpl_casts; reflexivity. Qed. Definition ZXTensorBiFunctor : Bifunctor ZXCategory ZXCategory ZXCategory := {| - obj2_map := Nat.add; - morphism2_map := @Stack; - id2_map := n_wire_stack; - compose2_map := @stack_compose_distr; - morphism2_compat := @stack_simplify; + obj2_map := Nat.add; + morphism2_map := @Stack; + id2_map := n_wire_stack; + compose2_map := @stack_compose_distr; + morphism2_compat := @stack_simplify; |}. #[export] Instance ZXMonoidalCategory : MonoidalCategory nat := { - tensor := ZXTensorBiFunctor; - - associator := fun n m o => {| - forward := @zx_associator n m o; - reverse := @zx_inv_associator n m o; - id_A := @zx_associator_inv_compose n m o; - id_B := @zx_inv_associator_compose n m o; - |}; - - left_unitor := fun n => {| - forward := @zx_left_unitor n; - reverse := @zx_inv_left_unitor n; - id_A := @zx_left_inv_compose n; - id_B := @zx_inv_left_compose n; - |}; - - right_unitor := fun n => {| - forward := @zx_right_unitor n; - reverse := @zx_inv_right_unitor n; - id_A := @zx_right_inv_compose n; - id_B := @zx_inv_right_compose n; - |}; - - associator_cohere := @zx_associator_cohere; - left_unitor_cohere := @zx_left_unitor_cohere; - right_unitor_cohere := @zx_right_unitor_cohere; - - triangle := @zx_triangle_lemma; - pentagon := @zx_pentagon_lemma; + tensor := ZXTensorBiFunctor; + + associator := fun n m o => {| + forward := @zx_associator n m o; + reverse := @zx_inv_associator n m o; + id_A := @zx_associator_inv_compose n m o; + id_B := @zx_inv_associator_compose n m o; + |}; + + left_unitor := fun n => {| + forward := @zx_left_unitor n; + reverse := @zx_inv_left_unitor n; + id_A := @zx_left_inv_compose n; + id_B := @zx_inv_left_compose n; + |}; + + right_unitor := fun n => {| + forward := @zx_right_unitor n; + reverse := @zx_inv_right_unitor n; + id_A := @zx_right_inv_compose n; + id_B := @zx_inv_right_compose n; + |}; + + associator_cohere := @zx_associator_cohere; + left_unitor_cohere := @zx_left_unitor_cohere; + right_unitor_cohere := @zx_right_unitor_cohere; + + triangle := @zx_triangle_lemma; + pentagon := @zx_pentagon_lemma; (* - tensor := Nat.add; - I := 0; + tensor := Nat.add; + I := 0; - tensor_morph _ _ _ _ := Stack; - tensor_morph_compat := stack_compat; + tensor_morph _ _ _ _ := Stack; + tensor_morph_compat := stack_compat; - associator := @zx_associator; - inv_associator := @zx_inv_associator; - associator_inv_compose := @zx_associator_inv_compose; - inv_associator_compose := @zx_inv_associator_compose; + associator := @zx_associator; + inv_associator := @zx_inv_associator; + associator_inv_compose := @zx_associator_inv_compose; + inv_associator_compose := @zx_inv_associator_compose; - left_unitor := @zx_left_unitor; - inv_left_unitor := @zx_inv_left_unitor; - left_inv_compose := @zx_left_inv_compose; - inv_left_compose := @zx_inv_left_compose; + left_unitor := @zx_left_unitor; + inv_left_unitor := @zx_inv_left_unitor; + left_inv_compose := @zx_left_inv_compose; + inv_left_compose := @zx_inv_left_compose; - right_unitor := @zx_right_unitor; - inv_right_unitor := @zx_inv_right_unitor; - right_inv_compose := @zx_right_inv_compose; - inv_right_compose := @zx_inv_right_compose; + right_unitor := @zx_right_unitor; + inv_right_unitor := @zx_inv_right_unitor; + right_inv_compose := @zx_right_inv_compose; + inv_right_compose := @zx_inv_right_compose; - bifunctor_id := n_wire_stack; - bifunctor_comp := @stack_compose_distr; + bifunctor_id := n_wire_stack; + bifunctor_comp := @stack_compose_distr; - associator_cohere := @zx_associator_cohere; - left_unitor_cohere := @zx_left_unitor_cohere; - right_unitor_cohere := @zx_right_unitor_cohere; + associator_cohere := @zx_associator_cohere; + left_unitor_cohere := @zx_left_unitor_cohere; + right_unitor_cohere := @zx_right_unitor_cohere; - triangle := @zx_triangle_lemma; - pentagon := @zx_pentagon_lemma; + triangle := @zx_triangle_lemma; + pentagon := @zx_pentagon_lemma; *) }. Definition zx_braiding {n m} := - let l := (n + m)%nat in - let r := (m + n)%nat in - cast l r (eq_refl l) (Nat.add_comm m n) (n_top_to_bottom n m). + let l := (n + m)%nat in + let r := (m + n)%nat in + cast l r (eq_refl l) (Nat.add_comm m n) (n_top_to_bottom n m). Definition zx_inv_braiding {n m} := - let l := (m + n)%nat in - let r := (n + m)%nat in - cast l r (eq_refl l) (Nat.add_comm n m) (n_bottom_to_top n m). + let l := (m + n)%nat in + let r := (n + m)%nat in + cast l r (eq_refl l) (Nat.add_comm n m) (n_bottom_to_top n m). (* Because they're not definitionally square, it's kinda useless to show zx_braiding and zx_inv_braiding are (up to cast) ZXperm's. Instead, we can hint it to unfold them automatically and let the casting wizardy take it from there: *) #[export] Hint Unfold zx_braiding zx_inv_braiding - zx_associator zx_inv_associator - zx_left_unitor zx_right_unitor - zx_inv_left_unitor zx_inv_right_unitor : zxperm_db. + zx_associator zx_inv_associator + zx_left_unitor zx_right_unitor + zx_inv_left_unitor zx_inv_right_unitor : zxperm_db. Definition n_compose_bot n m := n_compose n (bottom_to_top m). Definition n_compose_top n m := n_compose n (top_to_bottom m). Lemma zx_braiding_inv_compose : forall {n m}, - zx_braiding ⟷ zx_inv_braiding ∝ n_wire (n + m). -Proof. - intros. - prop_perm_eq. - rewrite Nat.add_comm. - cleanup_perm_of_zx; easy. - (* intros. - unfold zx_braiding. unfold zx_inv_braiding. - unfold n_top_to_bottom. - unfold n_bottom_to_top. - rewrite cast_compose_mid. - simpl_casts. - fold (n_compose_bot n (m + n)). - rewrite cast_fn_eq_dim. - rewrite n_compose_top_compose_bottom. - reflexivity. - Unshelve. - all: rewrite (Nat.add_comm n m); easy. *) + zx_braiding ⟷ zx_inv_braiding ∝ n_wire (n + m). +Proof. + intros. + prop_perm_eq. + rewrite Nat.add_comm. + cleanup_perm_of_zx; easy. + (* intros. + unfold zx_braiding. unfold zx_inv_braiding. + unfold n_top_to_bottom. + unfold n_bottom_to_top. + rewrite cast_compose_mid. + simpl_casts. + fold (n_compose_bot n (m + n)). + rewrite cast_fn_eq_dim. + rewrite n_compose_top_compose_bottom. + reflexivity. + Unshelve. + all: rewrite (Nat.add_comm n m); easy. *) Qed. Lemma zx_inv_braiding_compose : forall {n m}, - zx_inv_braiding ⟷ zx_braiding ∝ n_wire (m + n). -Proof. - intros. - prop_perm_eq. - rewrite Nat.add_comm. - cleanup_perm_of_zx; easy. - (* intros. - unfold zx_braiding. unfold zx_inv_braiding. - unfold n_top_to_bottom. - unfold n_bottom_to_top. - rewrite cast_compose_mid. - simpl_casts. - fold (n_compose_top n (n + m)). - rewrite cast_fn_eq_dim. - rewrite n_compose_bottom_compose_top. - reflexivity. - Unshelve. - all: rewrite (Nat.add_comm n m); easy. *) + zx_inv_braiding ⟷ zx_braiding ∝ n_wire (m + n). +Proof. + intros. + prop_perm_eq. + rewrite Nat.add_comm. + cleanup_perm_of_zx; easy. + (* intros. + unfold zx_braiding. unfold zx_inv_braiding. + unfold n_top_to_bottom. + unfold n_bottom_to_top. + rewrite cast_compose_mid. + simpl_casts. + fold (n_compose_top n (n + m)). + rewrite cast_fn_eq_dim. + rewrite n_compose_bottom_compose_top. + reflexivity. + Unshelve. + all: rewrite (Nat.add_comm n m); easy. *) Qed. Lemma n_top_to_bottom_split : forall {n m o o'} prf1 prf2 prf3 prf4, - n_top_to_bottom n m ↕ n_wire o - ⟷ cast (n + m + o) o' prf1 prf2 (n_wire m ↕ n_top_to_bottom n o) - ∝ cast (n + m + o) o' prf3 prf4 (n_top_to_bottom n (m + o)). + n_top_to_bottom n m ↕ n_wire o + ⟷ cast (n + m + o) o' prf1 prf2 (n_wire m ↕ n_top_to_bottom n o) + ∝ cast (n + m + o) o' prf3 prf4 (n_top_to_bottom n (m + o)). Proof. - intros. - prop_perm_eq. - solve_modular_permutation_equalities. + intros. + prop_perm_eq. + solve_modular_permutation_equalities. Qed. Lemma hexagon_lemma_1 : forall {n m o}, - (zx_braiding ↕ n_wire o) ⟷ zx_associator ⟷ (n_wire m ↕ zx_braiding) - ∝ zx_associator ⟷ (@zx_braiding n (m + o)) ⟷ zx_associator. -Proof. - intros. - prop_perm_eq. - solve_modular_permutation_equalities. - (* intros. - unfold zx_braiding. unfold zx_associator. - simpl_casts. - rewrite cast_compose_l. simpl_casts. - rewrite compose_assoc. - rewrite cast_compose_l. simpl_casts. - cleanup_zx. simpl_casts. - rewrite cast_compose_l. - simpl_casts. cleanup_zx. - rewrite cast_compose_l. simpl_casts. - rewrite (cast_compose_r _ _ _ (n_wire (m + o + n))). - cleanup_zx. simpl_casts. - rewrite n_top_to_bottom_split. - reflexivity. *) + (zx_braiding ↕ n_wire o) ⟷ zx_associator ⟷ (n_wire m ↕ zx_braiding) + ∝ zx_associator ⟷ (@zx_braiding n (m + o)) ⟷ zx_associator. +Proof. + intros. + prop_perm_eq. + solve_modular_permutation_equalities. + (* intros. + unfold zx_braiding. unfold zx_associator. + simpl_casts. + rewrite cast_compose_l. simpl_casts. + rewrite compose_assoc. + rewrite cast_compose_l. simpl_casts. + cleanup_zx. simpl_casts. + rewrite cast_compose_l. + simpl_casts. cleanup_zx. + rewrite cast_compose_l. simpl_casts. + rewrite (cast_compose_r _ _ _ (n_wire (m + o + n))). + cleanup_zx. simpl_casts. + rewrite n_top_to_bottom_split. + reflexivity. *) Qed. Lemma n_bottom_to_top_split : forall {n m o o'} prf1 prf2 prf3 prf4, - n_bottom_to_top m n ↕ n_wire o - ⟷ cast (n + m + o) o' prf1 prf2 (n_wire m ↕ n_bottom_to_top o n) - ∝ cast (n + m + o) o' prf3 prf4 (n_bottom_to_top (m + o) n). + n_bottom_to_top m n ↕ n_wire o + ⟷ cast (n + m + o) o' prf1 prf2 (n_wire m ↕ n_bottom_to_top o n) + ∝ cast (n + m + o) o' prf3 prf4 (n_bottom_to_top (m + o) n). Proof. - prop_perm_eq. - solve_modular_permutation_equalities. + prop_perm_eq. + solve_modular_permutation_equalities. Qed. Lemma hexagon_lemma_2 : forall {n m o}, - (zx_inv_braiding ↕ n_wire o) ⟷ zx_associator ⟷ (n_wire m ↕ zx_inv_braiding) - ∝ zx_associator ⟷ (@zx_inv_braiding (m + o) n) ⟷ zx_associator. -Proof. - prop_perm_eq. - solve_modular_permutation_equalities. - (* intros. - unfold zx_inv_braiding. unfold zx_associator. - simpl_casts. - rewrite cast_compose_l. simpl_casts. - rewrite compose_assoc. - rewrite cast_compose_l. simpl_casts. - cleanup_zx. simpl_casts. - rewrite cast_compose_l. - simpl_casts. cleanup_zx. - rewrite cast_compose_l. simpl_casts. - rewrite (cast_compose_r _ _ _ (n_wire (m + o + n))). - cleanup_zx. simpl_casts. - rewrite n_bottom_to_top_split. - reflexivity. *) -Qed. - -Ltac print_state := - try (match goal with | H : ?p |- _ => idtac H ":" p; fail end); - idtac "---------------------------------------------------------"; - match goal with |- ?P => idtac P -end. - - -Ltac is_C0 x := - assert (x = C0) by lca. - -Ltac is_C1 x := - assert (x = C1) by lca. - -Ltac print_C x := - tryif is_C0 x then idtac "0" else - tryif is_C1 x then idtac "1" else idtac "X". - -Ltac print_LHS_matU := - intros; - (let i := fresh "i" in - let j := fresh "j" in - let Hi := fresh "Hi" in - let Hj := fresh "Hj" in - intros i j Hi Hj; try solve_end; - repeat - (destruct i as [| i]; [ | apply <- Nat.succ_lt_mono in Hi ]; - try solve_end); clear Hi; - repeat - (destruct j as [| j]; [ | apply <- Nat.succ_lt_mono in Hj ]; - try solve_end); clear Hj); - match goal with |- ?x = ?y ?i ?j => autounfold with U_db; simpl; - match goal with - | |- ?x = _ => idtac i; idtac j; print_C x; idtac "" - end -end. - -Definition kron_comm p q : Matrix (p*q) (p*q):= - @make_WF (p*q) (p*q) (fun s t => - (* have blocks H_ij, p by q of them, and each is q by p *) - let i := (s / q)%nat in let j := (t / p)%nat in - let k := (s mod q)%nat in let l := (t mod p) in - (* let k := (s - q * i)%nat in let l := (t - p * t)%nat in *) - if (i =? l) && (j =? k) then C1 else C0 - (* s/q =? t mod p /\ t/p =? s mod q *) -). - -Lemma WF_kron_comm p q : WF_Matrix (kron_comm p q). -Proof. unfold kron_comm; auto with wf_db. Qed. -#[export] Hint Resolve WF_kron_comm : wf_db. - -(* Lemma test_kron : kron_comm 2 3 = Matrix.Zero. -Proof. - apply mat_equiv_eq; unfold kron_comm; auto with wf_db. - print_LHS_matU. -*) - -Lemma kron_comm_transpose : forall p q, - (kron_comm p q) ⊤ = kron_comm q p. -Proof. - intros p q. - apply mat_equiv_eq; auto with wf_db. - 1: rewrite Nat.mul_comm; apply WF_kron_comm. - intros i j Hi Hj. - unfold kron_comm, transpose, make_WF. - rewrite andb_comm, Nat.mul_comm. - rewrite (andb_comm (_ =? _)). - easy. + (zx_inv_braiding ↕ n_wire o) ⟷ zx_associator ⟷ (n_wire m ↕ zx_inv_braiding) + ∝ zx_associator ⟷ (@zx_inv_braiding (m + o) n) ⟷ zx_associator. +Proof. + prop_perm_eq. + solve_modular_permutation_equalities. + (* intros. + unfold zx_inv_braiding. unfold zx_associator. + simpl_casts. + rewrite cast_compose_l. simpl_casts. + rewrite compose_assoc. + rewrite cast_compose_l. simpl_casts. + cleanup_zx. simpl_casts. + rewrite cast_compose_l. + simpl_casts. cleanup_zx. + rewrite cast_compose_l. simpl_casts. + rewrite (cast_compose_r _ _ _ (n_wire (m + o + n))). + cleanup_zx. simpl_casts. + rewrite n_bottom_to_top_split. + reflexivity. *) Qed. -Lemma kron_comm_1_r : forall p, - (kron_comm p 1) = Matrix.I p. -Proof. - intros p. - apply mat_equiv_eq; [|rewrite 1?Nat.mul_1_r|]; auto with wf_db. - intros s t Hs Ht. - unfold kron_comm. - unfold make_WF. - unfold Matrix.I. - rewrite Nat.mul_1_r, Nat.div_1_r, Nat.mod_1_r, Nat.div_small, Nat.mod_small by lia. - bdestructΩ'. -Qed. +Require Export KronComm_orig. -Lemma kron_comm_1_l : forall p, - (kron_comm 1 p) = Matrix.I p. -Proof. - intros p. - apply mat_equiv_eq; [|rewrite 1?Nat.mul_1_l|]; auto with wf_db. - intros s t Hs Ht. - unfold kron_comm. - unfold make_WF. - unfold Matrix.I. - rewrite Nat.mul_1_l, Nat.div_1_r, Nat.mod_1_r, Nat.div_small, Nat.mod_small by lia. - bdestructΩ'. -Qed. -Definition mx_to_vec {n m} (A : Matrix n m) : Vector (n*m) := - make_WF (fun i j => A (i mod n)%nat (i / n)%nat - (* Note: goes columnwise. Rowwise would be: - make_WF (fun i j => A (i / m)%nat (i mod n)%nat - *) -). -Lemma WF_mx_to_vec {n m} (A : Matrix n m) : WF_Matrix (mx_to_vec A). -Proof. unfold mx_to_vec; auto with wf_db. Qed. -#[export] Hint Resolve WF_mx_to_vec : wf_db. -(* Compute vec_to_list (mx_to_vec (Matrix.I 2)). *) -From Coq Require Import ZArith. -Ltac Zify.zify_post_hook ::= PreOmega.Z.div_mod_to_equations. - -Lemma kron_comm_vec_helper : forall i p q, (i < p * q)%nat -> - (p * (i mod q) + i / q < p * q)%nat. -Proof. - intros i p q. - intros Hi. - assert (i / q < p)%nat by (apply Nat.div_lt_upper_bound; lia). - destruct p; [lia|]; - destruct q; [lia|]. - enough (S p * (i mod S q) <= S p * q)%nat by lia. - apply Nat.mul_le_mono; [lia | ]. - pose proof (Nat.mod_upper_bound i (S q) ltac:(easy)). - lia. -Qed. - -Lemma mx_to_vec_additive {n m} (A B : Matrix n m) : - mx_to_vec (A .+ B) = mx_to_vec A .+ mx_to_vec B. -Proof. - apply mat_equiv_eq; auto with wf_db. - intros i j Hi Hj. - replace j with O by lia; clear dependent j. - unfold mx_to_vec, make_WF, Mplus. - bdestructΩ'. -Qed. - -Lemma if_mult_dist_r (b : bool) (z : C) : - (if b then C1 else C0) * z = - if b then z else C0. -Proof. - destruct b; lca. -Qed. - -Lemma if_mult_and (b c : bool) : - (if b then C1 else C0) * (if c then C1 else C0) = - if (b && c) then C1 else C0. -Proof. - destruct b; destruct c; lca. -Qed. - -Lemma kron_comm_vec : forall p q (A : Matrix p q), - kron_comm p q × mx_to_vec A = mx_to_vec (A ⊤). -Proof. - intros p q A. - apply mat_equiv_eq; [|rewrite Nat.mul_comm|]; auto with wf_db. - intros i j Hi Hj. - replace j with O by lia; clear dependent j. - unfold transpose, mx_to_vec, kron_comm, make_WF, Mmult. - rewrite (Nat.mul_comm q p). - replace_bool_lia (i . - destruct p; [lia|]. - destruct q; [lia|]. - split. - + rewrite Nat.add_comm, Nat.mul_comm. - rewrite Nat.mod_add by easy. - rewrite Nat.mod_small; [lia|]. - apply Nat.div_lt_upper_bound; lia. - + rewrite Nat.mul_comm, Nat.div_add_l by easy. - rewrite Nat.div_small; [lia|]. - apply Nat.div_lt_upper_bound; lia. - - intros [Hmodp Hdivp]. - rewrite (Nat.div_mod_eq k p). - lia. - } - apply big_sum_unique. - exists (p * (i mod q) + i / q)%nat; repeat split; - [apply kron_comm_vec_helper; easy | rewrite Nat.eqb_refl | intros; bdestructΩ']. - destruct p; [lia|]; - destruct q; [lia|]. - f_equal. - - rewrite Nat.add_comm, Nat.mul_comm, Nat.mod_add, Nat.mod_small; try easy. - apply Nat.div_lt_upper_bound; lia. - - rewrite Nat.mul_comm, Nat.div_add_l by easy. - rewrite Nat.div_small; [lia|]. - apply Nat.div_lt_upper_bound; lia. -Qed. - -Lemma kron_comm_sum : forall p q, - kron_comm p q = - big_sum (G:=Square (p*q)) (fun i => big_sum (G:=Square (p*q)) (fun j => - (@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤)) - q) p. -Proof. - intros p q. - apply mat_equiv_eq; auto with wf_db. - 1: apply WF_Msum; intros; apply WF_Msum; intros; - rewrite Nat.mul_comm; apply WF_mult; - auto with wf_db; rewrite Nat.mul_comm; - auto with wf_db. - intros i j Hi Hj. - rewrite Msum_Csum. - erewrite big_sum_eq_bounded. - 2: { - intros k Hk. - rewrite Msum_Csum. - erewrite big_sum_eq_bounded. - 2: { - intros l Hl. - unfold Mmult, kron, transpose, e_i. - erewrite big_sum_eq_bounded. - 2: { - intros m Hm. - (* replace m with O by lia. *) - rewrite Nat.div_1_r, Nat.mod_1_r. - replace_bool_lia (m =? 0) true; rewrite 4!andb_true_r. - rewrite 3!if_mult_and. - match goal with - |- context[if ?b then _ else _] => - replace b with ((i =? k * q + l) && (j =? l * p + k)) - end. - 1: reflexivity. (* set our new function *) - clear dependent m. - rewrite eq_iff_eq_true, 8!andb_true_iff, - 6!Nat.eqb_eq, 4!Nat.ltb_lt. - split. - - intros [Hieq Hjeq]. - subst i j. - rewrite 2!Nat.div_add_l, Nat.div_small, Nat.add_0_r by lia. - rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small, - Nat.div_small, Nat.add_0_r by lia. - rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by lia. - easy. - - intros [[[] []] [[] []]]. - split. - + rewrite (Nat.div_mod_eq i q) by lia; lia. - + rewrite (Nat.div_mod_eq j p) by lia; lia. - } - simpl; rewrite Cplus_0_l. - reflexivity. - } - apply big_sum_unique. - exists (i mod q). - split; [|split]. - - apply Nat.mod_upper_bound; lia. - - reflexivity. - - intros l Hl Hnmod. - bdestructΩ'. - exfalso; apply Hnmod. - rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by lia; lia. - } - symmetry. - apply big_sum_unique. - exists (j mod p). - repeat split. - - apply Nat.mod_upper_bound; lia. - - unfold kron_comm, make_WF. - replace_bool_lia (i - enough (b = c) by bdestructΩ' - end. - rewrite eq_iff_eq_true, 2!andb_true_iff, 4!Nat.eqb_eq. - split. - + intros [Hieq Hjeq]. - split; [rewrite Hieq | rewrite Hjeq]; - rewrite Hieq, Nat.div_add_l by lia; - (rewrite Nat.div_small; [lia|]); - apply Nat.mod_upper_bound; lia. - + intros [Hidiv Hjdiv]. - rewrite (Nat.div_mod_eq i q) at 1 by lia. - rewrite (Nat.div_mod_eq j p) at 2 by lia. - lia. - - intros k Hk Hkmod. - bdestructΩ'. - exfalso; apply Hkmod. - rewrite Nat.add_comm, Nat.mod_add, Nat.mod_small by lia; lia. -Qed. - -Lemma kron_comm_sum' : forall p q, - kron_comm p q = - big_sum (G:=Square (p*q)) (fun ij => - let i := (ij / q)%nat in let j := (ij mod q) in - ((@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤))) (p*q). -Proof. - intros p q. - rewrite kron_comm_sum, big_sum_double_sum, Nat.mul_comm. - reflexivity. -Qed. - -Lemma e_i_dot_is_component : forall p k (x : Vector p), - (k < p)%nat -> WF_Matrix x -> - (@e_i p k) ⊤ × x = x k O .* Matrix.I 1. -Proof. - intros p k x Hk HWF. - apply mat_equiv_eq; auto with wf_db. - intros i j Hi Hj; - replace i with O by lia; - replace j with O by lia; - clear i Hi; - clear j Hj. - unfold Mmult, transpose, scale, e_i, Matrix.I. - bdestructΩ'. - rewrite Cmult_1_r. - apply big_sum_unique. - exists k. - split; [easy|]. - bdestructΩ'. - rewrite Cmult_1_l. - split; [easy|]. - intros l Hl Hkl. - bdestructΩ'; lca. -Qed. - -Lemma kron_e_i_e_i : forall p q k l, - (k < p)%nat -> (l < q)%nat -> - @e_i q l ⊗ @e_i p k = @e_i (p*q) (l*p + k). -Proof. - intros p q k l Hk Hl. - apply functional_extensionality; intro i. - apply functional_extensionality; intro j. - unfold kron, e_i. - rewrite Nat.mod_1_r, Nat.div_1_r. - rewrite if_mult_and. - lazymatch goal with - |- (if ?b then _ else _) = (if ?c then _ else _) => - enough (H : b = c) by (rewrite H; easy) - end. - rewrite Nat.eqb_refl, andb_true_r. - destruct (j =? 0); [|rewrite 2!andb_false_r; easy]. - rewrite 2!andb_true_r. - rewrite eq_iff_eq_true, 4!andb_true_iff, 3!Nat.eqb_eq, 3!Nat.ltb_lt. - split. - - intros [[] []]. - rewrite (Nat.div_mod_eq i p). - split; nia. - - intros []. - subst i. - rewrite Nat.div_add_l, Nat.div_small, Nat.add_0_r, - Nat.add_comm, Nat.mod_add, Nat.mod_small by lia. - easy. -Qed. - -Lemma kron_eq_sum : forall p q (x : Vector q) (y : Vector p), - WF_Matrix x -> WF_Matrix y -> - y ⊗ x = big_sum (fun ij => - let i := (ij / q)%nat in let j := ij mod q in - (x j O * y i O) .* (@e_i p i ⊗ @e_i q j)) (p * q). -Proof. - intros p q x y Hwfx Hwfy. - - erewrite big_sum_eq_bounded. - 2: { - intros ij Hij. - simpl. - rewrite (@kron_e_i_e_i q p) by - (try apply Nat.mod_upper_bound; try apply Nat.div_lt_upper_bound; lia). - rewrite (Nat.mul_comm (ij / q) q). - rewrite <- (Nat.div_mod_eq ij q). - reflexivity. - } - apply mat_equiv_eq; [|rewrite Nat.mul_comm|]; auto with wf_db. - intros i j Hi Hj. - replace j with O by lia; clear j Hj. - simpl. - rewrite Msum_Csum. - symmetry. - apply big_sum_unique. - exists i. - split; [lia|]. - unfold e_i; split. - - unfold scale, kron; bdestructΩ'. - rewrite Cmult_1_r, Cmult_comm. - easy. - - intros j Hj Hij. - unfold scale, kron; bdestructΩ'. - rewrite Cmult_0_r. - easy. -Qed. - -Lemma kron_comm_commutes_vectors : forall p q (x : Vector q) (y : Vector p), - WF_Matrix x -> WF_Matrix y -> - kron_comm p q × (x ⊗ y) = (y ⊗ x). -Proof. - intros p q x y Hwfx Hwfy. - rewrite kron_comm_sum', Mmult_Msum_distr_r. - erewrite big_sum_eq_bounded. - 2: { - intros k Hk. - simpl. - match goal with - |- (?A × ?B) × ?C = _ => - assert (Hassoc: (A × B) × C = A × (B × C)) by apply Mmult_assoc - end. - simpl in Hassoc. - rewrite (Nat.mul_comm q p) in *. - rewrite Hassoc. clear Hassoc. - pose proof (kron_transpose _ _ _ _ (@e_i q (k mod q)) (@e_i p (k / q))) as Hrw; - rewrite (Nat.mul_comm q p) in Hrw; - simpl in Hrw; rewrite Hrw; clear Hrw. - pose proof (kron_mixed_product ((e_i (k mod q)) ⊤) ((e_i (k / q)) ⊤) x y) as Hrw; - rewrite (Nat.mul_comm q p) in Hrw; - simpl in Hrw; rewrite Hrw; clear Hrw. - rewrite 2!e_i_dot_is_component; [| - apply Nat.div_lt_upper_bound; lia | - easy | - apply Nat.mod_upper_bound; lia | - easy]. - rewrite Mscale_kron_dist_l, Mscale_kron_dist_r, Mscale_assoc. - rewrite kron_1_l, Mscale_mult_dist_r, Mmult_1_r by auto with wf_db. - reflexivity. - } - rewrite <- kron_eq_sum; easy. -Qed. - -Lemma kron_basis_vector_basis_vector : forall p q k l, - (k < p)%nat -> (l < q)%nat -> - basis_vector q l ⊗ basis_vector p k = basis_vector (p*q) (l*p + k). -Proof. - intros p q k l Hk Hl. - apply functional_extensionality; intros i. - apply functional_extensionality; intros j. - unfold kron, basis_vector. - rewrite Nat.mod_1_r, Nat.div_1_r, Nat.eqb_refl, andb_true_r, if_mult_and. - bdestructΩ'; - try pose proof (Nat.div_mod_eq i p); - try nia. - rewrite Nat.div_add_l, Nat.div_small in * by lia. - lia. -Qed. - -Lemma kron_extensionality : forall n m s t (A B : Matrix (n*m) (s*t)), - WF_Matrix A -> WF_Matrix B -> - (forall (x : Vector s) (y :Vector t), - WF_Matrix x -> WF_Matrix y -> - A × (x ⊗ y) = B × (x ⊗ y)) -> - A = B. -Proof. - intros b n s t A B HwfA HwfB Hext. - apply equal_on_basis_vectors_implies_equal; try easy. - intros i Hi. - - pose proof (Nat.div_lt_upper_bound i t s ltac:(lia) ltac:(lia)). - pose proof (Nat.mod_upper_bound i s ltac:(lia)). - pose proof (Nat.mod_upper_bound i t ltac:(lia)). - - specialize (Hext (basis_vector s (i / t)) (basis_vector t (i mod t)) - ltac:(apply basis_vector_WF; easy) - ltac:(apply basis_vector_WF; easy) - ). - rewrite (kron_basis_vector_basis_vector t s) in Hext by lia. - - simpl in Hext. - rewrite (Nat.mul_comm (i/t) t), <- (Nat.div_mod_eq i t) in Hext. - rewrite (Nat.mul_comm t s) in Hext. easy. -Qed. - -Lemma kron_comm_commutes : forall n s m t - (A : Matrix n s) (B : Matrix m t), - WF_Matrix A -> WF_Matrix B -> - kron_comm m n × (A ⊗ B) × (kron_comm s t) = (B ⊗ A). -Proof. - intros n s m t A B HwfA HwfB. - apply (kron_extensionality _ _ t s); [| - apply WF_kron; try easy; lia |]. - rewrite (Nat.mul_comm t s); apply WF_mult; auto with wf_db; - apply WF_mult; auto with wf_db; - rewrite Nat.mul_comm; auto with wf_db. - (* rewrite Nat.mul_comm; apply WF_mult; [rewrite Nat.mul_comm|auto with wf_db]; - apply WF_mult; auto with wf_db; rewrite Nat.mul_comm; auto with wf_db. *) - intros x y Hwfx Hwfy. - (* simpl. *) - (* Search "assoc" in Matrix. *) - rewrite (Nat.mul_comm s t). - rewrite (Mmult_assoc (_ × _)). - rewrite (Nat.mul_comm t s). - rewrite kron_comm_commutes_vectors by easy. - rewrite Mmult_assoc, (Nat.mul_comm m n). - rewrite kron_mixed_product. - rewrite (Nat.mul_comm n m), kron_comm_commutes_vectors by (auto with wf_db). - rewrite <- kron_mixed_product. - f_equal; lia. -Qed. - -Lemma f_to_vec_split : forall (f : nat -> bool) (m n : nat), - f_to_vec (m + n) f = f_to_vec m f ⊗ f_to_vec n (VectorStates.shift f m). -Proof. - intros f m n. - rewrite f_to_vec_merge. - apply f_to_vec_eq. - intros i Hi. - unfold VectorStates.shift. - bdestructΩ'. - f_equal; lia. -Qed. - -Lemma n_top_to_bottom_semantics_eq_kron_comm : forall n o, - ⟦ n_top_to_bottom n o ⟧ = kron_comm (2^o) (2^n). -Proof. - intros n o. - rewrite zxperm_permutation_semantics by auto with zxperm_db. - unfold zxperm_to_matrix. - rewrite perm_of_n_top_to_bottom. - apply equal_on_basis_states_implies_equal; auto with wf_db. - 1: { - rewrite Nat.add_comm, Nat.pow_add_r. - auto with wf_db. - } - intros f. - pose proof (perm_to_matrix_permutes_qubits (n + o) (rotr (n+o) n) f). - unfold perm_to_matrix in H. - rewrite H by auto with perm_db. - rewrite (f_to_vec_split f). - pose proof (kron_comm_commutes_vectors (2^o) (2^n) - (f_to_vec n f) (f_to_vec o (@VectorStates.shift bool f n)) - ltac:(auto with wf_db) ltac:(auto with wf_db)). - replace (2^(n+o))%nat with (2^o *2^n)%nat by (rewrite Nat.pow_add_r; lia). - simpl in H0. - rewrite H0. - rewrite Nat.add_comm, f_to_vec_split. - f_equal. - - apply f_to_vec_eq. - intros i Hi. - unfold VectorStates.shift. - f_equal; unfold rotr. - rewrite Nat.mod_small by lia. - bdestructΩ'. - - apply f_to_vec_eq. - intros i Hi. - unfold VectorStates.shift, rotr. - rewrite <- Nat.add_assoc, mod_add_n_r, Nat.mod_small by lia. - bdestructΩ'. -Qed. - - -Lemma compose_semantics' : -forall {n m o : nat} (zx0 : ZX n m) (zx1 : ZX m o), -@eq (Matrix (Nat.pow 2 o) (Nat.pow 2 n)) - (@ZX_semantics n o (@Compose n m o zx0 zx1)) - (@Mmult (Nat.pow 2 o) (Nat.pow 2 m) (Nat.pow 2 n) - (@ZX_semantics m o zx1) (@ZX_semantics n m zx0)). -Proof. - intros. - rewrite (@compose_semantics n m o). - easy. -Qed. Lemma zx_braiding_commutes (A1 B1 A2 B2 : nat) (f1 : ZX A1 B1) (f2 : ZX A2 B2) : @@ -956,7 +401,7 @@ Proof. rewrite 2!n_top_to_bottom_semantics_eq_kron_comm. rewrite 2!stack_semantics, Mscale_1_l. rewrite <- (kron_comm_commutes (2^B1)%nat (2^A1)%nat (2^B2) (2^A2) (⟦ f1 ⟧) (⟦ f2 ⟧)) - by (auto with wf_db). + by (auto with wf_db). rewrite Mmult_assoc. rewrite (Nat.add_comm B1 B2), (Nat.add_comm A2 A1). rewrite 2!Nat.pow_add_r. @@ -990,438 +435,438 @@ Proof. Qed. Definition ZXBraidingIsomorphism : forall n m, - Isomorphism (ZXTensorBiFunctor n m) ((CommuteBifunctor ZXTensorBiFunctor) n m) := - fun n m => Build_Isomorphism nat ZXCategory _ _ - ((* forward := *) @zx_braiding n m) - ((* reverse := *) @zx_inv_braiding n m) - ((* id_A := *) @zx_braiding_inv_compose n m) - ((* id_B := *) @zx_inv_braiding_compose n m). + Isomorphism (ZXTensorBiFunctor n m) ((CommuteBifunctor ZXTensorBiFunctor) n m) := + fun n m => Build_Isomorphism nat ZXCategory _ _ + ((* forward := *) @zx_braiding n m) + ((* reverse := *) @zx_inv_braiding n m) + ((* id_A := *) @zx_braiding_inv_compose n m) + ((* id_B := *) @zx_inv_braiding_compose n m). #[export] Instance ZXBraidingBiIsomorphism : - NaturalBiIsomorphism ZXTensorBiFunctor (CommuteBifunctor ZXTensorBiFunctor) := {| - component2_iso := ZXBraidingIsomorphism; - component2_iso_natural := zx_braiding_iso_natural; + NaturalBiIsomorphism ZXTensorBiFunctor (CommuteBifunctor ZXTensorBiFunctor) := {| + component2_iso := ZXBraidingIsomorphism; + component2_iso_natural := zx_braiding_iso_natural; |}. #[export] Instance ZXBraidedMonoidalCategory : BraidedMonoidalCategory nat := { - braiding := ZXBraidingBiIsomorphism; + braiding := ZXBraidingBiIsomorphism; - hexagon_1 := @hexagon_lemma_1; - hexagon_2 := @hexagon_lemma_2; + hexagon_1 := @hexagon_lemma_1; + hexagon_2 := @hexagon_lemma_2; (* - braiding := @zx_braiding; - inv_braiding := @zx_inv_braiding; - braiding_inv_compose := @zx_braiding_inv_compose; - inv_braiding_compose := @zx_inv_braiding_compose; + braiding := @zx_braiding; + inv_braiding := @zx_inv_braiding; + braiding_inv_compose := @zx_braiding_inv_compose; + inv_braiding_compose := @zx_inv_braiding_compose; - hexagon_1 := @hexagon_lemma_1; - hexagon_2 := @hexagon_lemma_2; + hexagon_1 := @hexagon_lemma_1; + hexagon_2 := @hexagon_lemma_2; *) }. Lemma n_top_to_bottom_is_bottom_to_top : forall {n m}, - n_top_to_bottom n m ∝ n_bottom_to_top m n. + n_top_to_bottom n m ∝ n_bottom_to_top m n. Proof. - prop_perm_eq. - solve_modular_permutation_equalities. + prop_perm_eq. + solve_modular_permutation_equalities. (* - intros. - unfold n_bottom_to_top. - unfold bottom_to_top. - unfold n_top_to_bottom. - induction n. - - intros. - rewrite n_compose_0. - simpl. - rewrite <- n_compose_transpose. - rewrite n_compose_n_top_to_bottom. - rewrite n_wire_transpose. - reflexivity. - - intros. - rewrite n_compose_grow_l. - assert ((S n + m)%nat = (n + S m)%nat) by lia. - assert (top_to_bottom (S n + m) - ∝ cast (S n + m) (S n + m) H H (top_to_bottom (n + S m))) - by (rewrite cast_fn_eq_dim; reflexivity). - rewrite H0. - rewrite cast_n_compose. - rewrite IHn. - rewrite <- cast_n_compose. - rewrite <- H0. - rewrite n_compose_grow_l. - rewrite <- cast_transpose. - rewrite <- H0. - fold (bottom_to_top (S n + m)). - rewrite <- compose_assoc. - rewrite top_to_bottom_to_top. cleanup_zx. - reflexivity. *) + intros. + unfold n_bottom_to_top. + unfold bottom_to_top. + unfold n_top_to_bottom. + induction n. + - intros. + rewrite n_compose_0. + simpl. + rewrite <- n_compose_transpose. + rewrite n_compose_n_top_to_bottom. + rewrite n_wire_transpose. + reflexivity. + - intros. + rewrite n_compose_grow_l. + assert ((S n + m)%nat = (n + S m)%nat) by lia. + assert (top_to_bottom (S n + m) + ∝ cast (S n + m) (S n + m) H H (top_to_bottom (n + S m))) + by (rewrite cast_fn_eq_dim; reflexivity). + rewrite H0. + rewrite cast_n_compose. + rewrite IHn. + rewrite <- cast_n_compose. + rewrite <- H0. + rewrite n_compose_grow_l. + rewrite <- cast_transpose. + rewrite <- H0. + fold (bottom_to_top (S n + m)). + rewrite <- compose_assoc. + rewrite top_to_bottom_to_top. cleanup_zx. + reflexivity. *) Qed. Lemma braiding_symmetry : forall n m, - @zx_braiding n m ∝ @zx_inv_braiding m n. + @zx_braiding n m ∝ @zx_inv_braiding m n. Proof. - prop_perm_eq. - solve_modular_permutation_equalities. - (* intros. - unfold zx_braiding. unfold zx_inv_braiding. - apply cast_compat. - rewrite n_top_to_bottom_is_bottom_to_top. - reflexivity. *) + prop_perm_eq. + solve_modular_permutation_equalities. +(* intros. + unfold zx_braiding. unfold zx_inv_braiding. + apply cast_compat. + rewrite n_top_to_bottom_is_bottom_to_top. + reflexivity. *) Qed. #[export] Instance ZXSymmetricMonoidalCategory : SymmetricMonoidalCategory nat := { - symmetry := braiding_symmetry; + symmetry := braiding_symmetry; }. Lemma nwire_adjoint : forall n, (n_wire n) †%ZX ∝ n_wire n. Proof. - intros. - induction n; try easy. - - intros. - unfold ZXCore.adjoint. - simpl. - unfold ZXCore.adjoint in IHn. - rewrite IHn. - reflexivity. + intros. + induction n; try easy. + - intros. + unfold ZXCore.adjoint. + simpl. + unfold ZXCore.adjoint in IHn. + rewrite IHn. + reflexivity. Qed. Lemma compose_adjoint : forall {n m o} - (zx0 : ZX n m) (zx1 : ZX m o), - (zx0 ⟷ zx1) †%ZX ∝ zx1 †%ZX ⟷ zx0 †%ZX. + (zx0 : ZX n m) (zx1 : ZX m o), + (zx0 ⟷ zx1) †%ZX ∝ zx1 †%ZX ⟷ zx0 †%ZX. Proof. - intros; easy. + intros; easy. Qed. Lemma stack_adjoint : forall {n n' m m'} - (zx0: ZX n m) (zx1 : ZX n' m'), - (zx0 ↕ zx1) †%ZX ∝ zx0 †%ZX ↕ zx1 †%ZX. + (zx0: ZX n m) (zx1 : ZX n' m'), + (zx0 ↕ zx1) †%ZX ∝ zx0 †%ZX ↕ zx1 †%ZX. Proof. - intros. - unfold ZXCore.adjoint. - simpl. - easy. + intros. + unfold ZXCore.adjoint. + simpl. + easy. Qed. #[export] Instance ZXDaggerCategory : DaggerCategory nat := { - adjoint := @ZXCore.adjoint; - involutive := @Proportional.adjoint_involutive; - preserves_id := nwire_adjoint; - contravariant := @compose_adjoint; + adjoint := @ZXCore.adjoint; + involutive := @Proportional.adjoint_involutive; + preserves_id := nwire_adjoint; + contravariant := @compose_adjoint; }. Lemma zx_dagger_compat : forall {n n' m m'} - (zx0: ZX n m) (zx1 : ZX n' m'), - zx0 †%ZX ↕ zx1 †%ZX ∝ (zx0 ↕ zx1) †%ZX. + (zx0: ZX n m) (zx1 : ZX n' m'), + zx0 †%ZX ↕ zx1 †%ZX ∝ (zx0 ↕ zx1) †%ZX. Proof. - intros. - rewrite stack_adjoint. - easy. + intros. + rewrite stack_adjoint. + easy. Qed. Lemma zx_associator_unitary_r : forall {n m o}, - zx_associator ⟷ zx_associator †%ZX ∝ n_wire (n + m + o). + zx_associator ⟷ zx_associator †%ZX ∝ n_wire (n + m + o). Proof. - intros. - unfold zx_associator. - rewrite cast_adj. - rewrite nwire_adjoint. - simpl_permlike_zx. - reflexivity. + intros. + unfold zx_associator. + rewrite cast_adj. + rewrite nwire_adjoint. + simpl_permlike_zx. + reflexivity. Qed. Lemma zx_associator_unitary_l : forall {n m o}, - zx_associator †%ZX ⟷ zx_associator ∝ n_wire (n + (m + o)). -Proof. - intros. - unfold zx_associator. - rewrite cast_adj. - rewrite nwire_adjoint. - simpl_permlike_zx. - rewrite cast_n_wire. - reflexivity. + zx_associator †%ZX ⟷ zx_associator ∝ n_wire (n + (m + o)). +Proof. + intros. + unfold zx_associator. + rewrite cast_adj. + rewrite nwire_adjoint. + simpl_permlike_zx. + rewrite cast_n_wire. + reflexivity. Qed. Lemma zx_left_unitor_unitary_r : forall {n}, - zx_left_unitor ⟷ zx_left_unitor †%ZX ∝ n_wire (0 + n). + zx_left_unitor ⟷ zx_left_unitor †%ZX ∝ n_wire (0 + n). Proof. - intros. unfold zx_left_unitor. - simpl_permlike_zx. - rewrite nwire_adjoint. - reflexivity. + intros. unfold zx_left_unitor. + simpl_permlike_zx. + rewrite nwire_adjoint. + reflexivity. Qed. Lemma zx_left_unitor_unitary_l : forall {n}, - zx_left_unitor †%ZX ⟷ zx_left_unitor ∝ n_wire n. + zx_left_unitor †%ZX ⟷ zx_left_unitor ∝ n_wire n. Proof. - intros. unfold zx_left_unitor. - simpl_permlike_zx. - rewrite nwire_adjoint. - reflexivity. + intros. unfold zx_left_unitor. + simpl_permlike_zx. + rewrite nwire_adjoint. + reflexivity. Qed. Lemma zx_right_unitor_unitary_r : forall {n}, - zx_right_unitor ⟷ zx_right_unitor †%ZX ∝ n_wire (n + 0). + zx_right_unitor ⟷ zx_right_unitor †%ZX ∝ n_wire (n + 0). Proof. - intros. unfold zx_right_unitor. - simpl_permlike_zx. - rewrite cast_adj, nwire_adjoint. - simpl_permlike_zx. - rewrite cast_n_wire. - reflexivity. + intros. unfold zx_right_unitor. + simpl_permlike_zx. + rewrite cast_adj, nwire_adjoint. + simpl_permlike_zx. + rewrite cast_n_wire. + reflexivity. Qed. Lemma zx_right_unitor_unitary_l : forall {n}, - zx_right_unitor †%ZX ⟷ zx_right_unitor ∝ n_wire n. + zx_right_unitor †%ZX ⟷ zx_right_unitor ∝ n_wire n. Proof. - intros. unfold zx_right_unitor. - simpl_permlike_zx. - rewrite cast_adj, nwire_adjoint. - simpl_permlike_zx. - reflexivity. + intros. unfold zx_right_unitor. + simpl_permlike_zx. + rewrite cast_adj, nwire_adjoint. + simpl_permlike_zx. + reflexivity. Qed. Lemma helper_eq: forall n m (prf: n = m), (n + n = m + m)%nat. Proof. intros. subst. reflexivity. Qed. Lemma cast_n_cup_unswapped : forall n m prf1 prf2, - cast (n + n) 0 (helper_eq _ _ prf1) prf2 (n_cup_unswapped m) ∝ n_cup_unswapped n. + cast (n + n) 0 (helper_eq _ _ prf1) prf2 (n_cup_unswapped m) ∝ n_cup_unswapped n. Proof. - intros. - subst. - rewrite cast_id_eq. - easy. + intros. + subst. + rewrite cast_id_eq. + easy. Qed. Lemma wire_stack_distr_compose_l : forall n m o (zx0 : ZX n m) (zx1 : ZX m o), - — ↕ (zx0 ⟷ zx1) ∝ (— ↕ zx0) ⟷ (— ↕ zx1). + — ↕ (zx0 ⟷ zx1) ∝ (— ↕ zx0) ⟷ (— ↕ zx1). Proof. - intros. - rewrite <- stack_compose_distr. - cleanup_zx. - easy. + intros. + rewrite <- stack_compose_distr. + cleanup_zx. + easy. Qed. Lemma wire_stack_distr_compose_r : forall n m o (zx0 : ZX n m) (zx1 : ZX m o), - (zx0 ⟷ zx1) ↕ — ∝ (zx0 ↕ —) ⟷ (zx1 ↕ —). + (zx0 ⟷ zx1) ↕ — ∝ (zx0 ↕ —) ⟷ (zx1 ↕ —). Proof. - intros. - rewrite <- stack_compose_distr. - cleanup_zx. - easy. + intros. + rewrite <- stack_compose_distr. + cleanup_zx. + easy. Qed. Lemma n_cup_unswapped_grow_r : forall n prf1 prf2, - n_cup_unswapped (S n) ∝ - cast _ _ prf1 prf2 (— ↕ n_cup_unswapped n ↕ —) ⟷ ⊃. -Proof. - intros. - induction n. - - simpl. cleanup_zx. - apply compose_simplify; [|easy]. - prop_perm_eq. - - rewrite n_cup_unswapped_grow_l. - rewrite IHn at 1. - rewrite n_cup_unswapped_grow_l. - bundle_wires. - rewrite <- compose_assoc. - apply compose_simplify; [|easy]. - rewrite wire_stack_distr_compose_l, wire_stack_distr_compose_r. - rewrite (prop_iff_double_cast _ (1 + 0 + 1) _ _ eq_refl). - simpl_casts. - rewrite (cast_compose_mid_contract _ (1 + (n + n) + 1)). - rewrite 2!cast_contract, cast_id. - apply compose_simplify; [|easy]. - simpl_casts. - simpl (n_wire S n). - rewrite 4!stack_assoc. - rewrite (stack_assoc — (n_wire n) _). - simpl_casts. - repeat (rewrite cast_stack_distribute; apply stack_simplify; try prop_perm_eq). - rewrite cast_id; easy. - Unshelve. - all: lia. + n_cup_unswapped (S n) ∝ + cast _ _ prf1 prf2 (— ↕ n_cup_unswapped n ↕ —) ⟷ ⊃. +Proof. + intros. + induction n. + - simpl. cleanup_zx. + apply compose_simplify; [|easy]. + prop_perm_eq. + - rewrite n_cup_unswapped_grow_l. + rewrite IHn at 1. + rewrite n_cup_unswapped_grow_l. + bundle_wires. + rewrite <- compose_assoc. + apply compose_simplify; [|easy]. + rewrite wire_stack_distr_compose_l, wire_stack_distr_compose_r. + rewrite (prop_iff_double_cast _ (1 + 0 + 1) _ _ eq_refl). + simpl_casts. + rewrite (cast_compose_mid_contract _ (1 + (n + n) + 1)). + rewrite 2!cast_contract, cast_id. + apply compose_simplify; [|easy]. + simpl_casts. + simpl (n_wire S n). + rewrite 4!stack_assoc. + rewrite (stack_assoc — (n_wire n) _). + simpl_casts. + repeat (rewrite cast_stack_distribute; apply stack_simplify; try prop_perm_eq). + rewrite cast_id; easy. + Unshelve. + all: lia. Qed. Lemma nwire_stack_distr_compose_l : forall k n m o (zx0 : ZX n m) (zx1 : ZX m o), - n_wire k ↕ (zx0 ⟷ zx1) ∝ (n_wire k ↕ zx0) ⟷ (n_wire k ↕ zx1). + n_wire k ↕ (zx0 ⟷ zx1) ∝ (n_wire k ↕ zx0) ⟷ (n_wire k ↕ zx1). Proof. - intros. - rewrite <- stack_compose_distr. - cleanup_zx. - easy. + intros. + rewrite <- stack_compose_distr. + cleanup_zx. + easy. Qed. Lemma nwire_stack_distr_compose_r : forall k n m o (zx0 : ZX n m) (zx1 : ZX m o), - (zx0 ⟷ zx1) ↕ n_wire k ∝ (zx0 ↕ n_wire k) ⟷ (zx1 ↕ n_wire k). + (zx0 ⟷ zx1) ↕ n_wire k ∝ (zx0 ↕ n_wire k) ⟷ (zx1 ↕ n_wire k). Proof. - intros. - rewrite <- stack_compose_distr. - cleanup_zx. - easy. + intros. + rewrite <- stack_compose_distr. + cleanup_zx. + easy. Qed. Lemma n_cup_unswapped_comm_1 : forall k prf1 prf2 prf3 prf4, - cast (S k + (S k)) _ prf1 prf2 (n_wire k ↕ ⊃ ↕ n_wire k) ⟷ (n_cup_unswapped k) - ∝ cast _ _ prf3 prf4 (— ↕ n_cup_unswapped k ↕ —) ⟷ ⊃. + cast (S k + (S k)) _ prf1 prf2 (n_wire k ↕ ⊃ ↕ n_wire k) ⟷ (n_cup_unswapped k) + ∝ cast _ _ prf3 prf4 (— ↕ n_cup_unswapped k ↕ —) ⟷ ⊃. Proof. - intros. - rewrite <- n_cup_unswapped_grow_l. - rewrite <- n_cup_unswapped_grow_r. - easy. + intros. + rewrite <- n_cup_unswapped_grow_l. + rewrite <- n_cup_unswapped_grow_r. + easy. Qed. Lemma n_wire_add_stack : forall n k, - n_wire (n + k) ∝ n_wire n ↕ n_wire k. + n_wire (n + k) ∝ n_wire n ↕ n_wire k. Proof. prop_perm_eq. Qed. Lemma n_wire_add_stack_rev : forall n k prf1 prf2, - n_wire (n + k) ∝ cast _ _ prf1 prf2 (n_wire k ↕ n_wire n). + n_wire (n + k) ∝ cast _ _ prf1 prf2 (n_wire k ↕ n_wire n). Proof. prop_perm_eq. Qed. Lemma stack_assoc' : forall {n0 n1 n2 m0 m1 m2} (zx0 : ZX n0 m0) - (zx1 : ZX n1 m1) (zx2 : ZX n2 m2) prfn prfm, - zx0 ↕ (zx1 ↕ zx2) ∝ cast _ _ prfn prfm ((zx0 ↕ zx1) ↕ zx2). + (zx1 : ZX n1 m1) (zx2 : ZX n2 m2) prfn prfm, + zx0 ↕ (zx1 ↕ zx2) ∝ cast _ _ prfn prfm ((zx0 ↕ zx1) ↕ zx2). Proof. - intros. - rewrite stack_assoc. - rewrite cast_cast_eq, cast_id. - easy. - Unshelve. - all: lia. + intros. + rewrite stack_assoc. + rewrite cast_cast_eq, cast_id. + easy. + Unshelve. + all: lia. Qed. Lemma n_cup_unswapped_comm : forall n k prf1 prf2 prf3 prf4, - cast (S n + k + (S n + k)) _ prf1 prf2 (n_wire (n + k) ↕ ⊃ ↕ n_wire (n + k)) ⟷ (n_wire n ↕ n_cup_unswapped k ↕ n_wire n) - ∝ cast _ _ prf3 prf4 (n_wire (S n) ↕ n_cup_unswapped k ↕ n_wire (S n)) ⟷ (n_wire n ↕ ⊃ ↕ n_wire n). -Proof. - intros. - rewrite n_wire_add_stack at 1. - rewrite n_wire_add_stack_rev at 1. - rewrite (n_wire_add_stack_rev 1 n) at 1. - rewrite (n_wire_add_stack 1 n) at 1. - rewrite 5!stack_assoc. - repeat rewrite cast_cast_eq. - simpl_casts. - rewrite stack_assoc. - rewrite (prop_iff_double_cast (n + (k + 2 + k + n)) (n + (0 + n))). - rewrite (cast_compose_mid_contract _ (n + (k + 0 + k + n))). - repeat rewrite cast_cast_eq. - rewrite (cast_compose_mid_contract _ (n + (2 + n))). - rewrite 2!cast_cast_eq. - erewrite 3!(cast_stack_distribute _ _ _ _ _ _ (n_wire n)). - rewrite 4!cast_id_eq. - rewrite <- 2!nwire_stack_distr_compose_l. - apply stack_simplify; [easy|]. - rewrite <- wire_to_n_wire. - rewrite 3!stack_assoc'. - rewrite (stack_assoc' (_ ↕ _) (—) (n_wire n)). - repeat rewrite cast_cast_eq. - rewrite 3!(cast_stack_distribute (o':=n)). - rewrite (cast_id (n:=n)). - rewrite <- 2!nwire_stack_distr_compose_r. - apply stack_simplify; [|easy]. - rewrite (prop_iff_double_cast ((S k) + (S k)) (0)). - simpl_permlike_zx. - symmetry. - rewrite (cast_compose_mid_contract _ 2). - symmetry. - rewrite cast_id_eq. - rewrite cast_cast_eq. - rewrite <- n_cup_unswapped_comm_1. - rewrite (prop_iff_double_cast ((S k) + (S k)) (0)). - rewrite (cast_compose_mid_contract _ (k + k)). - simpl_casts. - easy. - Unshelve. - all: lia. + cast (S n + k + (S n + k)) _ prf1 prf2 (n_wire (n + k) ↕ ⊃ ↕ n_wire (n + k)) ⟷ (n_wire n ↕ n_cup_unswapped k ↕ n_wire n) + ∝ cast _ _ prf3 prf4 (n_wire (S n) ↕ n_cup_unswapped k ↕ n_wire (S n)) ⟷ (n_wire n ↕ ⊃ ↕ n_wire n). +Proof. + intros. + rewrite n_wire_add_stack at 1. + rewrite n_wire_add_stack_rev at 1. + rewrite (n_wire_add_stack_rev 1 n) at 1. + rewrite (n_wire_add_stack 1 n) at 1. + rewrite 5!stack_assoc. + repeat rewrite cast_cast_eq. + simpl_casts. + rewrite stack_assoc. + rewrite (prop_iff_double_cast (n + (k + 2 + k + n)) (n + (0 + n))). + rewrite (cast_compose_mid_contract _ (n + (k + 0 + k + n))). + repeat rewrite cast_cast_eq. + rewrite (cast_compose_mid_contract _ (n + (2 + n))). + rewrite 2!cast_cast_eq. + erewrite 3!(cast_stack_distribute _ _ _ _ _ _ (n_wire n)). + rewrite 4!cast_id_eq. + rewrite <- 2!nwire_stack_distr_compose_l. + apply stack_simplify; [easy|]. + rewrite <- wire_to_n_wire. + rewrite 3!stack_assoc'. + rewrite (stack_assoc' (_ ↕ _) (—) (n_wire n)). + repeat rewrite cast_cast_eq. + rewrite 3!(cast_stack_distribute (o':=n)). + rewrite (cast_id (n:=n)). + rewrite <- 2!nwire_stack_distr_compose_r. + apply stack_simplify; [|easy]. + rewrite (prop_iff_double_cast ((S k) + (S k)) (0)). + simpl_permlike_zx. + symmetry. + rewrite (cast_compose_mid_contract _ 2). + symmetry. + rewrite cast_id_eq. + rewrite cast_cast_eq. + rewrite <- n_cup_unswapped_comm_1. + rewrite (prop_iff_double_cast ((S k) + (S k)) (0)). + rewrite (cast_compose_mid_contract _ (k + k)). + simpl_casts. + easy. + Unshelve. + all: lia. Qed. Lemma n_cup_unswapped_grow_k_l : forall n k prf1 prf2, - n_cup_unswapped (n + k) ∝ - cast _ _ prf1 prf2 (n_wire n ↕ (n_cup_unswapped k) ↕ n_wire n) ⟷ n_cup_unswapped n. -Proof. - intros. - induction n. - - rewrite stack_empty_l, stack_empty_r, cast_cast_eq, - cast_id_eq, compose_empty_r. - easy. - - rewrite (prop_iff_double_cast (S (n + k) + S (n + k)) 0). - rewrite cast_n_cup_unswapped. - rewrite n_cup_unswapped_grow_l, IHn. - rewrite n_cup_unswapped_grow_l. - simpl_permlike_zx. - (* simpl_casts. *) - repeat rewrite <- compose_assoc. - apply compose_simplify; [|easy]. - symmetry. - rewrite (cast_compose_mid (n + 2 + n)). - erewrite (prop_iff_double_cast ((S n + k) + (S n + k)) (n + 0 + n) _ _ eq_refl). - rewrite 2!cast_contract. - rewrite cast_compose_mid_contract, 2!cast_contract, cast_id. - rewrite cast_compose_mid_contract, 2!cast_contract, cast_id. - rewrite n_cup_unswapped_comm. - easy. - Unshelve. - all: try easy; auto with arith; lia. + n_cup_unswapped (n + k) ∝ + cast _ _ prf1 prf2 (n_wire n ↕ (n_cup_unswapped k) ↕ n_wire n) ⟷ n_cup_unswapped n. +Proof. + intros. + induction n. + - rewrite stack_empty_l, stack_empty_r, cast_cast_eq, + cast_id_eq, compose_empty_r. + easy. + - rewrite (prop_iff_double_cast (S (n + k) + S (n + k)) 0). + rewrite cast_n_cup_unswapped. + rewrite n_cup_unswapped_grow_l, IHn. + rewrite n_cup_unswapped_grow_l. + simpl_permlike_zx. + (* simpl_casts. *) + repeat rewrite <- compose_assoc. + apply compose_simplify; [|easy]. + symmetry. + rewrite (cast_compose_mid (n + 2 + n)). + erewrite (prop_iff_double_cast ((S n + k) + (S n + k)) (n + 0 + n) _ _ eq_refl). + rewrite 2!cast_contract. + rewrite cast_compose_mid_contract, 2!cast_contract, cast_id. + rewrite cast_compose_mid_contract, 2!cast_contract, cast_id. + rewrite n_cup_unswapped_comm. + easy. + Unshelve. + all: try easy; auto with arith; lia. Qed. Lemma n_cup_unswapped_add_comm : forall n k prf1 prf2, - n_cup_unswapped (n + k) ∝ cast _ _ prf1 prf2 (n_cup_unswapped (k + n)). -Proof. - intros. - assert (H: (k + n = n + k)%nat) by lia. - generalize dependent (k + n)%nat. - generalize dependent (n + k)%nat. - intros; subst. - rewrite cast_id_eq. - easy. + n_cup_unswapped (n + k) ∝ cast _ _ prf1 prf2 (n_cup_unswapped (k + n)). +Proof. + intros. + assert (H: (k + n = n + k)%nat) by lia. + generalize dependent (k + n)%nat. + generalize dependent (n + k)%nat. + intros; subst. + rewrite cast_id_eq. + easy. Qed. Lemma n_cup_unswapped_grow_k_r : forall n k prf1 prf2, - n_cup_unswapped (n + k) ∝ - cast _ _ prf1 prf2 (n_wire k ↕ (n_cup_unswapped n) ↕ n_wire k) ⟷ n_cup_unswapped k. -Proof. - intros. - rewrite n_cup_unswapped_add_comm. - rewrite n_cup_unswapped_grow_k_l. - rewrite (cast_compose_mid_contract _ (k + k)%nat). - simpl_casts. - apply compose_simplify; [|easy]. - erewrite cast_proof_independence. - reflexivity. - Unshelve. all:lia. + n_cup_unswapped (n + k) ∝ + cast _ _ prf1 prf2 (n_wire k ↕ (n_cup_unswapped n) ↕ n_wire k) ⟷ n_cup_unswapped k. +Proof. + intros. + rewrite n_cup_unswapped_add_comm. + rewrite n_cup_unswapped_grow_k_l. + rewrite (cast_compose_mid_contract _ (k + k)%nat). + simpl_casts. + apply compose_simplify; [|easy]. + erewrite cast_proof_independence. + reflexivity. + Unshelve. all:lia. Qed. Lemma stack_ncup_unswapped_split : forall {n0t n0b n1t n1b} n m (zx0top : ZX n0t m) (zx0bot : ZX n0b m) - (zx1top : ZX n1t n) (zx1bot : ZX n1b n) prf1 prf2 prf3 prf4 prf5 prf6, - (zx1top ↕ zx0top) ↕ (zx0bot ↕ zx1bot) - ⟷ cast ((n + m) + (m + n)) 0 prf1 prf2 (n_cup_unswapped (n + m)) - ∝ cast _ _ prf5 prf6 (zx1top ↕ ((zx0top ↕ zx0bot) ⟷ n_cup_unswapped m) ↕ zx1bot - ⟷ cast (n + 0 + n) 0 prf3 prf4 (n_cup_unswapped n)). -Proof. - intros. - rewrite cast_compose_r. - simpl_permlike_zx. - rewrite n_cup_unswapped_grow_k_l, <- compose_assoc. - rewrite (cast_compose_mid_contract _ (n + n)). - simpl_permlike_zx. - apply compose_simplify; [|easy]. - rewrite stack_assoc, (stack_assoc' zx0top). - simpl_casts. - rewrite stack_assoc'. - rewrite cast_cast_eq. - rewrite (prop_iff_double_cast (n1t + (n0t + n0b) + n1b) (n + 0 + n)). - rewrite (cast_compose_mid_contract _ (n + (m + m) + n)). - simpl_permlike_zx. - rewrite <- 2!stack_compose_distr, 2!nwire_removal_r. - easy. - Unshelve. - all: lia. + (zx1top : ZX n1t n) (zx1bot : ZX n1b n) prf1 prf2 prf3 prf4 prf5 prf6, + (zx1top ↕ zx0top) ↕ (zx0bot ↕ zx1bot) + ⟷ cast ((n + m) + (m + n)) 0 prf1 prf2 (n_cup_unswapped (n + m)) + ∝ cast _ _ prf5 prf6 (zx1top ↕ ((zx0top ↕ zx0bot) ⟷ n_cup_unswapped m) ↕ zx1bot + ⟷ cast (n + 0 + n) 0 prf3 prf4 (n_cup_unswapped n)). +Proof. + intros. + rewrite cast_compose_r. + simpl_permlike_zx. + rewrite n_cup_unswapped_grow_k_l, <- compose_assoc. + rewrite (cast_compose_mid_contract _ (n + n)). + simpl_permlike_zx. + apply compose_simplify; [|easy]. + rewrite stack_assoc, (stack_assoc' zx0top). + simpl_casts. + rewrite stack_assoc'. + rewrite cast_cast_eq. + rewrite (prop_iff_double_cast (n1t + (n0t + n0b) + n1b) (n + 0 + n)). + rewrite (cast_compose_mid_contract _ (n + (m + m) + n)). + simpl_permlike_zx. + rewrite <- 2!stack_compose_distr, 2!nwire_removal_r. + easy. + Unshelve. + all: lia. Qed. (* Local Open Scope matrix_scope. *) @@ -1429,16 +874,16 @@ Qed. Lemma sem_n_cup_unswapped_2 : - ⟦ n_cup_unswapped 2 ⟧ = - fun x y => if (x=?0) && ((y=?0) || (y=?6) || (y=?9) || (y=?15)) then C1 else C0. + ⟦ n_cup_unswapped 2 ⟧ = + fun x y => if (x=?0) && ((y=?0) || (y=?6) || (y=?9) || (y=?15)) then C1 else C0. Proof. - unfold n_cup_unswapped. - repeat (simpl; - rewrite cast_semantics; simpl). - rewrite 2!id_kron. - replace (list2D_to_matrix [[C1; C0; C0; C1]]) with - (fun x y => if (x =? 0)&&((y =? 0) || (y =? 3)) then C1 else C0) by solve_matrix. - solve_matrix. + unfold n_cup_unswapped. + repeat (simpl; + rewrite cast_semantics; simpl). + rewrite 2!id_kron. + replace (list2D_to_matrix [[C1; C0; C0; C1]]) with + (fun x y => if (x =? 0)&&((y =? 0) || (y =? 3)) then C1 else C0) by solve_matrix. + solve_matrix. Qed. @@ -1446,379 +891,379 @@ Qed. Lemma swap_2cup_transport : - ⟦ n_cup_unswapped 2 ⟧ × (swap ⊗ (Matrix.I (2^2))) - = ⟦ n_cup_unswapped 2 ⟧ × ((Matrix.I (2^2)) ⊗ swap). + ⟦ n_cup_unswapped 2 ⟧ × (swap ⊗ (Matrix.I (2^2))) + = ⟦ n_cup_unswapped 2 ⟧ × ((Matrix.I (2^2)) ⊗ swap). Proof. - apply mat_equiv_eq; auto with wf_db. - rewrite sem_n_cup_unswapped_2. - by_cell; try lca. + apply mat_equiv_eq; auto with wf_db. + rewrite sem_n_cup_unswapped_2. + by_cell; try lca. Qed. Lemma swap_2cup_flip : - ⨉ ↕ n_wire 2 ⟷ n_cup_unswapped 2 ∝ n_wire 2 ↕ ⨉ ⟷ n_cup_unswapped 2. + ⨉ ↕ n_wire 2 ⟷ n_cup_unswapped 2 ∝ n_wire 2 ↕ ⨉ ⟷ n_cup_unswapped 2. Proof. - prop_exists_nonzero 1. - rewrite (compose_semantics (⨉ ↕ n_wire 2) (n_cup_unswapped 2)). (* For some reason, it needs *) - rewrite (compose_semantics (n_wire 2 ↕ ⨉) (n_cup_unswapped 2)). (* its arguments explicitly *) - rewrite Mscale_1_l. - rewrite 2!stack_semantics, n_wire_semantics. - simpl (⟦ ⨉ ⟧). - apply swap_2cup_transport. + prop_exists_nonzero 1. + rewrite (compose_semantics (⨉ ↕ n_wire 2) (n_cup_unswapped 2)). (* For some reason, it needs *) + rewrite (compose_semantics (n_wire 2 ↕ ⨉) (n_cup_unswapped 2)). (* its arguments explicitly *) + rewrite Mscale_1_l. + rewrite 2!stack_semantics, n_wire_semantics. + simpl (⟦ ⨉ ⟧). + apply swap_2cup_transport. Qed. Tactic Notation "preplace" constr(zx0) "with" constr(zx1) := - (let H := fresh "H" in - enough (H: zx0 ∝ zx1); - [rewrite H; clear H| ]). + (let H := fresh "H" in + enough (H: zx0 ∝ zx1); + [rewrite H; clear H| ]). Tactic Notation "preplace" constr(zx0) "with" constr(zx1) "in" hyp(target) := - (let H := fresh "H" in - enough (H: zx0 ∝ zx1); - [rewrite H in target; clear H| ]). + (let H := fresh "H" in + enough (H: zx0 ∝ zx1); + [rewrite H in target; clear H| ]). Tactic Notation "preplace" constr(zx0) "with" constr(zx1) "in" "*" := - (let H := fresh "H" in - enough (H: zx0 ∝ zx1); - [rewrite H in *; clear H| ]). + (let H := fresh "H" in + enough (H: zx0 ∝ zx1); + [rewrite H in *; clear H| ]). Tactic Notation "preplace" constr(zx0) "with" constr(zx1) "by" tactic(slv) := - (let H := fresh "H" in - assert (H: zx0 ∝ zx1) by slv; - rewrite H; clear H). + (let H := fresh "H" in + assert (H: zx0 ∝ zx1) by slv; + rewrite H; clear H). Tactic Notation "preplace" constr(zx0) "with" constr(zx1) "in" hyp(target) "by" tactic(slv) := - (let H := fresh "H" in - assert (H: zx0 ∝ zx1) by slv; - rewrite H in target; clear H). + (let H := fresh "H" in + assert (H: zx0 ∝ zx1) by slv; + rewrite H in target; clear H). Tactic Notation "preplace" constr(zx0) "with" constr(zx1) "in" "*" "by" tactic(slv) := - (let H := fresh "H" in - assert (H: zx0 ∝ zx1) by slv; - rewrite H in *; clear H). + (let H := fresh "H" in + assert (H: zx0 ∝ zx1) by slv; + rewrite H in *; clear H). Lemma n_cup_unswapped_split_stack : forall {n00 m0 n01 m1 n10 n11 nbot} - (zx00 : ZX n00 m0) (zx01 : ZX n01 m1) (zx10 : ZX n10 m1) (zx11 : ZX n11 m0) - prf1 prf2 prf3 prf4 prf5 prf6, - (zx00 ↕ zx01) ↕ (cast nbot (m0 + m1) prf1 prf2 (zx10 ↕ zx11)) ⟷ n_cup_unswapped (m0 + m1) - ∝ - cast _ 0 prf5 prf6 (n_wire n00 - ↕ (zx01 ↕ zx10 ⟷ n_cup_unswapped m1) ↕ n_wire n11 ⟷ - cast _ 0 prf3 prf4 (zx00 ↕ zx11 ⟷ n_cup_unswapped m0)). -Proof. - intros. - rewrite cast_stack_r. - rewrite n_cup_unswapped_grow_k_l. - rewrite <- compose_assoc. - rewrite stack_assoc, (stack_assoc' zx01). - simpl_casts. - rewrite stack_assoc', cast_cast_eq. - rewrite <- cast_compose_mid_contract. - rewrite <- 2!stack_compose_distr, 2!nwire_removal_r. - rewrite pull_out_bot, <- (nwire_removal_l zx11), stack_compose_distr. - rewrite cast_compose_distribute. - rewrite (cast_compose_mid_contract _ (n00 + 0 + n11)). - rewrite compose_assoc. - apply compose_simplify; [easy | ]. - simpl_casts. - rewrite cast_compose_mid_contract, cast_id_eq. - apply compose_simplify; [|easy]. - rewrite nwire_removal_l. - easy. - Unshelve. - all: lia. + (zx00 : ZX n00 m0) (zx01 : ZX n01 m1) (zx10 : ZX n10 m1) (zx11 : ZX n11 m0) + prf1 prf2 prf3 prf4 prf5 prf6, + (zx00 ↕ zx01) ↕ (cast nbot (m0 + m1) prf1 prf2 (zx10 ↕ zx11)) ⟷ n_cup_unswapped (m0 + m1) + ∝ + cast _ 0 prf5 prf6 (n_wire n00 + ↕ (zx01 ↕ zx10 ⟷ n_cup_unswapped m1) ↕ n_wire n11 ⟷ + cast _ 0 prf3 prf4 (zx00 ↕ zx11 ⟷ n_cup_unswapped m0)). +Proof. + intros. + rewrite cast_stack_r. + rewrite n_cup_unswapped_grow_k_l. + rewrite <- compose_assoc. + rewrite stack_assoc, (stack_assoc' zx01). + simpl_casts. + rewrite stack_assoc', cast_cast_eq. + rewrite <- cast_compose_mid_contract. + rewrite <- 2!stack_compose_distr, 2!nwire_removal_r. + rewrite pull_out_bot, <- (nwire_removal_l zx11), stack_compose_distr. + rewrite cast_compose_distribute. + rewrite (cast_compose_mid_contract _ (n00 + 0 + n11)). + rewrite compose_assoc. + apply compose_simplify; [easy | ]. + simpl_casts. + rewrite cast_compose_mid_contract, cast_id_eq. + apply compose_simplify; [|easy]. + rewrite nwire_removal_l. + easy. + Unshelve. + all: lia. Qed. Lemma n_cup_unswapped_split_stack' : forall {n00 m0 n01 m1 n10 n11 ntop} - (zx00 : ZX n00 m0) (zx01 : ZX n01 m1) (zx10 : ZX n10 m1) (zx11 : ZX n11 m0) - prf1 prf2 prf3 prf4 prf5 prf6, - (cast ntop (m1 + m0) prf1 prf2 (zx00 ↕ zx01)) ↕ (zx10 ↕ zx11) ⟷ n_cup_unswapped (m1 + m0) - ∝ - cast _ 0 prf5 prf6 (n_wire n00 - ↕ (zx01 ↕ zx10 ⟷ n_cup_unswapped m1) ↕ n_wire n11 ⟷ - cast _ 0 prf3 prf4 (zx00 ↕ zx11 ⟷ n_cup_unswapped m0)). -Proof. - intros. - rewrite (prop_iff_double_cast (ntop + (n10 + n11)) 0). - rewrite (cast_compose_mid_contract _ (m0 + m1 + (m0 + m1))). - rewrite cast_n_cup_unswapped by lia. - subst ntop. - rewrite cast_stack_distribute, cast_cast_eq, cast_id_eq. - rewrite n_cup_unswapped_split_stack. - simpl_casts. - easy. - Unshelve. - all: lia. + (zx00 : ZX n00 m0) (zx01 : ZX n01 m1) (zx10 : ZX n10 m1) (zx11 : ZX n11 m0) + prf1 prf2 prf3 prf4 prf5 prf6, + (cast ntop (m1 + m0) prf1 prf2 (zx00 ↕ zx01)) ↕ (zx10 ↕ zx11) ⟷ n_cup_unswapped (m1 + m0) + ∝ + cast _ 0 prf5 prf6 (n_wire n00 + ↕ (zx01 ↕ zx10 ⟷ n_cup_unswapped m1) ↕ n_wire n11 ⟷ + cast _ 0 prf3 prf4 (zx00 ↕ zx11 ⟷ n_cup_unswapped m0)). +Proof. + intros. + rewrite (prop_iff_double_cast (ntop + (n10 + n11)) 0). + rewrite (cast_compose_mid_contract _ (m0 + m1 + (m0 + m1))). + rewrite cast_n_cup_unswapped by lia. + subst ntop. + rewrite cast_stack_distribute, cast_cast_eq, cast_id_eq. + rewrite n_cup_unswapped_split_stack. + simpl_casts. + easy. + Unshelve. + all: lia. Qed. Lemma n_cup_unswapped_split_stack_cast : forall {n00 m0 n01 m1 n10 n11 ntop nbot} - (zx00 : ZX n00 m0) (zx01 : ZX n01 m1) (zx10 : ZX n10 m1) (zx11 : ZX n11 m0) - prf1 prf2 prf3 prf4 prf5 prf6 prf7 prf8, - (cast ntop (m0 + m1) prf1 prf2 (zx00 ↕ zx01)) - ↕ (cast nbot (m0 + m1) prf3 prf4 (zx10 ↕ zx11)) - ⟷ n_cup_unswapped (m0 + m1) ∝ - cast _ 0 prf5 prf6 ( - n_wire n00 ↕ (zx01 ↕ zx10 ⟷ n_cup_unswapped m1) ↕ n_wire n11 - ⟷ cast _ 0 prf7 prf8 (zx00 ↕ zx11 ⟷ n_cup_unswapped m0)). + (zx00 : ZX n00 m0) (zx01 : ZX n01 m1) (zx10 : ZX n10 m1) (zx11 : ZX n11 m0) + prf1 prf2 prf3 prf4 prf5 prf6 prf7 prf8, + (cast ntop (m0 + m1) prf1 prf2 (zx00 ↕ zx01)) + ↕ (cast nbot (m0 + m1) prf3 prf4 (zx10 ↕ zx11)) + ⟷ n_cup_unswapped (m0 + m1) ∝ + cast _ 0 prf5 prf6 ( + n_wire n00 ↕ (zx01 ↕ zx10 ⟷ n_cup_unswapped m1) ↕ n_wire n11 + ⟷ cast _ 0 prf7 prf8 (zx00 ↕ zx11 ⟷ n_cup_unswapped m0)). Proof. - intros. - subst ntop. - rewrite cast_id_eq. - apply n_cup_unswapped_split_stack. + intros. + subst ntop. + rewrite cast_id_eq. + apply n_cup_unswapped_split_stack. Qed. Lemma n_cup_unswapped_split_stack_n_wire_bot : forall {n0 n1} - (zx0 : ZX n0 n0) (zx1 : ZX n1 n1) prf1 prf2 prf3 prf4, - zx0 ↕ zx1 ↕ n_wire (n0 + n1) ⟷ n_cup_unswapped (n0 + n1) - ∝ cast _ 0 prf1 prf2 ( - n_wire n0 ↕ (zx1 ↕ n_wire n1 ⟷ n_cup_unswapped n1) ↕ n_wire n0 ⟷ - cast _ 0 prf3 prf4 (zx0 ↕ n_wire n0 ⟷ n_cup_unswapped n0)). -Proof. - intros. - rewrite n_wire_add_stack_rev. - rewrite n_cup_unswapped_split_stack. - easy. - Unshelve. - all: lia. + (zx0 : ZX n0 n0) (zx1 : ZX n1 n1) prf1 prf2 prf3 prf4, + zx0 ↕ zx1 ↕ n_wire (n0 + n1) ⟷ n_cup_unswapped (n0 + n1) + ∝ cast _ 0 prf1 prf2 ( + n_wire n0 ↕ (zx1 ↕ n_wire n1 ⟷ n_cup_unswapped n1) ↕ n_wire n0 ⟷ + cast _ 0 prf3 prf4 (zx0 ↕ n_wire n0 ⟷ n_cup_unswapped n0)). +Proof. + intros. + rewrite n_wire_add_stack_rev. + rewrite n_cup_unswapped_split_stack. + easy. + Unshelve. + all: lia. Qed. Lemma n_cup_unswapped_split_stack_n_wire_top : forall {n0 n1} - (zx0 : ZX n0 n0) (zx1 : ZX n1 n1) prf1 prf2 prf3 prf4, - n_wire (n0 + n1) ↕ (zx0 ↕ zx1) ⟷ n_cup_unswapped (n0 + n1) - ∝ cast _ 0 prf1 prf2 ( - n_wire n1 ↕ (n_wire n0 ↕ zx0 ⟷ n_cup_unswapped n0) ↕ n_wire n1 ⟷ - cast _ 0 prf3 prf4 (n_wire n1 ↕ zx1 ⟷ n_cup_unswapped n1)). -Proof. - intros. - rewrite n_wire_add_stack_rev. - rewrite n_cup_unswapped_split_stack'. - easy. - Unshelve. - all: lia. + (zx0 : ZX n0 n0) (zx1 : ZX n1 n1) prf1 prf2 prf3 prf4, + n_wire (n0 + n1) ↕ (zx0 ↕ zx1) ⟷ n_cup_unswapped (n0 + n1) + ∝ cast _ 0 prf1 prf2 ( + n_wire n1 ↕ (n_wire n0 ↕ zx0 ⟷ n_cup_unswapped n0) ↕ n_wire n1 ⟷ + cast _ 0 prf3 prf4 (n_wire n1 ↕ zx1 ⟷ n_cup_unswapped n1)). +Proof. + intros. + rewrite n_wire_add_stack_rev. + rewrite n_cup_unswapped_split_stack'. + easy. + Unshelve. + all: lia. Qed. Lemma n_cup_unswapped_split_stack_n_wire_bot' : forall {n0 n1 ntop} - (zx0 : ZX n0 n0) (zx1 : ZX n1 n1) prf1 prf2 prf3 prf4 prf5 prf6, - cast ntop _ prf1 prf2 (zx0 ↕ zx1) ↕ n_wire (n1 + n0) ⟷ n_cup_unswapped (n1 + n0) - ∝ cast _ 0 prf3 prf4 ( - n_wire n0 ↕ (zx1 ↕ n_wire n1 ⟷ n_cup_unswapped n1) ↕ n_wire n0 ⟷ - cast _ 0 prf5 prf6 (zx0 ↕ n_wire n0 ⟷ n_cup_unswapped n0)). -Proof. - intros. - rewrite n_wire_add_stack. - rewrite n_cup_unswapped_split_stack'. - easy. - Unshelve. - all: lia. + (zx0 : ZX n0 n0) (zx1 : ZX n1 n1) prf1 prf2 prf3 prf4 prf5 prf6, + cast ntop _ prf1 prf2 (zx0 ↕ zx1) ↕ n_wire (n1 + n0) ⟷ n_cup_unswapped (n1 + n0) + ∝ cast _ 0 prf3 prf4 ( + n_wire n0 ↕ (zx1 ↕ n_wire n1 ⟷ n_cup_unswapped n1) ↕ n_wire n0 ⟷ + cast _ 0 prf5 prf6 (zx0 ↕ n_wire n0 ⟷ n_cup_unswapped n0)). +Proof. + intros. + rewrite n_wire_add_stack. + rewrite n_cup_unswapped_split_stack'. + easy. + Unshelve. + all: lia. Qed. Lemma n_cup_unswapped_split_stack_n_wire_top' : forall {n0 n1 ntop} - (zx0 : ZX n0 n0) (zx1 : ZX n1 n1) prf1 prf2 prf3 prf4 prf5 prf6, - n_wire (n1 + n0) ↕ cast ntop _ prf1 prf2 (zx0 ↕ zx1) ⟷ n_cup_unswapped (n1 + n0) - ∝ cast _ 0 prf3 prf4 ( - n_wire n1 ↕ (n_wire n0 ↕ zx0 ⟷ n_cup_unswapped n0) ↕ n_wire n1 ⟷ - cast _ 0 prf5 prf6 (n_wire n1 ↕ zx1 ⟷ n_cup_unswapped n1)). -Proof. - intros. - rewrite n_wire_add_stack. - rewrite n_cup_unswapped_split_stack. - easy. - Unshelve. - all: lia. + (zx0 : ZX n0 n0) (zx1 : ZX n1 n1) prf1 prf2 prf3 prf4 prf5 prf6, + n_wire (n1 + n0) ↕ cast ntop _ prf1 prf2 (zx0 ↕ zx1) ⟷ n_cup_unswapped (n1 + n0) + ∝ cast _ 0 prf3 prf4 ( + n_wire n1 ↕ (n_wire n0 ↕ zx0 ⟷ n_cup_unswapped n0) ↕ n_wire n1 ⟷ + cast _ 0 prf5 prf6 (n_wire n1 ↕ zx1 ⟷ n_cup_unswapped n1)). +Proof. + intros. + rewrite n_wire_add_stack. + rewrite n_cup_unswapped_split_stack. + easy. + Unshelve. + all: lia. Qed. Lemma n_cup_unswapped_grow_r_back : forall n prf1 prf2, - (— ↕ n_cup_unswapped (n) ↕ — ⟷ ⊃) - ∝ cast _ _ prf1 prf2 (n_cup_unswapped (S n)). + (— ↕ n_cup_unswapped (n) ↕ — ⟷ ⊃) + ∝ cast _ _ prf1 prf2 (n_cup_unswapped (S n)). Proof. - intros. - rewrite (n_cup_unswapped_grow_r n). - rewrite cast_compose_l. - rewrite cast_cast_eq, 2!cast_id_eq. - easy. - Unshelve. - all: lia. + intros. + rewrite (n_cup_unswapped_grow_r n). + rewrite cast_compose_l. + rewrite cast_cast_eq, 2!cast_id_eq. + easy. + Unshelve. + all: lia. Qed. Lemma n_cup_unswapped_grow_k_r_back : forall n k prf1 prf2 prf3 prf4, - (n_wire k ↕ n_cup_unswapped n ↕ n_wire k) - ⟷ cast (k + 0 + k) 0 prf1 prf2 (n_cup_unswapped k) - ∝ cast _ 0 prf3 prf4 (n_cup_unswapped (n + k)). -Proof. - intros. - rewrite n_cup_unswapped_grow_k_r. - rewrite (cast_compose_mid_contract _ (k + 0 + k)). - rewrite cast_cast_eq, cast_id_eq. - easy. - Unshelve. - all: lia. + (n_wire k ↕ n_cup_unswapped n ↕ n_wire k) + ⟷ cast (k + 0 + k) 0 prf1 prf2 (n_cup_unswapped k) + ∝ cast _ 0 prf3 prf4 (n_cup_unswapped (n + k)). +Proof. + intros. + rewrite n_cup_unswapped_grow_k_r. + rewrite (cast_compose_mid_contract _ (k + 0 + k)). + rewrite cast_cast_eq, cast_id_eq. + easy. + Unshelve. + all: lia. Qed. Lemma compose_n_wire_comm : forall {n m} (zx : ZX n m), - n_wire n ⟷ zx ∝ zx ⟷ n_wire m. + n_wire n ⟷ zx ∝ zx ⟷ n_wire m. Proof. - intros. - cleanup_zx; easy. + intros. + cleanup_zx; easy. Qed. Lemma compose_stack_n_wire_comm : forall {n0 m0 n1 m1} (zx0 : ZX n0 m0) (zx1 : ZX n1 m1), - zx0 ↕ n_wire n1 ⟷ (n_wire m0 ↕ zx1) ∝ n_wire n0 ↕ zx1 ⟷ (zx0 ↕ n_wire m1). + zx0 ↕ n_wire n1 ⟷ (n_wire m0 ↕ zx1) ∝ n_wire n0 ↕ zx1 ⟷ (zx0 ↕ n_wire m1). Proof. - intros. - cleanup_zx. - easy. + intros. + cleanup_zx. + easy. Qed. Lemma top_to_bottom_cup_flip : forall k, - top_to_bottom k ↕ n_wire k ⟷ n_cup_unswapped k - ∝ n_wire k ↕ top_to_bottom k ⟷ n_cup_unswapped k. -Proof. - destruct k; [prop_perm_eq|]. - induction k; - [ apply compose_simplify; [prop_perm_eq | easy] | ]. - rewrite top_to_bottom_grow_r at 1. - rewrite nwire_stack_distr_compose_r. - rewrite compose_assoc. - rewrite (n_wire_add_stack 2 k) at 2. - rewrite (n_cup_unswapped_split_stack' (n_wire k) ⨉ (n_wire 2) (n_wire k)). - rewrite swap_2cup_flip. - bundle_wires; cleanup_zx. - rewrite nwire_stack_distr_compose_l, nwire_stack_distr_compose_r. - rewrite compose_assoc. - rewrite n_cup_unswapped_grow_k_r_back. - rewrite (stack_assoc' (n_wire k)). - rewrite 2!cast_stack_l. - rewrite (stack_assoc _ ⨉), cast_cast_eq. - rewrite (cast_compose_mid_contract _ (S (S k) + S (S k))). - simpl_casts. - rewrite 2!cast_stack_distribute. - simpl_permlike_zx. - bundle_wires; rewrite cast_n_wire. - rewrite <- compose_assoc, compose_stack_n_wire_comm. - rewrite (n_wire_add_stack 1 (S k)) at 2. - rewrite compose_assoc. - rewrite (n_cup_unswapped_split_stack' (top_to_bottom (S k)) — (n_wire 1) (n_wire S k)). - rewrite IHk. - rewrite <- n_cup_unswapped_split_stack'. - rewrite <- compose_assoc. - rewrite <- stack_compose_distr. - rewrite <- wire_to_n_wire. - rewrite <- (top_to_bottom_grow_l k). - apply compose_simplify; [|easy]. - apply stack_simplify; [|easy]. - prop_perm_eq. - Unshelve. - all: lia. + top_to_bottom k ↕ n_wire k ⟷ n_cup_unswapped k + ∝ n_wire k ↕ top_to_bottom k ⟷ n_cup_unswapped k. +Proof. + destruct k; [prop_perm_eq|]. + induction k; + [ apply compose_simplify; [prop_perm_eq | easy] | ]. + rewrite top_to_bottom_grow_r at 1. + rewrite nwire_stack_distr_compose_r. + rewrite compose_assoc. + rewrite (n_wire_add_stack 2 k) at 2. + rewrite (n_cup_unswapped_split_stack' (n_wire k) ⨉ (n_wire 2) (n_wire k)). + rewrite swap_2cup_flip. + bundle_wires; cleanup_zx. + rewrite nwire_stack_distr_compose_l, nwire_stack_distr_compose_r. + rewrite compose_assoc. + rewrite n_cup_unswapped_grow_k_r_back. + rewrite (stack_assoc' (n_wire k)). + rewrite 2!cast_stack_l. + rewrite (stack_assoc _ ⨉), cast_cast_eq. + rewrite (cast_compose_mid_contract _ (S (S k) + S (S k))). + simpl_casts. + rewrite 2!cast_stack_distribute. + simpl_permlike_zx. + bundle_wires; rewrite cast_n_wire. + rewrite <- compose_assoc, compose_stack_n_wire_comm. + rewrite (n_wire_add_stack 1 (S k)) at 2. + rewrite compose_assoc. + rewrite (n_cup_unswapped_split_stack' (top_to_bottom (S k)) — (n_wire 1) (n_wire S k)). + rewrite IHk. + rewrite <- n_cup_unswapped_split_stack'. + rewrite <- compose_assoc. + rewrite <- stack_compose_distr. + rewrite <- wire_to_n_wire. + rewrite <- (top_to_bottom_grow_l k). + apply compose_simplify; [|easy]. + apply stack_simplify; [|easy]. + prop_perm_eq. + Unshelve. + all: lia. Qed. Lemma bottom_to_top_cup_flip : forall k, - bottom_to_top k ↕ n_wire k ⟷ n_cup_unswapped k - ∝ n_wire k ↕ bottom_to_top k ⟷ n_cup_unswapped k. -Proof. - intros k. - destruct k; [prop_perm_eq | ]. - induction k; - [ apply compose_simplify; [prop_perm_eq | easy] | ]. - rewrite bottom_to_top_grow_r at 1. - rewrite nwire_stack_distr_compose_r. - rewrite compose_assoc. - rewrite (n_wire_add_stack_rev 2 k) at 2. - rewrite (n_cup_unswapped_split_stack ⨉ (n_wire k) (n_wire k) (n_wire 2)). - rewrite swap_2cup_flip. - rewrite <- n_cup_unswapped_split_stack. - rewrite <- compose_assoc, <- n_wire_add_stack, compose_stack_n_wire_comm. - preplace (n_wire (2 + k)) with (n_wire (1 + (S k))) by easy. - rewrite (n_wire_add_stack_rev 1 (S k)). - rewrite compose_assoc. - rewrite (n_cup_unswapped_split_stack — (bottom_to_top (S k)) (n_wire S k) (n_wire 1)). - rewrite IHk. - rewrite <- n_cup_unswapped_split_stack. - rewrite <- compose_assoc. - apply compose_simplify; [|easy]. - rewrite <- stack_compose_distr. - apply stack_simplify; [prop_perm_eq | ]. - rewrite <- wire_to_n_wire. - rewrite cast_compose_r, cast_cast_eq. - rewrite (cast_compose_partial_contract_l _ (S k + 1)). - rewrite <- (bottom_to_top_grow_l k). - easy. - Unshelve. - all: lia. + bottom_to_top k ↕ n_wire k ⟷ n_cup_unswapped k + ∝ n_wire k ↕ bottom_to_top k ⟷ n_cup_unswapped k. +Proof. + intros k. + destruct k; [prop_perm_eq | ]. + induction k; + [ apply compose_simplify; [prop_perm_eq | easy] | ]. + rewrite bottom_to_top_grow_r at 1. + rewrite nwire_stack_distr_compose_r. + rewrite compose_assoc. + rewrite (n_wire_add_stack_rev 2 k) at 2. + rewrite (n_cup_unswapped_split_stack ⨉ (n_wire k) (n_wire k) (n_wire 2)). + rewrite swap_2cup_flip. + rewrite <- n_cup_unswapped_split_stack. + rewrite <- compose_assoc, <- n_wire_add_stack, compose_stack_n_wire_comm. + preplace (n_wire (2 + k)) with (n_wire (1 + (S k))) by easy. + rewrite (n_wire_add_stack_rev 1 (S k)). + rewrite compose_assoc. + rewrite (n_cup_unswapped_split_stack — (bottom_to_top (S k)) (n_wire S k) (n_wire 1)). + rewrite IHk. + rewrite <- n_cup_unswapped_split_stack. + rewrite <- compose_assoc. + apply compose_simplify; [|easy]. + rewrite <- stack_compose_distr. + apply stack_simplify; [prop_perm_eq | ]. + rewrite <- wire_to_n_wire. + rewrite cast_compose_r, cast_cast_eq. + rewrite (cast_compose_partial_contract_l _ (S k + 1)). + rewrite <- (bottom_to_top_grow_l k). + easy. + Unshelve. + all: lia. Qed. Lemma a_swap_cup_flip : forall k, - a_swap k ↕ n_wire k ⟷ n_cup_unswapped k - ∝ n_wire k ↕ a_swap k ⟷ n_cup_unswapped k. -Proof. - intros k. - destruct k; [prop_perm_eq|]. - simpl a_swap. - rewrite nwire_stack_distr_compose_r, compose_assoc. - rewrite (n_cup_unswapped_split_stack_n_wire_bot —). - rewrite top_to_bottom_cup_flip. - rewrite <- (cast_cast_eq _ _ (k + 1 + (k + 1)) 0 (1 + k + (1 + k)) 0). - preplace (— ↕ n_wire 1) with (n_wire 1 ↕ —) by prop_perm_eq. - rewrite <- (n_cup_unswapped_split_stack_n_wire_top (top_to_bottom k) —). - rewrite (cast_compose_mid_contract _ (1 + k + (1 + k))). - rewrite cast_n_cup_unswapped, cast_stack_distribute, cast_n_wire by lia. - rewrite <- compose_assoc, compose_stack_n_wire_comm, compose_assoc. - rewrite bottom_to_top_cup_flip. - rewrite <- compose_assoc. - apply compose_simplify; [|easy]. - rewrite <- stack_compose_distr. - rewrite nwire_removal_r. - apply stack_simplify; prop_perm_eq. - solve_modular_permutation_equalities. - Unshelve. - all: lia. + a_swap k ↕ n_wire k ⟷ n_cup_unswapped k + ∝ n_wire k ↕ a_swap k ⟷ n_cup_unswapped k. +Proof. + intros k. + destruct k; [prop_perm_eq|]. + simpl a_swap. + rewrite nwire_stack_distr_compose_r, compose_assoc. + rewrite (n_cup_unswapped_split_stack_n_wire_bot —). + rewrite top_to_bottom_cup_flip. + rewrite <- (cast_cast_eq _ _ (k + 1 + (k + 1)) 0 (1 + k + (1 + k)) 0). + preplace (— ↕ n_wire 1) with (n_wire 1 ↕ —) by prop_perm_eq. + rewrite <- (n_cup_unswapped_split_stack_n_wire_top (top_to_bottom k) —). + rewrite (cast_compose_mid_contract _ (1 + k + (1 + k))). + rewrite cast_n_cup_unswapped, cast_stack_distribute, cast_n_wire by lia. + rewrite <- compose_assoc, compose_stack_n_wire_comm, compose_assoc. + rewrite bottom_to_top_cup_flip. + rewrite <- compose_assoc. + apply compose_simplify; [|easy]. + rewrite <- stack_compose_distr. + rewrite nwire_removal_r. + apply stack_simplify; prop_perm_eq. + solve_modular_permutation_equalities. + Unshelve. + all: lia. Qed. Lemma n_swap_zxperm : forall n, - ZXperm n (n_swap n). + ZXperm n (n_swap n). Proof. - induction n; simpl; auto with zxperm_db. + induction n; simpl; auto with zxperm_db. Qed. #[export] Hint Resolve n_swap_zxperm : zxperm_db. Lemma perm_of_n_swap : forall n, - perm_of_zx (n_swap n) = fun k => if n <=? k then k else (n - S k)%nat. + perm_of_zx (n_swap n) = fun k => if n <=? k then k else (n - S k)%nat. Proof. - (* destruct n; [simpl; solve_modular_permutation_equalities|]. *) - induction n; simpl perm_of_zx; cleanup_perm_of_zx; - [|rewrite IHn]; solve_modular_permutation_equalities. + (* destruct n; [simpl; solve_modular_permutation_equalities|]. *) + induction n; simpl perm_of_zx; cleanup_perm_of_zx; + [|rewrite IHn]; solve_modular_permutation_equalities. Qed. #[export] Hint Rewrite perm_of_n_swap : perm_of_zx_cleanup_db. Lemma n_swap_cup_flip : forall k, - n_swap k ↕ n_wire k ⟷ n_cup_unswapped k - ∝ n_wire k ↕ n_swap k ⟷ n_cup_unswapped k. -Proof. - intros k. - (* destruct k; [prop_perm_eq|]. *) - induction k; - [prop_perm_eq | ]. - (* [apply compose_simplify; [prop_perm_eq | easy] | ]. *) - simpl (n_swap (S k)). - rewrite nwire_stack_distr_compose_r, compose_assoc. - rewrite (n_cup_unswapped_split_stack_n_wire_bot —), IHk. - preplace (— ↕ n_wire 1) with (n_wire 1 ↕ —) by prop_perm_eq. - rewrite <- (n_cup_unswapped_split_stack_n_wire_top' (n_swap k) —). - rewrite <- compose_assoc, compose_stack_n_wire_comm, compose_assoc. - rewrite bottom_to_top_cup_flip, <- compose_assoc. - apply compose_simplify; [|easy]. - rewrite <- stack_compose_distr, nwire_removal_l. - apply stack_simplify; [easy|]. - prop_perm_eq. - solve_modular_permutation_equalities. - Unshelve. - all: lia. + n_swap k ↕ n_wire k ⟷ n_cup_unswapped k + ∝ n_wire k ↕ n_swap k ⟷ n_cup_unswapped k. +Proof. + intros k. + (* destruct k; [prop_perm_eq|]. *) + induction k; + [prop_perm_eq | ]. + (* [apply compose_simplify; [prop_perm_eq | easy] | ]. *) + simpl (n_swap (S k)). + rewrite nwire_stack_distr_compose_r, compose_assoc. + rewrite (n_cup_unswapped_split_stack_n_wire_bot —), IHk. + preplace (— ↕ n_wire 1) with (n_wire 1 ↕ —) by prop_perm_eq. + rewrite <- (n_cup_unswapped_split_stack_n_wire_top' (n_swap k) —). + rewrite <- compose_assoc, compose_stack_n_wire_comm, compose_assoc. + rewrite bottom_to_top_cup_flip, <- compose_assoc. + apply compose_simplify; [|easy]. + rewrite <- stack_compose_distr, nwire_removal_l. + apply stack_simplify; [easy|]. + prop_perm_eq. + solve_modular_permutation_equalities. + Unshelve. + all: lia. Qed. @@ -1831,189 +1276,189 @@ Local Open Scope ZX_scope. Lemma n_yank_1_l_helper_helper : forall n, - (⊃) ⊤ ↕ n_wire n ⟷ (— ↕ n_wire (1 + n)) ∝ n_wire n ⟷ (⊂ ↕ n_wire n). + (⊃) ⊤ ↕ n_wire n ⟷ (— ↕ n_wire (1 + n)) ∝ n_wire n ⟷ (⊂ ↕ n_wire n). Proof. - intros n. - simpl. - rewrite nwire_removal_l, nwire_removal_r. - easy. + intros n. + simpl. + rewrite nwire_removal_l, nwire_removal_r. + easy. Qed. Lemma n_yank_1_l_helper : forall n prf1 prf2 prf3 prf4 prf5 prf6 prf7 prf8, - cast (S n + (n + n)) (1 + n + S n + S n) prf1 prf2 (n_wire (1 + n) ↕ (n_wire n ↕ ⊃ ↕ n_wire n) ⊤) - ⟷ (((cast (S n + S n) 2 prf3 prf4 (— ↕ n_cup_unswapped n ↕ —)) ⟷ ⊃) ↕ (— ↕ n_wire n)) - ∝ cast _ _ prf7 prf8 ( - (— ↕ (n_wire n ↕ n_wire n) ↕ n_wire n) ⟷ (cast (1 + (n + n) + n) (1 + 0 + n) prf5 prf6 (— ↕ (n_cup_unswapped n) ↕ n_wire n) ) - ⟷ (— ↕ ⊂ ↕ n_wire n) ⟷ (⊃ ↕ — ↕ n_wire n) - ). -Proof. - intros. - rewrite nwire_stack_distr_compose_r. - rewrite <- compose_assoc. - rewrite (cast_compose_mid_contract _ (2 + 1 + n)). - apply compose_simplify; [| - rewrite stack_assoc, cast_cast_eq, cast_id; easy]. - rewrite 2!stack_transpose, n_wire_transpose. - rewrite 2!stack_assoc', (stack_assoc — (n_wire n) (n_wire n)). - simpl_permlike_zx. - rewrite cast_stack_l, cast_cast_eq. - rewrite (stack_assoc _ ((⊃) ⊤) (n_wire n)). - rewrite (prop_iff_double_cast (S (n + n) + (0 + n)) (1 + (1 + (1 + n)))). - rewrite (cast_compose_mid_contract _ ((n + n) + 1 + (1 + (1 + n)))). - - rewrite 2!cast_cast_eq. - rewrite cast_stack_distribute. - rewrite (cast_stack_l (mTop':=2)). - rewrite (stack_assoc _ —). - rewrite 2!cast_cast_eq. - rewrite (cast_stack_distribute (o':= 1 + (1 + n))). - rewrite <- stack_compose_distr. + cast (S n + (n + n)) (1 + n + S n + S n) prf1 prf2 (n_wire (1 + n) ↕ (n_wire n ↕ ⊃ ↕ n_wire n) ⊤) + ⟷ (((cast (S n + S n) 2 prf3 prf4 (— ↕ n_cup_unswapped n ↕ —)) ⟷ ⊃) ↕ (— ↕ n_wire n)) + ∝ cast _ _ prf7 prf8 ( + (— ↕ (n_wire n ↕ n_wire n) ↕ n_wire n) ⟷ (cast (1 + (n + n) + n) (1 + 0 + n) prf5 prf6 (— ↕ (n_cup_unswapped n) ↕ n_wire n) ) + ⟷ (— ↕ ⊂ ↕ n_wire n) ⟷ (⊃ ↕ — ↕ n_wire n) + ). +Proof. + intros. + rewrite nwire_stack_distr_compose_r. + rewrite <- compose_assoc. + rewrite (cast_compose_mid_contract _ (2 + 1 + n)). + apply compose_simplify; [| + rewrite stack_assoc, cast_cast_eq, cast_id; easy]. + rewrite 2!stack_transpose, n_wire_transpose. + rewrite 2!stack_assoc', (stack_assoc — (n_wire n) (n_wire n)). + simpl_permlike_zx. + rewrite cast_stack_l, cast_cast_eq. + rewrite (stack_assoc _ ((⊃) ⊤) (n_wire n)). + rewrite (prop_iff_double_cast (S (n + n) + (0 + n)) (1 + (1 + (1 + n)))). + rewrite (cast_compose_mid_contract _ ((n + n) + 1 + (1 + (1 + n)))). + + rewrite 2!cast_cast_eq. + rewrite cast_stack_distribute. + rewrite (cast_stack_l (mTop':=2)). + rewrite (stack_assoc _ —). + rewrite 2!cast_cast_eq. + rewrite (cast_stack_distribute (o':= 1 + (1 + n))). + rewrite <- stack_compose_distr. + + simpl_permlike_zx. + rewrite n_yank_1_l_helper_helper. + rewrite <- (nwire_removal_r (— ↕ n_cup_unswapped n)). + rewrite stack_compose_distr. + rewrite compose_assoc. + apply compose_simplify; [ + bundle_wires; prop_perm_eq|]. + cleanup_zx. + enough (Hrw : — ↕ n_cup_unswapped n ↕ n_wire n ⟷ (— ↕ ⊂ ↕ n_wire n) + ∝ @cast (1 + (n + n) + n) (1 + 2 + n) (1 + (n + n) + 0 + n) (1 + 0 + 2 + n) + (ltac:(lia)) (ltac:(lia)) ( + — ↕ n_cup_unswapped n ↕ ⦰ ↕ n_wire n ⟷ (— ↕ ⦰ ↕ ⊂ ↕ n_wire n))). + - rewrite Hrw. + repeat rewrite <- stack_compose_distr. + rewrite 2!nwire_removal_l. + rewrite (stack_assoc' _ ⊂ (n_wire n)). + simpl_casts. + do 2 (apply stack_simplify; [|easy]). + apply stack_simplify; cleanup_zx; easy. - simpl_permlike_zx. - rewrite n_yank_1_l_helper_helper. - rewrite <- (nwire_removal_r (— ↕ n_cup_unswapped n)). - rewrite stack_compose_distr. - rewrite compose_assoc. - apply compose_simplify; [ - bundle_wires; prop_perm_eq|]. - cleanup_zx. - enough (Hrw : — ↕ n_cup_unswapped n ↕ n_wire n ⟷ (— ↕ ⊂ ↕ n_wire n) - ∝ @cast (1 + (n + n) + n) (1 + 2 + n) (1 + (n + n) + 0 + n) (1 + 0 + 2 + n) - (ltac:(lia)) (ltac:(lia)) ( - — ↕ n_cup_unswapped n ↕ ⦰ ↕ n_wire n ⟷ (— ↕ ⦰ ↕ ⊂ ↕ n_wire n))). - - rewrite Hrw. - repeat rewrite <- stack_compose_distr. - rewrite 2!nwire_removal_l. - rewrite (stack_assoc' _ ⊂ (n_wire n)). - simpl_casts. - do 2 (apply stack_simplify; [|easy]). - apply stack_simplify; cleanup_zx; easy. - - - cleanup_zx. simpl_permlike_zx. - rewrite (cast_compose_mid_contract _ (1 + 0 + n)), cast_id_eq. - apply compose_simplify; [|easy]. - rewrite cast_stack_l, cast_cast_eq, cast_id_eq. - easy. - - Unshelve. all: lia. -Qed. - + - cleanup_zx. simpl_permlike_zx. + rewrite (cast_compose_mid_contract _ (1 + 0 + n)), cast_id_eq. + apply compose_simplify; [|easy]. + rewrite cast_stack_l, cast_cast_eq, cast_id_eq. + easy. + + Unshelve. all: lia. +Qed. + Lemma n_yank_1_l : forall n, - (— ↕ n_wire n) ↕ ((n_cup_unswapped (S n)) ⊤) ⟷ zx_inv_associator ⟷ ((n_cup_unswapped (S n)) ↕ (— ↕ n_wire n)) - ∝ — ↕ (n_wire n ↕ ((n_cup_unswapped n) ⊤) ⟷ zx_inv_associator ⟷ ((n_cup_unswapped n) ↕ n_wire n)). -Proof. - intros n. - rewrite n_cup_unswapped_grow_l at 1. - rewrite compose_transpose. - rewrite n_cup_unswapped_grow_r. - unfold zx_inv_associator. - simpl_permlike_zx. - (* bundle_wires. *) - rewrite nwire_stack_distr_compose_l. - rewrite (cast_compose_mid_contract _ ((S n + (n + n)))). - rewrite (cast_stack_distribute). - simpl_permlike_zx. - rewrite compose_assoc. - simpl_casts. - rewrite n_yank_1_l_helper, cast_id_eq. - rewrite compose_assoc. - rewrite <- 3!stack_compose_distr. - rewrite yank_l. - simpl (n_wire (1 + n)). - rewrite (stack_assoc —). - rewrite cast_id_eq. - rewrite wire_to_n_wire, nwire_removal_r, <- wire_to_n_wire. - rewrite (stack_assoc —). - rewrite cast_id_eq. - rewrite <- n_wire_add_stack. - rewrite 2!nwire_removal_l. - (* rewrite <- (wire_stack_distr_compose_l _ _ _ (n_cup_unswapped n ↕ n_wire n) (n_wire n)). *) - replace (S n + (n + n))%nat with (1 + (n + (n + n)))%nat by easy. - rewrite (wire_stack_distr_compose_l (n + 0)). - simpl_casts. - rewrite wire_to_n_wire at 5. - rewrite <- n_wire_add_stack, nwire_removal_r. - rewrite (prop_iff_double_cast (1 + (n + 0)) (0 + (1 + n))). - rewrite 2!(cast_compose_mid_contract _ (1 + (n + n + n))). - rewrite 2!cast_cast_eq, 2!cast_id_eq. - easy. - Unshelve. all: try easy; auto with arith; lia. + (— ↕ n_wire n) ↕ ((n_cup_unswapped (S n)) ⊤) ⟷ zx_inv_associator ⟷ ((n_cup_unswapped (S n)) ↕ (— ↕ n_wire n)) + ∝ — ↕ (n_wire n ↕ ((n_cup_unswapped n) ⊤) ⟷ zx_inv_associator ⟷ ((n_cup_unswapped n) ↕ n_wire n)). +Proof. + intros n. + rewrite n_cup_unswapped_grow_l at 1. + rewrite compose_transpose. + rewrite n_cup_unswapped_grow_r. + unfold zx_inv_associator. + simpl_permlike_zx. + (* bundle_wires. *) + rewrite nwire_stack_distr_compose_l. + rewrite (cast_compose_mid_contract _ ((S n + (n + n)))). + rewrite (cast_stack_distribute). + simpl_permlike_zx. + rewrite compose_assoc. + simpl_casts. + rewrite n_yank_1_l_helper, cast_id_eq. + rewrite compose_assoc. + rewrite <- 3!stack_compose_distr. + rewrite yank_l. + simpl (n_wire (1 + n)). + rewrite (stack_assoc —). + rewrite cast_id_eq. + rewrite wire_to_n_wire, nwire_removal_r, <- wire_to_n_wire. + rewrite (stack_assoc —). + rewrite cast_id_eq. + rewrite <- n_wire_add_stack. + rewrite 2!nwire_removal_l. + (* rewrite <- (wire_stack_distr_compose_l _ _ _ (n_cup_unswapped n ↕ n_wire n) (n_wire n)). *) + replace (S n + (n + n))%nat with (1 + (n + (n + n)))%nat by easy. + rewrite (wire_stack_distr_compose_l (n + 0)). + simpl_casts. + rewrite wire_to_n_wire at 5. + rewrite <- n_wire_add_stack, nwire_removal_r. + rewrite (prop_iff_double_cast (1 + (n + 0)) (0 + (1 + n))). + rewrite 2!(cast_compose_mid_contract _ (1 + (n + n + n))). + rewrite 2!cast_cast_eq, 2!cast_id_eq. + easy. + Unshelve. all: try easy; auto with arith; lia. Qed. Lemma n_yank_l_unswapped : forall n {prf1 prf2}, - (n_wire n) ↕ ((n_cup_unswapped n) ⊤) ⟷ zx_inv_associator ⟷ ((n_cup_unswapped n) ↕ (n_wire n)) - ∝ cast _ _ prf1 prf2 (n_wire n). + (n_wire n) ↕ ((n_cup_unswapped n) ⊤) ⟷ zx_inv_associator ⟷ ((n_cup_unswapped n) ↕ (n_wire n)) + ∝ cast _ _ prf1 prf2 (n_wire n). Proof. - intros. - induction n. - - unfold zx_inv_associator, n_cup_unswapped. - prop_perm_eq. - - simpl (n_wire S n). - rewrite n_yank_1_l, IHn. - prop_perm_eq. - Unshelve. - all: auto with arith. + intros. + induction n. + - unfold zx_inv_associator, n_cup_unswapped. + prop_perm_eq. + - simpl (n_wire S n). + rewrite n_yank_1_l, IHn. + prop_perm_eq. + Unshelve. + all: auto with arith. Qed. Lemma compose_zx_inv_associator_r : forall {n0 n m o} (zx : ZX n0 (n + (m + o))) prf1 prf2, - zx ⟷ zx_inv_associator ∝ cast n0 (n + m + o) prf1 prf2 zx. + zx ⟷ zx_inv_associator ∝ cast n0 (n + m + o) prf1 prf2 zx. Proof. - intros. - rewrite (prop_iff_double_cast n0 (n + (m + o))). - rewrite (cast_compose_mid_contract _ (n + (m + o))). - rewrite cast_cast_eq, 2!cast_id_eq. - rewrite <- (nwire_removal_r zx) at 2. - apply compose_simplify; [easy | prop_perm_eq]. - Unshelve. - all: auto with arith. + intros. + rewrite (prop_iff_double_cast n0 (n + (m + o))). + rewrite (cast_compose_mid_contract _ (n + (m + o))). + rewrite cast_cast_eq, 2!cast_id_eq. + rewrite <- (nwire_removal_r zx) at 2. + apply compose_simplify; [easy | prop_perm_eq]. + Unshelve. + all: auto with arith. Qed. Lemma n_yank_l_unswapped': forall n {prf1 prf2}, - cast n (n+n+n) prf1 prf2 (n_wire n ↕ (n_cup_unswapped n) ⊤) ⟷ ((n_cup_unswapped n) ↕ (n_wire n)) - ∝ n_wire n. -Proof. - intros. - rewrite (prop_iff_double_cast (n + 0) (0 + n)). - rewrite <- n_yank_l_unswapped. - rewrite compose_zx_inv_associator_r. - rewrite (cast_compose_mid_contract _ (n + n + n)), cast_id_eq, cast_cast_eq. - easy. - Unshelve. - all: auto with arith. + cast n (n+n+n) prf1 prf2 (n_wire n ↕ (n_cup_unswapped n) ⊤) ⟷ ((n_cup_unswapped n) ↕ (n_wire n)) + ∝ n_wire n. +Proof. + intros. + rewrite (prop_iff_double_cast (n + 0) (0 + n)). + rewrite <- n_yank_l_unswapped. + rewrite compose_zx_inv_associator_r. + rewrite (cast_compose_mid_contract _ (n + n + n)), cast_id_eq, cast_cast_eq. + easy. + Unshelve. + all: auto with arith. Qed. Lemma n_swap_grow_r' : forall n prf1 prf2, - n_swap (S n) ∝ top_to_bottom (S n) ⟷ cast (S n) (S n) prf1 prf2 (n_swap n ↕ —). + n_swap (S n) ∝ top_to_bottom (S n) ⟷ cast (S n) (S n) prf1 prf2 (n_swap n ↕ —). Proof. - intros. - prop_perm_eq. - solve_modular_permutation_equalities. + intros. + prop_perm_eq. + solve_modular_permutation_equalities. Qed. Lemma n_swap_n_swap : forall n, - n_swap n ⟷ n_swap n ∝ n_wire n. + n_swap n ⟷ n_swap n ∝ n_wire n. Proof. - prop_perm_eq. - solve_modular_permutation_equalities. + prop_perm_eq. + solve_modular_permutation_equalities. Qed. #[export] Hint Rewrite n_swap_n_swap : perm_inv_db. Lemma n_cup_unswapped_n_swap_n_swap : forall n, - n_swap n ↕ n_swap n ⟷ n_cup_unswapped n ∝ n_cup_unswapped n. -Proof. - intros n. - rewrite <- (nwire_removal_l (n_swap n)) at 1. - rewrite <- (nwire_removal_r (n_swap n)) at 2. - rewrite stack_compose_distr. - rewrite compose_assoc, n_swap_cup_flip, <- compose_assoc. - rewrite <- stack_compose_distr, n_swap_n_swap. - rewrite nwire_removal_l, <- n_wire_add_stack, nwire_removal_l. - easy. + n_swap n ↕ n_swap n ⟷ n_cup_unswapped n ∝ n_cup_unswapped n. +Proof. + intros n. + rewrite <- (nwire_removal_l (n_swap n)) at 1. + rewrite <- (nwire_removal_r (n_swap n)) at 2. + rewrite stack_compose_distr. + rewrite compose_assoc, n_swap_cup_flip, <- compose_assoc. + rewrite <- stack_compose_distr, n_swap_n_swap. + rewrite nwire_removal_l, <- n_wire_add_stack, nwire_removal_l. + easy. Qed. Lemma n_cup_inv_n_swap_n_wire' : forall n, n_cup n ∝ n_wire n ↕ n_swap n ⟷ n_cup_unswapped n. @@ -2024,115 +1469,115 @@ Proof. Qed. Lemma n_swap_transpose : forall n, - (n_swap n) ⊤ ∝ n_swap n. + (n_swap n) ⊤ ∝ n_swap n. Proof. - intros n. - prop_perm_eq. - perm_eq_by_WF_inv_inj (perm_of_zx (n_swap n)) n; - [apply zxperm_permutation | apply perm_of_zx_WF - | rewrite perm_of_transpose_is_linv | ]; - auto with zxperm_db. - cleanup_perm_of_zx. - solve_modular_permutation_equalities. + intros n. + prop_perm_eq. + perm_eq_by_WF_inv_inj (perm_of_zx (n_swap n)) n; + [apply zxperm_permutation | apply perm_of_zx_WF + | rewrite perm_of_transpose_is_linv | ]; + auto with zxperm_db. + cleanup_perm_of_zx. + solve_modular_permutation_equalities. Qed. Lemma n_yank_l : forall n {prf1 prf2}, - cast n (n + n + n) prf1 prf2 (n_wire n ↕ n_cap n) - ⟷ (n_cup n ↕ n_wire n) ∝ n_wire n. -Proof. - intros. - unfold n_cap. - rewrite n_cup_inv_n_swap_n_wire' at 2. - unfold n_cup. - (* rewrite n_cup_inv_n_swap_n_wire. *) - simpl. - rewrite n_wire_transpose, n_swap_transpose. - rewrite nwire_stack_distr_compose_l. - rewrite stack_assoc'. - rewrite (cast_compose_mid_contract n (n + n + n) (n + n + n)). - rewrite cast_cast_eq, cast_id_eq. - rewrite nwire_stack_distr_compose_r. - rewrite <- compose_assoc, (compose_assoc _ _ (_ ↕ n_swap n ↕ _)). - rewrite <- 2!stack_compose_distr, n_swap_n_swap, nwire_removal_l. - rewrite <- 2!n_wire_add_stack, nwire_removal_r. - apply n_yank_l_unswapped'. - Unshelve. - all: lia. + cast n (n + n + n) prf1 prf2 (n_wire n ↕ n_cap n) + ⟷ (n_cup n ↕ n_wire n) ∝ n_wire n. +Proof. + intros. + unfold n_cap. + rewrite n_cup_inv_n_swap_n_wire' at 2. + unfold n_cup. + (* rewrite n_cup_inv_n_swap_n_wire. *) + simpl. + rewrite n_wire_transpose, n_swap_transpose. + rewrite nwire_stack_distr_compose_l. + rewrite stack_assoc'. + rewrite (cast_compose_mid_contract n (n + n + n) (n + n + n)). + rewrite cast_cast_eq, cast_id_eq. + rewrite nwire_stack_distr_compose_r. + rewrite <- compose_assoc, (compose_assoc _ _ (_ ↕ n_swap n ↕ _)). + rewrite <- 2!stack_compose_distr, n_swap_n_swap, nwire_removal_l. + rewrite <- 2!n_wire_add_stack, nwire_removal_r. + apply n_yank_l_unswapped'. + Unshelve. + all: lia. Qed. Lemma n_yank_r : forall n {prf1 prf2 prf3 prf4}, - cast n n prf3 prf4 (cast n (n + (n + n)) prf1 prf2 (n_cap n ↕ n_wire n) - ⟷ (n_wire n ↕ n_cup n)) ∝ n_wire n. -Proof. - intros. - apply transpose_diagrams. - rewrite cast_transpose, compose_transpose. - rewrite (cast_compose_mid_contract _ (n+n+n)). - rewrite cast_transpose, cast_cast_eq, cast_id_eq. - rewrite 2!stack_transpose. - rewrite n_wire_transpose. - unfold n_cap. - rewrite Proportional.transpose_involutive. - apply n_yank_l. - Unshelve. - all: lia. + cast n n prf3 prf4 (cast n (n + (n + n)) prf1 prf2 (n_cap n ↕ n_wire n) + ⟷ (n_wire n ↕ n_cup n)) ∝ n_wire n. +Proof. + intros. + apply transpose_diagrams. + rewrite cast_transpose, compose_transpose. + rewrite (cast_compose_mid_contract _ (n+n+n)). + rewrite cast_transpose, cast_cast_eq, cast_id_eq. + rewrite 2!stack_transpose. + rewrite n_wire_transpose. + unfold n_cap. + rewrite Proportional.transpose_involutive. + apply n_yank_l. + Unshelve. + all: lia. Qed. Lemma zx_triangle_1 : forall n, - zx_inv_right_unitor ⟷ (n_wire n ↕ n_cap n) ⟷ zx_inv_associator - ⟷ (n_cup n ↕ n_wire n) ⟷ zx_left_unitor ∝ n_wire n. -Proof. - intros. - unfold zx_inv_right_unitor. - unfold zx_inv_associator. - unfold zx_left_unitor. - simpl_casts. cleanup_zx. - rewrite cast_compose_l. - simpl_casts. cleanup_zx. - rewrite cast_compose_r. - cleanup_zx. simpl_casts. - rewrite n_yank_l. - reflexivity. + zx_inv_right_unitor ⟷ (n_wire n ↕ n_cap n) ⟷ zx_inv_associator + ⟷ (n_cup n ↕ n_wire n) ⟷ zx_left_unitor ∝ n_wire n. +Proof. + intros. + unfold zx_inv_right_unitor. + unfold zx_inv_associator. + unfold zx_left_unitor. + simpl_casts. cleanup_zx. + rewrite cast_compose_l. + simpl_casts. cleanup_zx. + rewrite cast_compose_r. + cleanup_zx. simpl_casts. + rewrite n_yank_l. + reflexivity. Qed. Lemma zx_triangle_2 : forall n, - zx_inv_left_unitor ⟷ (n_cap n ↕ n_wire n) ⟷ zx_associator - ⟷ (n_wire n ↕ n_cup n) ⟷ zx_right_unitor ∝ n_wire n. -Proof. - intros. - unfold zx_inv_left_unitor. - unfold zx_associator. - unfold zx_right_unitor. - simpl_casts. cleanup_zx. - rewrite cast_compose_r. - simpl_casts. cleanup_zx. - rewrite cast_compose_r. - simpl_casts. cleanup_zx. - rewrite n_yank_r. - reflexivity. + zx_inv_left_unitor ⟷ (n_cap n ↕ n_wire n) ⟷ zx_associator + ⟷ (n_wire n ↕ n_cup n) ⟷ zx_right_unitor ∝ n_wire n. +Proof. + intros. + unfold zx_inv_left_unitor. + unfold zx_associator. + unfold zx_right_unitor. + simpl_casts. cleanup_zx. + rewrite cast_compose_r. + simpl_casts. cleanup_zx. + rewrite cast_compose_r. + simpl_casts. cleanup_zx. + rewrite n_yank_r. + reflexivity. Qed. #[export] Instance ZXCompactClosedCategory : CompactClosedCategory nat := { - dual n := n; - unit n := n_cap n; - counit n := n_cup n; - triangle_1 := zx_triangle_1; - triangle_2 := zx_triangle_2; + dual n := n; + unit n := n_cap n; + counit n := n_cup n; + triangle_1 := zx_triangle_1; + triangle_2 := zx_triangle_2; }. #[export] Instance ZXDaggerMonoidalCategory : DaggerMonoidalCategory nat := { - dagger_compat := @zx_dagger_compat; - - associator_unitary_r := @zx_associator_unitary_r; - associator_unitary_l := @zx_associator_unitary_l; - left_unitor_unitary_r := @zx_left_unitor_unitary_r; - left_unitor_unitary_l := @zx_left_unitor_unitary_l; - right_unitor_unitary_r := @zx_right_unitor_unitary_r; - right_unitor_unitary_l := @zx_right_unitor_unitary_l; + dagger_compat := @zx_dagger_compat; + + associator_unitary_r := @zx_associator_unitary_r; + associator_unitary_l := @zx_associator_unitary_l; + left_unitor_unitary_r := @zx_left_unitor_unitary_r; + left_unitor_unitary_l := @zx_left_unitor_unitary_l; + right_unitor_unitary_r := @zx_right_unitor_unitary_r; + right_unitor_unitary_l := @zx_right_unitor_unitary_l; }. #[export] Instance ZXDaggerBraidedMonoidalCategory : DaggerBraidedMonoidalCategory nat := {}. -#[export] Instance ZXDaggerSymmetricMonoidalCategory : DaggerSymmetricMonoidalCategory nat := {}. \ No newline at end of file +#[export] Instance ZXDaggerSymmetricMonoidalCategory : DaggerSymmetricMonoidalCategory nat := {}. \ No newline at end of file