From 9081860ff5f064f6e866e0e817532d4ac349df22 Mon Sep 17 00:00:00 2001 From: William Spencer Date: Mon, 20 Jan 2025 22:02:27 -0600 Subject: [PATCH] Fixes for QuantumLib 1.6.0 and performance improvements (#61) --- SQIR/Equivalences.v | 471 +++++++++------- SQIR/ExtractionGateSet.v | 12 +- SQIR/GateDecompositions.v | 400 +++++++------- SQIR/UnitaryOps.v | 892 +++++++++++++++++++------------ VOQC/ConnectivityGraph.v | 2 +- VOQC/Layouts.v | 6 + VOQC/Main.v | 22 +- VOQC/MappingValidation.v | 33 +- VOQC/SwapRoute.v | 42 +- VOQC/UnitaryListRepresentation.v | 58 +- coq-sqir.opam | 4 +- coq-voqc.opam | 12 +- 12 files changed, 1087 insertions(+), 867 deletions(-) diff --git a/SQIR/Equivalences.v b/SQIR/Equivalences.v index 517bec3..0b8fc18 100644 --- a/SQIR/Equivalences.v +++ b/SQIR/Equivalences.v @@ -12,9 +12,9 @@ Lemma ID_equiv_SKIP : forall dim n, n < dim -> @ID dim n ≡ SKIP. Proof. intros dim n WT. unfold uc_equiv. - autorewrite with eval_db. - gridify. reflexivity. - lia. + autorewrite with eval_db; [|lia]. + bdestruct_all. + rewrite 2!id_kron; f_equal; unify_pows_two. Qed. Lemma SKIP_id_l : forall {dim} (c : base_ucom dim), SKIP; c ≡ c. @@ -71,7 +71,7 @@ Proof. unfold uc_equiv. simpl; autorewrite with eval_db. gridify. - replace (σx × σx) with (I 2) by solve_matrix. + rewrite MmultXX. reflexivity. Qed. @@ -79,9 +79,9 @@ Lemma H_H_id : forall {dim} q, H q; H q ≡ @ID dim q. Proof. intros dim q. unfold uc_equiv. - simpl; autorewrite with eval_db. + simpl; autorewrite with eval_db. gridify. - replace (hadamard × hadamard) with (I 2) by solve_matrix. + rewrite MmultHH. reflexivity. Qed. @@ -112,13 +112,16 @@ Lemma CNOT_CNOT_id : forall {dim} m n, Proof. intros dim m n Hm Hn Hneq. unfold uc_equiv. - simpl; autorewrite with eval_db. - 2: lia. - gridify. - all: Qsimpl. - all: repeat rewrite <- kron_plus_distr_r; - repeat rewrite <- kron_plus_distr_l. - all: rewrite Mplus10; reflexivity. + apply equal_on_basis_states_implies_equal; [auto_wf..|]. + intros f. + simpl. + rewrite denote_SKIP by lia. + rewrite Mmult_1_l by auto_wf. + rewrite Mmult_assoc. + rewrite 2!f_to_vec_CNOT by easy. + rewrite update_index_eq, update_index_neq, update_twice_eq by congruence. + rewrite xorb_assoc, xorb_nilpotent, xorb_false_r. + now rewrite update_same by easy. Qed. Lemma U_V_comm : forall {dim} (m n : nat) (U V : base_Unitary 1), @@ -156,51 +159,86 @@ Proof. rewrite pad_ctrl_ctrl_commutes; auto with wf_db. Qed. +(* FIXME: Move *) + Lemma H_comm_Z : forall {dim} q, @H dim q; SQIR.Z q ≡ X q; H q. Proof. - intros. + intros. unfold uc_equiv. - simpl; autorewrite with eval_db. - gridify. - replace (σz × hadamard) with (hadamard × σx) by solve_matrix. - reflexivity. + simpl. + bdestruct (q q3 -> @CNOT dim q1 q2; H q2; CNOT q2 q3; H q2 ≡ H q2; CNOT q2 q3; H q2; CNOT q1 q2. Proof. - intros dim q1 q2 q3 H. + intros dim q1 q2 q3 H13. unfold uc_equiv. - simpl; autorewrite with eval_db. - gridify; trivial. (* slow *) - all: replace (hadamard × (∣1⟩⟨1∣ × (hadamard × σx))) with - (σx × (hadamard × (∣1⟩⟨1∣ × hadamard))) by solve_matrix; - replace (hadamard × (∣0⟩⟨0∣ × (hadamard × σx))) with - (σx × (hadamard × (∣0⟩⟨0∣ × hadamard))) by solve_matrix; - reflexivity. + simpl. + bdestruct (q1 (f n). + now destruct (f q), (g q), (g n). + - rewrite (f_to_vec_proj_neq) by (rewrite ?update_index_neq; congruence). + rewrite (f_to_vec_proj_eq) by easy. + rewrite f_to_vec_CNOT by easy. + rewrite f_to_vec_proj_neq by (rewrite ?update_index_neq; congruence). + now Msimpl. + - rewrite 2!f_to_vec_proj_neq by (rewrite ?update_index_neq; congruence). + now Msimpl. Qed. diff --git a/SQIR/ExtractionGateSet.v b/SQIR/ExtractionGateSet.v index 6886042..f3ede90 100644 --- a/SQIR/ExtractionGateSet.v +++ b/SQIR/ExtractionGateSet.v @@ -294,13 +294,11 @@ Proof. unfold uc_eval. simpl. rewrite Ropp_0. apply f_equal. - unfold rotation. - solve_matrix; autorewrite with Cexp_db trig_db R_db; lca. + lma'; autorewrite with Cexp_db trig_db R_db; lca. (* U_U2 *) unfold uc_eval. simpl. apply f_equal. - unfold rotation. - solve_matrix; autorewrite with Cexp_db trig_db R_db; lca. + lma'; autorewrite with Cexp_db trig_db R_db; lca. (* U_CU1 *) rewrite invert_control. unfold uc_eval. simpl. @@ -308,8 +306,7 @@ Proof. unfold uc_equiv. simpl. rewrite Ropp_0. apply f_equal. - unfold rotation. - solve_matrix; autorewrite with Cexp_db trig_db R_db; lca. + lma'; autorewrite with Cexp_db trig_db R_db; lca. split; intro; invert_is_fresh; repeat constructor; auto. (* U_CH *) rewrite invert_control. @@ -333,8 +330,7 @@ Proof. unfold uc_equiv. simpl. rewrite Ropp_0. apply f_equal. - unfold rotation. - solve_matrix; autorewrite with Cexp_db trig_db R_db; lca. + lma'; autorewrite with Cexp_db trig_db R_db; lca. split; intro; invert_is_fresh; repeat constructor; auto. rewrite <- is_fresh_invert. rewrite <- 2 fresh_control. diff --git a/SQIR/GateDecompositions.v b/SQIR/GateDecompositions.v index 26654cf..2718b17 100644 --- a/SQIR/GateDecompositions.v +++ b/SQIR/GateDecompositions.v @@ -73,70 +73,89 @@ Definition C4X {dim} (a b c d e : nat) : base_ucom dim := Local Transparent H U3. Lemma CH_is_control_H : forall dim a b, @CH dim a b ≡ control a (H b). Proof. - assert (aux1 : rotation (- (PI / 4)) 0 0 × (σx × rotation (PI / 4) 0 0) = - Cexp (PI / 2) .* (rotation (PI / 2 / 2) 0 0 × - (σx × (rotation (- (PI / 2) / 2) 0 (- PI / 2) × - (σx × phase_shift (PI / 2)))))). - { (* messy :) should make better automation -KH *) - solve_matrix; repeat rewrite RIneq.Ropp_div; autorewrite with Cexp_db trig_db; - repeat rewrite RtoC_opp; field_simplify_eq; try nonzero. - replace (((R1 + R1)%R, (R0 + R0)%R) * cos (PI / 4 / 2) * sin (PI / 4 / 2)) - with (2 * sin (PI / 4 / 2) * cos (PI / 4 / 2)) by lca. - 2: replace (((- (R1 + R1))%R, (- (R0 + R0))%R) * Ci * Ci * - cos (PI / 2 / 2 / 2) * sin (PI / 2 / 2 / 2)) - with (2 * sin (PI / 2 / 2 / 2) * cos (PI / 2 / 2 / 2)) by lca. - 3: replace (- sin (PI / 4 / 2) * sin (PI / 4 / 2) + - cos (PI / 4 / 2) * cos (PI / 4 / 2)) - with (cos (PI / 4 / 2) * cos (PI / 4 / 2) - - sin (PI / 4 / 2) * sin (PI / 4 / 2)) by lca. - 3: replace ((R1 + R1)%R, (R0 + R0)%R) with (RtoC 2) by lca. - 4: replace (((- (R1 + R1))%R, (- (R0 + R0))%R) * sin (PI / 4 / 2) * - cos (PI / 4 / 2)) - with (- (2 * sin (PI / 4 / 2) * cos (PI / 4 / 2))) by lca. - 4: replace (- Ci * Ci * sin (PI / 2 / 2 / 2) * sin (PI / 2 / 2 / 2) + - Ci * Ci * cos (PI / 2 / 2 / 2) * cos (PI / 2 / 2 / 2)) - with (- (cos (PI / 2 / 2 / 2) * cos (PI / 2 / 2 / 2) - - sin (PI / 2 / 2 / 2) * sin (PI / 2 / 2 / 2))) by lca. - all: autorewrite with RtoC_db; apply c_proj_eq; simpl; autorewrite with R_db; - try rewrite <- Rminus_unfold; try rewrite <- cos_2a; try rewrite <- sin_2a; - replace (2 * (PI * / 4 * / 2))%R with (PI * / 4)%R by lra; - replace (2 * (PI * / 2 * / 2 * / 2))%R with (PI * / 4)%R by lra; - try symmetry; try apply sin_cos_PI4; try reflexivity; - try rewrite Ropp_eq_compat with (cos (PI * / 4)) (sin (PI * / 4)); - try reflexivity; try symmetry; try apply sin_cos_PI4; try reflexivity. - all: autorewrite with RtoC_db; rewrite <- sin_2a; rewrite <- cos_2a; - replace (2 * (PI / 4 / 2))%R with (PI / 4)%R by lra; - replace (2 * (PI / 2 / 2 / 2))%R with (PI / 4)%R by lra; - autorewrite with trig_db; reflexivity. } - assert (aux2 : rotation (- (PI / 4)) 0 0 × rotation (PI / 4) 0 0 = - rotation (PI / 2 / 2) 0 0 × - (rotation (- (PI / 2) / 2) 0 (- PI / 2) × phase_shift (PI / 2))). - { assert (aux: forall x, (x * x = x²)%R) by (unfold Rsqr; reflexivity). - solve_matrix; repeat rewrite RIneq.Ropp_div; autorewrite with Cexp_db trig_db; - repeat rewrite RtoC_opp; field_simplify_eq; try nonzero; - autorewrite with RtoC_db; repeat rewrite aux; - replace (PI / 2 / 2 / 2)%R with (PI / 4 / 2)%R by lra; reflexivity. } - intros dim a b. - unfold H, CH, uc_equiv. + intros. simpl. - autorewrite with eval_db. - gridify; trivial; autorewrite with ket_db bra_db. (* slow! *) - - rewrite Rminus_0_r, Rplus_0_l, Rplus_0_r. - apply f_equal2. - + rewrite <- Mscale_kron_dist_l. - rewrite <- Mscale_kron_dist_r. - do 2 (apply f_equal2; try reflexivity). - apply aux1. - + rewrite aux2. - reflexivity. - - rewrite Rminus_0_r, Rplus_0_l, Rplus_0_r. - apply f_equal2. - + rewrite <- 3 Mscale_kron_dist_l. - rewrite <- Mscale_kron_dist_r. - do 4 (apply f_equal2; try reflexivity). - apply aux1. - + rewrite aux2. - reflexivity. + unfold uc_equiv. + cbn. + rewrite !Ry_rotation. + (* rewrite <- !denote_Ry. *) + bdestruct (a compute_matrix A; compute_matrix B end. + rewrite Rdiv_0_l, sin_0, Cexp_0, cos_0. + Csimpl. + rewrite !Rplus_0_l, Cexp_0. + lma'. + } + rewrite pad_u_mmult by auto_wf. + rewrite Ry_rotation(* , <- denote_Ry *). + rewrite Rplus_0_l. + replace (- PI / 2)%R with (- (PI / 2))%R by lra. + rewrite phase_shift_rotation, <- denote_Rz. + apply equal_on_basis_states_implies_equal; [auto_wf..|]. + intros f. + rewrite !Mmult_assoc. + rewrite Rplus_0_r, Rminus_0_r. + f_to_vec_simpl. + rewrite Ropp_div. + replace (PI / 2 / 2)%R with (PI / 4)%R by lra. + rewrite 2!f_to_vec_pad_u_generic by (lia + auto_wf). + cbn. + rewrite Ropp_div. + replace (PI / 4 / 2)%R with (PI / 8)%R by lra. + rewrite update_index_eq. + distribute_plus; distribute_scale. + rewrite sin_neg, RtoC_opp, Copp_involutive. + f_to_vec_simpl. + rewrite xorb_false_l, xorb_true_l. + rewrite !f_to_vec_pad_u_generic by (lia + auto_wf). + cbn. + rewrite Ropp_div. + replace (PI / 4 / 2)%R with (PI / 8)%R by lra. + rewrite !update_index_eq. + rewrite !update_twice_eq. + rewrite !cos_neg, !sin_neg, RtoC_opp, Copp_involutive. + rewrite !Cexp_bool_mul. + autorewrite with Cexp_db C_db. + prep_matrix_equivalence. + intros i j Hi Hj. + unfold Mplus, scale. + destruct (f a), (f b); cbn; + (* autorewrite with C_db; *) + cancel_terms 1; + rewrite <- ?Copp_mult_distr_r, <- ?Copp_mult_distr_l. + - rewrite !Cmult_assoc. + rewrite sin_sin_PI8. + rewrite <- sin_cos_PI4. + replace (PI / 4)%R with (2 * (PI / 8))%R by lra. + rewrite sin_2a. + lca. + - rewrite Ci2, <- Copp_mult_distr_l, Cmult_1_l, <- !Copp_mult_distr_r. + rewrite !Cmult_assoc. + rewrite sin_sin_PI8. + rewrite <- sin_cos_PI4. + replace (PI / 4)%R with (2 * (PI / 8))%R by lra. + rewrite sin_2a. + lca. + - rewrite <- Copp_mult_distr_r. + rewrite !(Cmult_assoc (sin (PI / 8))). + rewrite sin_sin_PI8. + rewrite <- sin_cos_PI4. + replace (PI / 4)%R with (2 * (PI / 8))%R by lra. + rewrite sin_2a. + lca. + - rewrite <- !Copp_mult_distr_r. + rewrite !(Cmult_assoc (sin (PI / 8))). + rewrite sin_sin_PI8. + rewrite <- sin_cos_PI4. + replace (PI / 4)%R with (2 * (PI / 8))%R by lra. + rewrite sin_2a. + lca. Qed. Local Opaque H U3. @@ -149,12 +168,10 @@ Proof. autorewrite with R_db. repeat rewrite phase_shift_rotation. rewrite phase_0. - bdestruct (b field_simplify r - end; try lra. + end. all: autorewrite with R_db C_db Cexp_db. all: group_Cexp. all: try match goal with | |- context [Cexp ?r] => field_simplify r end. all: replace (8 * PI / 8)%R with PI by lra. - all: autorewrite with R_db C_db Cexp_db. + all: autorewrite with R_db Cexp_db. all: rewrite Mscale_plus_distr_r. all: distribute_scale; group_radicals. all: lma. @@ -521,14 +541,11 @@ Proof. assert (H0 : uc_eval (@CNOT dim a b) = Zero \/ uc_eval (@CNOT dim a c) = Zero \/ uc_eval (@CNOT dim a d) = Zero). - { assert (H0 : a = b \/ a = c \/ a = d). - apply Classical_Prop.NNPP. - intro contra. contradict H. - apply fresh_CCX; repeat split; auto. - destruct H0 as [H0 | [H0 | H0]]; subst. - left. autorewrite with eval_db. gridify. - right. left. autorewrite with eval_db. gridify. - right. right. autorewrite with eval_db. gridify. } + { assert (H0 : a = b \/ a = c \/ a = d) by + (rewrite <- fresh_CCX in *; lia). + destruct H0 as [? | [? | ?]]; + [left|right; left|right; right]; + apply CNOT_ill_typed; lia. } destruct H0 as [H0 | [H0 | H0]]; rewrite H0; Msimpl_light; trivial. Qed. @@ -541,8 +558,12 @@ Proof. UnitarySem.uc_eval (@CNOT dim b d) = Zero \/ UnitarySem.uc_eval (@CNOT dim c d) = Zero). { rewrite <- uc_well_typed_CCX in H. - autorewrite with eval_db. - gridify; auto. } + bdestruct (b bool), @@ -879,9 +898,8 @@ Proof. rewrite Rz_Rz_add. reflexivity. unfold uc_equiv; simpl. - autorewrite with eval_db. - bdestruct_all. - Msimpl. reflexivity. + rewrite H_ill_typed by lia. + now Msimpl_light. Qed. Lemma RTX_PI : forall dim q, @RTX dim PI q ≡ X q. @@ -891,13 +909,11 @@ Proof. rewrite H_comm_Z. rewrite useq_assoc. rewrite H_H_id. - bdestruct (q b <-> is_fresh a (@RTX dim r b). @@ -1009,14 +1023,14 @@ Proof. - (* b1 = true, b2 = true, b3 = true *) repeat commute_proj2. rewrite Mmult_assoc. - apply f_equal2; auto. + apply f_equal2; [easy|]. rewrite <- (Mmult_assoc _ (proj b _ _)). repeat commute_proj2. rewrite Mmult_assoc. - apply f_equal2; auto. + apply f_equal2; [easy|]. rewrite <- (Mmult_assoc _ (proj c _ _)). repeat commute_proj2. - apply f_equal2; auto. + apply f_equal2; [easy|]. repeat rewrite <- (Mmult_assoc (uc_eval (X _))). repeat rewrite Hr1. repeat rewrite denote_ID. @@ -1252,7 +1266,7 @@ Proof. assumption. destruct H2 as [_ ?]. rewrite H2 in H1. - rewrite Mmult_1_l in H1; auto with wf_db. + rewrite Mmult_1_l in H1 by auto with wf_db. rewrite H1. distribute_scale. rewrite <- Cexp_add. @@ -1535,18 +1549,12 @@ Proof. uc_eval (@CNOT dim a c) = Zero \/ uc_eval (@CNOT dim a d) = Zero \/ uc_eval (@CNOT dim a e) = Zero). - { assert (H0 : a = b \/ a = c \/ a = d \/ a = e). - apply Classical_Prop.NNPP. - intro contra. contradict H. - apply Decidable.not_or in contra as [? contra]. - apply Decidable.not_or in contra as [? contra]. - apply Decidable.not_or in contra as [? contra]. - apply fresh_C3X; repeat split; auto. - destruct H0 as [H0 | [H0 | [H0 | H0]]]; subst. - left. autorewrite with eval_db. gridify. - right. left. autorewrite with eval_db. gridify. - right. right. left. autorewrite with eval_db. gridify. - right. right. right. autorewrite with eval_db. gridify. } + { assert (H0 : a = b \/ a = c \/ a = d \/ a = e) + by (rewrite <- fresh_C3X in H; lia). + destruct H0 as [H0 | [H0 | [H0 | H0]]]; + match type of H0 with ?a = ?a' => + rewrite (CNOT_ill_typed a a') by lia; auto + end. } destruct H0 as [H0 | [H0 | [H0 | H0]]]; rewrite H0; Msimpl_light; trivial. Qed. @@ -1562,37 +1570,19 @@ Proof. uc_eval (@CNOT dim c e) = Zero \/ uc_eval (@CNOT dim d e) = Zero). { rewrite <- uc_well_typed_C3X in H. - apply Classical_Prop.not_and_or in H as [H | H]. - left. - autorewrite with eval_db; gridify; auto. - apply Classical_Prop.not_and_or in H as [H | H]. - left. - autorewrite with eval_db; gridify; auto. - apply Classical_Prop.not_and_or in H as [H | H]. - right. left. - autorewrite with eval_db; gridify; auto. - apply Classical_Prop.not_and_or in H as [H | H]. - right. right. left. - autorewrite with eval_db; gridify; auto. - apply Classical_Prop.not_and_or in H as [H | H]. - left. - autorewrite with eval_db; gridify; auto. - apply Classical_Prop.not_and_or in H as [H | H]. - right. left. - autorewrite with eval_db; gridify; auto. - apply Classical_Prop.not_and_or in H as [H | H]. - right. right. left. - autorewrite with eval_db; gridify; auto. - apply Classical_Prop.not_and_or in H as [H | H]. - right. right. right. left. - autorewrite with eval_db; gridify; auto. - apply Classical_Prop.not_and_or in H as [H | H]. - right. right. right. right. left. - autorewrite with eval_db; gridify; auto. - right. right. right. right. right. - autorewrite with eval_db; gridify; auto. } - destruct H0 as [H0 | [H0 | [H0 | [H0 | [H0 | H0]]]]]; - rewrite H0; Msimpl_light; trivial. + bdestruct (b b2 -> proj q dim b1 × proj q dim b2 = Zero. -Proof. - intros dim q b1 b2 neq. - unfold proj, pad_u, pad. - gridify. - destruct b1; destruct b2; try contradiction; simpl; Qsimpl; reflexivity. -Qed. - -(* TODO: move to QuantumLib *) - -Lemma bra0_phase : forall ϕ, bra 0 × phase_shift ϕ = bra 0. -Proof. intros; solve_matrix. Qed. - -Lemma bra1_phase : forall ϕ, bra 1 × phase_shift ϕ = Cexp ϕ .* bra 1. -Proof. intros; solve_matrix. Qed. - -#[export] Hint Rewrite bra0_phase bra1_phase : bra_db. - -Lemma braketbra_same : forall x y, bra x × (ket x × bra y) = bra y. -Proof. intros. destruct x; destruct y; solve_matrix. Qed. - -Lemma braketbra_diff : forall x y z, (x + y = 1)%nat -> bra x × (ket y × bra z) = Zero. -Proof. intros. destruct x; destruct y; try lia; solve_matrix. Qed. - -#[export] Hint Rewrite braketbra_same braketbra_diff using lia : bra_db. - -(* Auxiliary proofs about the semantics of CU and TOFF *) -Lemma CU_correct : forall (dim : nat) θ ϕ λ c t, - (t < dim)%nat -> c <> t -> - uc_eval (CU θ ϕ λ c t) = proj c dim false .+ (proj c dim true) × (ueval_r dim t (U_R θ ϕ λ)). +Lemma f_to_vec_proj : forall f q n b, + (q < n)%nat -> + proj q n b × f_to_vec n f = + (if bool_dec (f q) b then C1 else C0) .* f_to_vec n f. Proof. intros. - unfold proj; simpl. - autorewrite with eval_db. - unfold pad_u, pad. - gridify. (* slow *) - all: clear. - all: autorewrite with M_db_light ket_db bra_db. - all: rewrite Mplus_comm; - repeat (apply f_equal2; try reflexivity). - (* A little messy because we need to apply trig identities; - goal #1 = goal #3 and goal #2 = goal #4 *) - - solve_matrix; autorewrite with R_db C_db RtoC_db Cexp_db trig_db; - try lca; field_simplify_eq; try nonzero; group_Cexp. - + simpl. try (rewrite Rplus_comm; setoid_rewrite sin2_cos2; easy). - try ( - rewrite Cplus_comm; unfold Cplus, Cmult; - autorewrite with R_db; simpl; - setoid_rewrite sin2_cos2; easy - ). - + try (simpl; rewrite Copp_mult_distr_l, Copp_mult_distr_r; - repeat rewrite <- Cmult_assoc; rewrite <- Cmult_plus_distr_l; - autorewrite with RtoC_db; rewrite Ropp_involutive; - setoid_rewrite sin2_cos2; rewrite Cmult_1_r; - apply f_equal; lra). - try ( - simpl; repeat rewrite <- Cmult_assoc; simpl; - rewrite <- Cmult_plus_distr_l; - unfold Cplus, Cmult; - autorewrite with R_db; simpl; - setoid_rewrite sin2_cos2; autorewrite with R_db; - unfold Cexp; apply f_equal2; [apply f_equal; lra|] - ). - apply f_equal; lra. - - rewrite <- Mscale_kron_dist_l. - repeat rewrite <- Mscale_kron_dist_r. - repeat (apply f_equal2; try reflexivity). - (* Note: These destructs shouldn't be necessary - weakness in destruct_m_eq'. - The mat_equiv branch takes a more principled approach (see lma there, port). *) - solve_matrix; destruct x0; try destruct x0; simpl; - autorewrite with R_db C_db Cexp_db trig_db; try lca; - rewrite RtoC_opp; field_simplify_eq; try nonzero; group_Cexp; - repeat rewrite <- Cmult_assoc. - + unfold Cminus. - rewrite Copp_mult_distr_r, <- Cmult_plus_distr_l. - apply f_equal2; [apply f_equal; lra|]. - try (autorewrite with RtoC_db). - try (rewrite <- Cminus_unfold; - unfold Cminus, Cmult; simpl; autorewrite with R_db; - apply c_proj_eq; simpl; autorewrite with R_db). - rewrite <- Rminus_unfold, <- cos_plus. - apply f_equal. try apply f_equal. try lra. lra. - + apply f_equal2; [apply f_equal; lra|]. - apply c_proj_eq; simpl; try lra. - R_field_simplify. - rewrite <- sin_2a. - apply f_equal; lra. - + rewrite Copp_mult_distr_r. - apply f_equal2; [apply f_equal; lra|]. - apply c_proj_eq; simpl; try lra. - R_field_simplify. - replace (-2)%R with (-(2))%R by lra. - repeat rewrite <- Ropp_mult_distr_l. - apply f_equal. - rewrite <- sin_2a. - apply f_equal; lra. - + rewrite Copp_mult_distr_r. - rewrite <- Cmult_plus_distr_l. - apply f_equal2; [apply f_equal; lra|]. - try (autorewrite with RtoC_db; - rewrite Rplus_comm; rewrite <- Rminus_unfold, <- cos_plus; - apply f_equal; apply f_equal; lra). - try ( - rewrite Cplus_comm; apply c_proj_eq; simpl; try lra; - autorewrite with R_db; rewrite <- Rminus_unfold; - rewrite <- cos_plus; apply f_equal; lra - ). - - solve_matrix; autorewrite with R_db C_db RtoC_db Cexp_db trig_db; try lca; - field_simplify_eq; try nonzero; group_Cexp. - + try (rewrite Rplus_comm; setoid_rewrite sin2_cos2; easy). - try ( - simpl; rewrite Cplus_comm; unfold Cplus, Cmult; - autorewrite with R_db; simpl; - setoid_rewrite sin2_cos2; easy - ). - + try (rewrite Copp_mult_distr_l, Copp_mult_distr_r; - repeat rewrite <- Cmult_assoc; rewrite <- Cmult_plus_distr_l; - autorewrite with RtoC_db; rewrite Ropp_involutive; - setoid_rewrite sin2_cos2; rewrite Cmult_1_r). - try ( - simpl; - repeat rewrite <- Cmult_assoc; simpl; - rewrite <- Cmult_plus_distr_l; - unfold Cplus, Cmult; - autorewrite with R_db; simpl; - setoid_rewrite sin2_cos2; autorewrite with R_db; - unfold Cexp; apply f_equal2; [apply f_equal; lra|] - ). - apply f_equal; lra. - - rewrite <- 3 Mscale_kron_dist_l. - repeat rewrite <- Mscale_kron_dist_r. - repeat (apply f_equal2; try reflexivity). - solve_matrix; destruct x; try destruct x; simpl; - autorewrite with R_db C_db Cexp_db trig_db; try lca; - try rewrite RtoC_opp; field_simplify_eq; try nonzero; group_Cexp; - repeat rewrite <- Cmult_assoc. - + unfold Cminus. - rewrite Copp_mult_distr_r, <- Cmult_plus_distr_l. - apply f_equal2; [apply f_equal; lra|]. - try (autorewrite with RtoC_db). - try ( - repeat rewrite <- Cmult_assoc; simpl; - unfold Cplus, Cmult; - autorewrite with R_db; simpl; - apply c_proj_eq; simpl; try lra - ). - rewrite <- Rminus_unfold, <- cos_plus. - apply f_equal. try apply f_equal. lra. - + apply f_equal2; [apply f_equal; lra|]. - apply c_proj_eq; simpl; try lra. - R_field_simplify. - rewrite <- sin_2a. - apply f_equal; lra. - + rewrite Copp_mult_distr_r. - apply f_equal2; [apply f_equal; lra|]. - apply c_proj_eq; simpl; try lra. - R_field_simplify. - replace (-2)%R with (-(2))%R by lra. - repeat rewrite <- Ropp_mult_distr_l. - apply f_equal. - rewrite <- sin_2a. - apply f_equal; lra. - + rewrite Copp_mult_distr_r. - rewrite <- Cmult_plus_distr_l. - apply f_equal2; [apply f_equal; lra|]. - try (autorewrite with RtoC_db; - rewrite Rplus_comm; rewrite <- Rminus_unfold, <- cos_plus; - apply f_equal; apply f_equal; lra); - try ( - apply c_proj_eq; simpl; try lra; - rewrite Rplus_comm; autorewrite with R_db; - rewrite <- Rminus_unfold, <- cos_plus; - apply f_equal; lra - ). -Qed. - -Lemma UR_not_WT : forall (dim a b : nat) r r0 r1, - ~ uc_well_typed (@uapp1 _ dim (U_R r r0 r1) b) -> - uc_eval (@CU dim r r0 r1 a b) = Zero. + destruct (bool_dec (f q) b). + - rewrite f_to_vec_proj_eq by easy. + now Msimpl. + - rewrite f_to_vec_proj_neq by easy. + now Msimpl. +Qed. + +Lemma f_to_vec_pad_u_generic : forall (n i : nat) A (f : nat -> bool), + (i < n)%nat -> WF_Matrix A -> + pad_u n i A × (f_to_vec n f) = + (if f i then A 0 1 else A 0 0)%nat .* f_to_vec n (update f i false) + .+ (if f i then A 1 1 else A 1 0)%nat .* f_to_vec n (update f i true). Proof. - intros dim a b r r0 r1 H. - simpl. unfold pad_u. - assert (@pad 1 b dim (rotation (r / 2) r0 0) = Zero). - { unfold pad. gridify. - assert (uc_well_typed (@uapp1 _ (b + 1 + d) (U_R r r0 r1) b)). - constructor; lia. - contradiction. } - rewrite H0. - Msimpl_light. - reflexivity. + intros n i A f Hi HA. + unfold pad_u, pad. + rewrite (f_to_vec_split 0 n i f Hi). + repad. + replace (i + 1 + x - 1 - i)%nat with x by lia. + Msimpl. + replace (A × ∣ Nat.b2n (f i) ⟩) with + ((if f i then A 0 1 else A 0 0)%nat .* ∣ Nat.b2n false ⟩ + .+ (if f i then A 1 1 else A 1 0)%nat .* ∣ Nat.b2n true ⟩) + by (destruct (f i); lma'). + restore_dims. + distribute_plus; distribute_scale. + f_equal; [unify_pows_two|..]; + (f_equal; [unify_pows_two|..]); + rewrite (f_to_vec_split 0 (i + 1 + x) i) by lia; + rewrite f_to_vec_update_oob by lia; + rewrite f_to_vec_shift_update_oob by lia; + rewrite update_index_eq; + do 2 f_equal; lia. Qed. -Lemma UR_not_fresh : forall (dim a b : nat) r r0 r1, - ~ is_fresh a (@uapp1 _ dim (U_R r r0 r1) b) -> - uc_eval (@CU dim r r0 r1 a b) = Zero. +Lemma f_to_vec_X : forall (n i : nat) (f : nat -> bool), (i < n)%nat -> + (uc_eval (X i)) × (f_to_vec n f) = f_to_vec n (update f i (¬ (f i))). Proof. - intros dim a b r r0 r1 H. - simpl. - assert (uc_eval (@CNOT dim a b) = Zero). - { assert (a = b). - apply Classical_Prop.NNPP. - intro contra. contradict H. - constructor; assumption. - autorewrite with eval_db. gridify. } - rewrite H0. - Msimpl_light. - reflexivity. + intros. rewrite denote_X. apply f_to_vec_σx. auto. Qed. -Lemma UR_a_geq_dim : forall (dim a b : nat) r r0 r1, - (dim <= a)%nat -> - uc_eval (@CU dim r r0 r1 a b) = Zero. +Lemma f_to_vec_Y : forall (n i : nat) (f : nat -> bool), (i < n)%nat -> + (uc_eval (SQIR.Y i)) × (f_to_vec n f) + = (-1) ^ Nat.b2n (f i) * Ci .* f_to_vec n (update f i (¬ f i)). Proof. - intros dim a b r r0 r1 H. - simpl. - assert (uc_eval (@Rz dim ((r1 + r0) / 2) a) = Zero). - { autorewrite with eval_db. gridify. } - rewrite H0. - Msimpl_light. - reflexivity. + intros. rewrite denote_Y. apply f_to_vec_σy. auto. Qed. -Local Opaque CU. -Lemma f_to_vec_X : forall (n i : nat) (f : nat -> bool), (i < n)%nat -> - (uc_eval (X i)) × (f_to_vec n f) = f_to_vec n (update f i (¬ (f i))). +Lemma f_to_vec_Z : forall (n i : nat) (f : nat -> bool), (i < n)%nat -> + (uc_eval (SQIR.Z i)) × (f_to_vec n f) = (-1) ^ Nat.b2n (f i) .* f_to_vec n f. Proof. - intros. rewrite denote_X. apply f_to_vec_σx. auto. + intros. rewrite denote_Z. apply f_to_vec_σz. auto. Qed. Lemma f_to_vec_CNOT : forall (n i j : nat) (f : nat -> bool), @@ -631,7 +450,8 @@ Proof. intros. rewrite denote_H. apply f_to_vec_hadamard; auto. Qed. -#[export] Hint Rewrite f_to_vec_CNOT f_to_vec_Rz f_to_vec_X using lia : f_to_vec_db. +#[export] Hint Rewrite f_to_vec_CNOT f_to_vec_SWAP f_to_vec_Rz + f_to_vec_X f_to_vec_Y f_to_vec_Z using lia : f_to_vec_db. Ltac f_to_vec_simpl_body := autorewrite with f_to_vec_db; @@ -648,6 +468,16 @@ Ltac f_to_vec_simpl_body := Ltac f_to_vec_simpl := repeat f_to_vec_simpl_body. +Lemma Cexp_bool_mul b a : + Cexp (b2R b * a) = Cexp a ^ (Nat.b2n b). +Proof. + destruct b. + - rewrite Rmult_1_l. lca. + - rewrite Rmult_0_l, Cexp_0. easy. +Qed. + +#[export] Hint Rewrite Cexp_bool_mul : Cexp_db. + Lemma f_to_vec_CCX : forall (dim a b c : nat) (f : nat -> bool), (a < dim)%nat -> (b < dim)%nat -> (c < dim)%nat -> a <> b -> a <> c -> b <> c -> (uc_eval (CCX a b c)) × (f_to_vec dim f) @@ -658,20 +488,354 @@ Proof. simpl uc_eval. repeat rewrite Mmult_assoc. f_to_vec_simpl. - rewrite (update_same _ b). - 2: destruct (f a); destruct (f b); reflexivity. - destruct (f a); destruct (f b); destruct (f c); simpl. - all: autorewrite with R_db C_db Cexp_db. - all: cancel_terms (Cexp (PI * / 4)). - all: group_Cexp. - all: repeat match goal with - | |- context [Cexp ?r] => field_simplify r - end. - all: autorewrite with R_db C_db Cexp_db. - all: rewrite Mscale_plus_distr_r. - all: distribute_scale; group_radicals. - all: lma. + rewrite xorb_false_l, xorb_true_l. + rewrite (xorb_comm _ (f b)), xorb_assoc, xorb_nilpotent, xorb_false_r. + replace ((((¬ f b) ⊕ f a) ⊕ f b) ⊕ f a) with true by + (now destruct (f b), (f a)). + replace (((¬ f b) ⊕ f a) ⊕ f b) with (¬ f a) by + (now destruct (f b), (f a)). + replace ((f b ⊕ (f b ⊕ f a)) ⊕ f a) with false by + (now destruct (f b), (f a)). + rewrite <- xorb_assoc, xorb_nilpotent, xorb_false_l. + rewrite (update_same _ b) by easy. + prep_matrix_equivalence. + intros i j Hi Hj. + unfold scale, Mplus. + C_field. + group_Cexp. + replace (b2R (f b) * - (PI / 4) + b2R (f b ⊕ f a) * (PI / 4) + + b2R (f a) * - (PI / 4) + b2R (f b ⊕ f a) * - (PI / 4) + b2R (f a)*(PI/4)+ + b2R (f b) * (PI / 4) + b2R false * (PI / 4))%R with + (0)%R by (simpl; lra). + rewrite (Cexp_add _ (1 * PI)), !Rmult_1_l, Cexp_PI. + rewrite Rmult_0_l, Rplus_0_l, Cexp_0. + rewrite <- !Cplus_assoc, <- Cmult_assoc, <- Cmult_plus_distr_l. + assert (aux : forall b, (b2R (¬ b) = 1 - b2R b)%R) by + (intros b'; destruct b'; simpl; lra). + rewrite <- negb_xorb_l. + rewrite !aux. + replace (b2R (f b ⊕ f a) * - (PI/4) + b2R (f a)*(PI/4) + b2R (f b)*(PI/4) + + b2R (f c) * PI + (1 - b2R (f b)) * - (PI / 4) + + (1 - b2R (f b ⊕ f a)) * (PI / 4) + (1 - b2R (f a)) * - (PI / 4) + + (PI / 4))%R with + ((2 * b2R (f c) + b2R (f b) + b2R (f a) - b2R (f b ⊕ f a)) * (PI / 2))%R + by (simpl; lra). + replace (2 * b2R (f c) + b2R (f b) + b2R (f a) - b2R (f b ⊕ f a))%R + with (2 * (b2R (f c) + b2R (f a && f b)))%R by + (destruct (f a), (f b), (f c); cbn; lra). + rewrite Rmult_comm, <- Rmult_assoc. + replace (PI / 2 * 2)%R with PI by lra. + rewrite Rmult_comm. + replace (Cexp ((b2R (f c) + b2R (f a && f b)) * PI)) with + (if f c ⊕ (f a && f b) then -1 : C else C1)%C by + (destruct (f c), (f a && f b); cbn; autorewrite with R_db Cexp_db; + try lca; symmetry; apply Cexp_2PI). + destruct (f c ⊕ (f a && f b)); lca. +Qed. + + +(* It is also helpful to have lemmas with the specific conditions for + a gate being Zero *) + +Section IllTyped. +Local Open Scope nat_scope. + +Lemma CNOT_ill_typed : forall {dim} n m, + (dim <= n \/ dim <= m \/ n = m) -> + @uc_eval dim (CNOT n m) = Zero. +Proof. + intros dim n m H. + rewrite denote_cnot. + unfold ueval_cnot, pad_ctrl, pad. + now bdestruct_all. +Qed. + +Lemma ID_ill_typed : forall dim q, dim <= q -> @uc_eval dim (SQIR.ID q) = Zero. +Proof. + intros. + rewrite denote_ID. + unfold pad_u, pad. + Modulus.bdestructΩ'. +Qed. + +Lemma H_ill_typed : forall dim q, dim <= q -> @uc_eval dim (H q) = Zero. +Proof. + intros. + autorewrite with eval_db. + Modulus.bdestructΩ'. +Qed. + +Lemma X_ill_typed : forall dim q : nat, dim <= q -> @uc_eval dim (X q) = Zero. +Proof. + intros. + autorewrite with eval_db. + Modulus.bdestructΩ'. +Qed. + +Lemma Y_ill_typed : forall dim q : nat, dim <= q -> @uc_eval dim (Y q) = Zero. +Proof. + intros. + autorewrite with eval_db. + Modulus.bdestructΩ'. +Qed. + +Lemma Z_ill_typed : forall dim q : nat, dim <= q -> + @uc_eval dim (SQIR.Z q) = Zero. +Proof. + intros. + autorewrite with eval_db. + Modulus.bdestructΩ'. +Qed. + +Local Transparent SWAP. +Lemma SWAP_ill_typed : forall dim a b, + (dim <= a \/ dim <= b \/ a = b) -> + @uc_eval dim (SWAP a b) = Zero. +Proof. + intros. + simpl. + rewrite CNOT_ill_typed by easy. + now Msimpl. Qed. +Local Opaque SWAP. + +Lemma Rz_ill_typed : forall dim a n, + dim <= n -> @uc_eval dim (Rz a n) = Zero. +Proof. + intros. + autorewrite with eval_db. + Modulus.bdestructΩ'. +Qed. + +Lemma proj_ill_typed : forall dim q b, + dim <= q -> proj q dim b = Zero. +Proof. + intros. + unfold proj, pad_u, pad. + Modulus.bdestructΩ'. +Qed. + +End IllTyped. + + +Lemma proj_commutes : forall dim q1 q2 b1 b2, + proj q1 dim b1 × proj q2 dim b2 = proj q2 dim b2 × proj q1 dim b1. +Proof. + intros dim q1 q2 b1 b2. + bdestruct (q1 b2 -> proj q dim b1 × proj q dim b2 = Zero. +Proof. + intros dim q b1 b2 neq. + unfold proj, pad_u, pad. + Modulus.bdestructΩ'; [|apply Mmult_0_l]. + restore_dims. + rewrite 2!kron_mixed_product by auto_wf. + replace (bool_to_matrix b1 × bool_to_matrix b2) with (@Zero 2 2) + by (destruct b1, b2; try easy; lma). + now Msimpl. +Qed. + +(* TODO: move to QuantumLib *) + +Lemma bra0_phase : forall ϕ, bra 0 × phase_shift ϕ = bra 0. +Proof. intros; lma'. Qed. + +Lemma bra1_phase : forall ϕ, bra 1 × phase_shift ϕ = Cexp ϕ .* bra 1. +Proof. intros; lma'. Qed. + +#[export] Hint Rewrite bra0_phase bra1_phase : bra_db. + +Lemma braketbra_same : forall x y, bra x × (ket x × bra y) = bra y. +Proof. intros. destruct x; destruct y; lma'. Qed. + +Lemma braketbra_diff : forall x y z, (x + y = 1)%nat -> bra x × (ket y × bra z) = Zero. +Proof. intros. destruct x; destruct y; try lia; lma'. Qed. + +#[export] Hint Rewrite braketbra_same braketbra_diff using lia : bra_db. + +(* Auxiliary proofs about the semantics of CU and TOFF *) +Lemma CU_correct : forall (dim : nat) θ ϕ λ c t, + (t < dim)%nat -> c <> t -> + uc_eval (CU θ ϕ λ c t) = proj c dim false .+ (proj c dim true) × (ueval_r dim t (U_R θ ϕ λ)). +Proof. + intros. + (* simpl. + bdestruct (c replace a with 0%R by lra; rewrite Cexp_0 + end. + lca. + } + assert (Hcase2 : Cexp ((λ + ϕ) / 2) .* rotation (θ / 2) ϕ 0 × (σx × + (rotation (- θ / 2) 0 (- (ϕ + λ) / 2) × (σx × phase_shift ((λ - ϕ) / 2)))) + = rotation θ ϕ λ). 1:{ + + prep_matrix_equivalence. + unfold rotation. + assert (Hrw : (forall c, c / 2 / 2 = c / 4)%R) by (intros; lra). + rewrite 2!Hrw. + rewrite Rplus_0_r. + autorewrite with R_db. + unfold Mmult; simpl. + autorewrite with trig_db. + unfold scale, Mmult. + autorewrite with Cexp_db. + + by_cell; cbn. + Csimpl. + C_field. + rewrite <- !Cmult_assoc. + rewrite (Rplus_comm λ). + pose proof (cos_plus (θ*/4) (θ*/4)) as Hcos. + replace ((θ*/4)+(θ*/4))%R with (θ*/2)%R in Hcos by lra. + rewrite Hcos. + lca. + + Csimpl. + C_field. + rewrite <- !Cmult_assoc. + rewrite (Rplus_comm λ). + pose proof (sin_plus (θ*/4) (θ*/4)) as Hcos. + replace ((θ*/4)+(θ*/4))%R with (θ*/2)%R in Hcos by lra. + rewrite Hcos. + replace (Cexp λ) with (Cexp ((ϕ + λ) * / 2 + (λ + - ϕ) * / 2)) + by (f_equal; lra). + rewrite Cexp_add. + lca. + + Csimpl. + C_field. + rewrite <- !Cmult_assoc. + rewrite (Rplus_comm λ). + pose proof (sin_plus (θ*/4) (θ*/4)) as Hcos. + replace ((θ*/4)+(θ*/4))%R with (θ*/2)%R in Hcos by lra. + rewrite Hcos. + lca. + + Csimpl. + C_field. + rewrite <- !Cexp_add. + rewrite Cmult_comm, <- !Cmult_assoc. + C_field. + rewrite <- Cexp_add. + repeat match goal with + |- context [Cexp ?a] => + progress replace a%R with (ϕ + λ)%R by lra + end. + autorewrite with RtoC_db. + pose proof (cos_plus (θ*/4) (θ*/4)) as Hcos. + replace ((θ*/4)+(θ*/4))%R with (θ*/2)%R in Hcos by lra. + rewrite Hcos. + lca. + } + gridify. (* slow *) + all: clear -Hcase1 Hcase2. + all: autorewrite with M_db_light ket_db bra_db. + all: rewrite Mplus_comm; + repeat (apply f_equal2; try reflexivity). + + (* A little messy because we need to apply trig identities; + goal #1 = goal #3 and goal #2 = goal #4 *) + - apply Hcase1. + - rewrite <- Mscale_kron_dist_l, <- Mscale_kron_dist_r, <- Mscale_mult_dist_l. + do 2 f_equal. + apply Hcase2. + - apply Hcase1. + - rewrite <- 3!Mscale_kron_dist_l, <- Mscale_kron_dist_r, <- Mscale_mult_dist_l. + do 4 f_equal. + apply Hcase2. +Qed. + +Lemma UR_not_WT : forall (dim a b : nat) r r0 r1, + ~ uc_well_typed (@uapp1 _ dim (U_R r r0 r1) b) -> + uc_eval (@CU dim r r0 r1 a b) = Zero. +Proof. + intros dim a b r r0 r1 H. + simpl. + bdestruct (b + uc_eval (@CU dim r r0 r1 a b) = Zero. +Proof. + intros dim a b r r0 r1 H. + simpl. + bdestruct (a =? b); [rewrite CNOT_ill_typed by lia; now Msimpl|]. + bdestruct (b + uc_eval (@CU dim r r0 r1 a b) = Zero. +Proof. + intros dim a b r r0 r1 H. + simpl. + rewrite CNOT_ill_typed by lia. + now Msimpl. +Qed. +Local Opaque CU. + + Lemma CCX_a_geq_dim : forall (dim a b c : nat), (dim <= a)%nat -> uc_eval (@CCX dim a b c) = Zero. @@ -679,29 +843,20 @@ Proof. intros dim a b c H. unfold CCX. simpl. - rewrite (denote_cnot dim a b). - unfold ueval_cnot, pad_ctrl, pad. - gridify. + rewrite CNOT_ill_typed by lia. + now Msimpl_light. Qed. Lemma CCX_not_WT : forall (dim a b c : nat), ~ uc_well_typed (@CNOT dim b c) -> uc_eval (@CCX dim a b c) = Zero. Proof. intros dim a b c H. - unfold CCX. - simpl. - assert (uc_eval (@CNOT dim b c) = Zero). - { autorewrite with eval_db. - gridify. - assert (uc_well_typed (@CNOT (b + (1 + d + 1) + d0) b (b + 1 + d))). - apply uc_well_typed_CNOT; repeat split; lia. - contradiction. - assert (uc_well_typed (@CNOT (c + (1 + d + 1) + d0) (c + 1 + d) c)). - apply uc_well_typed_CNOT; repeat split; lia. - contradiction. } - rewrite H0. - Msimpl_light. - reflexivity. + destruct (ltac:(lia) : ((b < dim /\ c < dim /\ b <> c) \/ + (dim <= b \/ dim <= c \/ b = c))%nat). + - exfalso; apply H, uc_well_typed_CNOT; easy. + - simpl. + rewrite (CNOT_ill_typed b c) by easy. + now Msimpl_light. Qed. Local Transparent CNOT. @@ -710,17 +865,17 @@ Lemma CCX_not_fresh : forall (dim a b c : nat), Proof. intros dim a b c H. unfold CCX. - simpl. - assert (ueval_cnot dim a b = Zero \/ ueval_cnot dim a c = Zero). - { assert (a = b \/ a = c). - apply Classical_Prop.NNPP. + cbn -[CNOT]. + enough (Hz : @uc_eval dim (CNOT a b) = Zero + \/ @uc_eval dim (CNOT a c) = Zero) by + (destruct Hz as [-> | ->]; now Msimpl_light). + assert (H0 : a = b \/ a = c). { + enough (~ (a <> b /\ a <> c)) by lia. intro contra. contradict H. - apply Classical_Prop.not_or_and in contra as [? ?]. - constructor; assumption. - destruct H0. - left. autorewrite with eval_db. gridify. - right. autorewrite with eval_db. gridify. } - destruct H0; rewrite H0; Msimpl_light; reflexivity. + constructor; lia. + } + destruct H0 as [-> | ->]; [left | right]; + apply CNOT_ill_typed; lia. Qed. Local Opaque CNOT. Local Opaque CCX. @@ -729,9 +884,9 @@ Lemma CCX_correct : forall (dim : nat) a b c, (b < dim)%nat -> (c < dim)%nat -> a <> b -> a <> c -> b <> c -> uc_eval (CCX a b c) = proj a dim false .+ (proj a dim true) × (ueval_cnot dim b c). intros dim a b c ? ? ? ? ?. - bdestruct (a is_fresh q (invert u). Proof. intros dim q u. - split; intro H. - induction u; try dependent destruction u; inversion H; subst; constructor; auto. - induction u; try dependent destruction u; inversion H; subst; constructor; auto. + split; intro H; + induction u; try dependent destruction u; + inversion H; subst; constructor; auto. Qed. Lemma proj_adjoint : forall dim q b, (proj q dim b) † = proj q dim b. Proof. intros. unfold proj, pad_u, pad. - gridify. - destruct b; simpl; Msimpl; reflexivity. + Modulus.bdestruct_one; + restore_dims; + distribute_adjoint; [destruct b|]; + simpl; Msimpl; reflexivity. Qed. Lemma invert_control : forall dim q (u : base_ucom dim), @@ -995,9 +1157,20 @@ Lemma proj_CNOT_ctl_true : forall dim m n, uc_eval (CNOT m n) × proj m dim true = proj m dim true × uc_eval (X n). Proof. intros dim m n H. - unfold proj. - autorewrite with eval_db. - gridify; Qsimpl; reflexivity. + bdestruct (n (uc_eval c) ⊗ I (2^k) = uc_eval (cast c (dim + k)). diff --git a/VOQC/ConnectivityGraph.v b/VOQC/ConnectivityGraph.v index 5edecbb..f55e9bf 100644 --- a/VOQC/ConnectivityGraph.v +++ b/VOQC/ConnectivityGraph.v @@ -998,7 +998,7 @@ Proof. left. unfold is_in_graph. bdestruct_all; reflexivity. - rewrite plus_assoc. + rewrite Nat.add_assoc. apply IHdist; lia. Qed. diff --git a/VOQC/Layouts.v b/VOQC/Layouts.v index 6ee7d30..fc4109b 100644 --- a/VOQC/Layouts.v +++ b/VOQC/Layouts.v @@ -209,6 +209,8 @@ Proof. rewrite H. auto. Qed. +#[export] Hint Resolve get_phys_perm get_log_perm : perm_db. + Lemma get_log_phys_inv : forall n lay l, layout_bijective n lay -> l < n -> get_log lay (get_phys lay l) = l. @@ -237,6 +239,8 @@ Proof. reflexivity. Qed. +#[export] Hint Resolve get_log_phys_inv get_phys_log_inv : perm_inv_db. + Lemma get_phys_lt : forall dim m x, layout_bijective dim m -> x < dim -> @@ -549,6 +553,8 @@ Proof. rewrite find_log_swap_log_3 with (n:=n); auto. Qed. +#[export] Hint Resolve swap_log_preserves_bij : perm_db. + (** * Trivial layout *) Fixpoint trivial_layout n : layout := diff --git a/VOQC/Main.v b/VOQC/Main.v index 514aedb..825b436 100644 --- a/VOQC/Main.v +++ b/VOQC/Main.v @@ -1,3 +1,4 @@ +Require Import QuantumLib.Permutations. Require Import GateCancellation. Require Import HadamardReduction. Require Import NotPropagation. @@ -1007,29 +1008,24 @@ Proof. intros dim c1 c2 lay1 lay2 WT1 WT2 WF1 WF2 H. unfold check_swap_equivalence in H. unfold is_swap_equivalent in H. - destruct (MappingValidation.check_swap_equivalence (full_to_map c1) - (full_to_map c2) lay1 lay2 - (fun n : nat => MappingGateSet.match_gate match_gate)) eqn:mv. + destruct (MappingValidation.check_swap_equivalence + (full_to_map c1) (full_to_map c2) lay1 lay2 + (fun n : nat => MappingGateSet.match_gate match_gate)) eqn:mv. assert (mvWF:=mv). destruct p. 2: inversion H. - apply MVP.check_swap_equivalence_implies_equivalence in mv; auto. - apply MVP.check_swap_equivalence_layouts_WF in mvWF as [? ?]; auto. + apply MVP.check_swap_equivalence_implies_equivalence in mv; + auto using full_to_map_WT. + apply MVP.check_swap_equivalence_layouts_WF in mvWF as [? ?]; + auto using full_to_map_WT. unfold MVP.SRP.uc_equiv_perm_ex in mv. exists (get_phys lay1 ∘ get_log lay2)%prg. exists (get_phys l0 ∘ get_log l)%prg. - repeat split. - apply Permutations.permutation_compose. - apply get_phys_perm; auto. - apply get_log_perm; auto. - apply Permutations.permutation_compose. - apply get_phys_perm; auto. - apply get_log_perm; auto. + repeat split; [auto with perm_db..|]. unfold eval. unfold MVP.SRP.MapList.eval in mv. rewrite <- 2 list_to_ucom_full_to_map in mv. apply mv. - all: apply full_to_map_WT; assumption. Qed. Lemma check_constraints_correct : forall dim (c : circ dim) (cg : c_graph), diff --git a/VOQC/MappingValidation.v b/VOQC/MappingValidation.v index 454fcda..db076fe 100644 --- a/VOQC/MappingValidation.v +++ b/VOQC/MappingValidation.v @@ -165,7 +165,7 @@ Proof. constructor. apply uc_well_typed_l_implies_dim_nonzero in WT. assumption. - * apply IHl in H; auto. + * apply IHl in H; auto with perm_db. destruct H as [l0 [? ?]]. exists l0. subst. split; auto. rewrite (cons_to_app _ l). @@ -175,26 +175,21 @@ Proof. unfold uc_equiv_perm_ex. unfold MapList.eval. simpl. - rewrite denote_SKIP. + rewrite denote_SKIP by lia. apply equal_on_basis_states_implies_equal; auto with wf_db. intro f. rewrite Mmult_assoc. - rewrite perm_to_matrix_permutes_qubits. + rewrite perm_to_matrix_permutes_qubits by auto with perm_db. repeat rewrite Mmult_assoc. rewrite f_to_vec_SWAP by assumption. Msimpl. - rewrite perm_to_matrix_permutes_qubits. + rewrite perm_to_matrix_permutes_qubits by auto with perm_db. apply f_to_vec_eq. intros x Hx. rewrite fswap_swap_log with (dim:=dim) by assumption. rewrite get_log_phys_inv with (n:=dim); auto. - apply get_phys_perm. - apply swap_log_preserves_bij; auto. - apply get_log_perm; auto. apply uc_well_typed_l_implies_dim_nonzero in WT. assumption. - apply H0. - apply swap_log_preserves_bij; auto. + dependent destruction m0. Qed. @@ -285,28 +280,16 @@ Proof. unfold uc_equiv_perm_ex, MapList.eval in *. unfold MapList.uc_equiv_l, uc_equiv in *. rewrite rs1, rs2, eq. - rewrite <- 2 perm_to_matrix_Mmult. + rewrite <- 2 perm_to_matrix_Mmult by auto with perm_db. repeat rewrite Mmult_assoc. apply f_equal2; try reflexivity. repeat rewrite <- Mmult_assoc. apply f_equal2; try reflexivity. - rewrite perm_to_matrix_Mmult. + rewrite perm_to_matrix_Mmult by auto with perm_db. repeat rewrite Mmult_assoc. - rewrite perm_to_matrix_Mmult. - rewrite 2 perm_to_matrix_I. + rewrite perm_to_matrix_Mmult by auto with perm_db. + rewrite 2 perm_to_matrix_I by eauto with perm_inv_db. Msimpl. reflexivity. - apply permutation_compose. - apply get_log_perm; auto. - apply get_phys_perm; auto. - intros x Hx. - apply get_log_phys_inv with (n:=dim); auto. - apply permutation_compose. - apply get_log_perm; auto. - apply get_phys_perm; auto. - intros x Hx. - apply get_log_phys_inv with (n:=dim); auto. - all: try apply get_log_perm; auto. - all: try apply get_phys_perm; auto. inversion H. Qed. diff --git a/VOQC/SwapRoute.v b/VOQC/SwapRoute.v index 0a8e678..df80603 100644 --- a/VOQC/SwapRoute.v +++ b/VOQC/SwapRoute.v @@ -141,11 +141,7 @@ Lemma perm_pair_get_log_phys: forall dim m, perm_pair dim (get_log m) (get_phys m). Proof. intros dim m. - repeat split. - apply get_log_perm. auto. - apply get_phys_perm. auto. - intros. apply get_log_phys_inv with (n:=dim); auto. - intros. apply get_phys_log_inv with (n:=dim); auto. + repeat split; eauto with perm_db perm_inv_db. Qed. Lemma perm_pair_get_phys_log: forall dim m, @@ -153,11 +149,7 @@ Lemma perm_pair_get_phys_log: forall dim m, perm_pair dim (get_phys m) (get_log m). Proof. intros dim m. - repeat split. - apply get_phys_perm. auto. - apply get_log_perm. auto. - intros. apply get_phys_log_inv with (n:=dim); auto. - intros. apply get_log_phys_inv with (n:=dim); auto. + repeat split; eauto with perm_db perm_inv_db. Qed. Lemma permute_commutes_with_map_qubits : forall dim (u : base_ucom dim) p1 p2, @@ -178,7 +170,7 @@ Proof. - apply equal_on_basis_states_implies_equal; auto with wf_db. intro f. repeat rewrite Mmult_assoc. - rewrite perm_to_matrix_permutes_qubits by assumption. + rewrite perm_to_matrix_permutes_qubits by auto with perm_bounded_db. assert (p2 n < dim). { destruct H1 as [? H1]. specialize (H1 n H5). @@ -204,7 +196,7 @@ Proof. with (f_to_vec (dim - (n + 1)) (shift f0 (n + 1))). replace (dim - (n + 1)) with (dim - 1 - n) by lia. rewrite <- f_to_vec_split by auto. - rewrite perm_to_matrix_permutes_qubits by assumption. + rewrite perm_to_matrix_permutes_qubits by auto with perm_bounded_db. remember (update (fun x : nat => f (p1 x)) (p2 n) false) as f0'. replace (f_to_vec dim (fun x : nat => f0 (p1 x))) with (f_to_vec dim f0'). rewrite (f_to_vec_split 0 dim (p2 n)) by auto. @@ -231,7 +223,7 @@ Proof. with (f_to_vec (dim - (n + 1)) (shift f1 (n + 1))). replace (dim - (n + 1)) with (dim - 1 - n) by lia. rewrite <- f_to_vec_split by auto. - rewrite perm_to_matrix_permutes_qubits by assumption. + rewrite perm_to_matrix_permutes_qubits by auto with perm_bounded_db. remember (update (fun x : nat => f (p1 x)) (p2 n) true) as f1'. replace (f_to_vec dim (fun x : nat => f1 (p1 x))) with (f_to_vec dim f1'). rewrite (f_to_vec_split 0 dim (p2 n)) by auto. @@ -254,10 +246,10 @@ Proof. - apply equal_on_basis_states_implies_equal; auto with wf_db. intro f. repeat rewrite Mmult_assoc. - rewrite perm_to_matrix_permutes_qubits by assumption. + rewrite perm_to_matrix_permutes_qubits by auto with perm_bounded_db. replace (ueval_cnot dim n n0) with (uc_eval (@SQIR.CNOT dim n n0)) by reflexivity. rewrite f_to_vec_CNOT by assumption. - rewrite perm_to_matrix_permutes_qubits by assumption. + rewrite perm_to_matrix_permutes_qubits by auto with perm_bounded_db. replace (ueval_cnot dim (p2 n) (p2 n0)) with (uc_eval (@SQIR.CNOT dim (p2 n) (p2 n0))) by reflexivity. assert (p2 n < dim). @@ -312,7 +304,6 @@ Proof. rewrite denote_SKIP by assumption. Msimpl. rewrite perm_to_matrix_Mmult, perm_to_matrix_I; auto. - apply permutation_compose; auto. Qed. Lemma SWAP_well_typed : forall dim a b, @@ -435,8 +426,8 @@ Proof. destruct (path_to_swaps (n :: n0 :: p) (swap_log m a n)) eqn:res'. inversion res; subst. assert (WFm':=res'). - eapply path_to_swaps_bijective in WFm'; auto. - eapply IHp in res'; auto. + eapply path_to_swaps_bijective in WFm'; auto using swap_log_preserves_bij. + eapply IHp in res'; auto using swap_log_preserves_bij. unfold uc_equiv_perm_ex in *. rewrite cons_to_app. rewrite eval_append, res'. @@ -444,21 +435,15 @@ Proof. apply f_equal2; try reflexivity. apply f_equal2; try reflexivity. apply equal_on_basis_states_implies_equal; auto with wf_db. - unfold eval; auto with wf_db. intro f. rewrite Mmult_assoc. rewrite SWAP_semantics by assumption. rewrite f_to_vec_SWAP by assumption. - rewrite perm_to_matrix_permutes_qubits. - rewrite perm_to_matrix_permutes_qubits. + rewrite 2!perm_to_matrix_permutes_qubits by + eauto using swap_log_preserves_bij with perm_db perm_bounded_db. apply f_to_vec_eq. intros x Hx. apply fswap_swap_log with (dim:=dim); auto. - apply get_phys_perm; assumption. - apply get_phys_perm. - apply swap_log_preserves_bij; assumption. - apply swap_log_preserves_bij; assumption. - apply swap_log_preserves_bij; assumption. Qed. (* These uc_eq_perm_* lemmas are specific to swap_route_sound -- they help @@ -520,10 +505,7 @@ Proof. rewrite (perm_to_matrix_I _ (p1inv ∘ p1)%prg). Msimpl. reflexivity. - unfold eval; auto with wf_db. - apply permutation_compose; auto. - intros x Hx. - apply H3. auto. + auto. Qed. Lemma uc_equiv_perm_ex_app2 : forall {dim} (l1 l2 : ucom_l dim) (g : gate_app (Map_Unitary (G.U 1)) dim) p1 p2 p3, diff --git a/VOQC/UnitaryListRepresentation.v b/VOQC/UnitaryListRepresentation.v index 7b7796e..99ee8e0 100644 --- a/VOQC/UnitaryListRepresentation.v +++ b/VOQC/UnitaryListRepresentation.v @@ -1,7 +1,7 @@ Require Export Coq.Classes.Equivalence. Require Export Coq.Classes.Morphisms. Require Export Setoid. -Require Import QuantumLib.Permutations. +Require Import QuantumLib.Permutations QuantumLib.PermutationAutomation. Require Export GateSet. Require Export SQIR.Equivalences. @@ -1107,6 +1107,14 @@ Qed. Definition eval {dim} (l : gate_list G.U dim) := uc_eval (list_to_ucom l). +Lemma WF_eval {dim} l : WF_Matrix (@eval dim l). +Proof. + unfold eval. + auto_wf. +Qed. + +#[export] Hint Resolve WF_eval : wf_db. + Lemma eval_append : forall {dim} (l1 l2 : gate_list G.U dim), eval (l1 ++ l2) = eval l2 × eval l1. Proof. @@ -1302,7 +1310,6 @@ Proof. apply permutation_id. rewrite perm_to_matrix_I; auto. unfold eval. Msimpl. reflexivity. - apply permutation_id. Qed. Lemma uc_equiv_perm_sym : forall {dim} (l1 l2 : gate_list G.U dim), l1 ≡x l2 -> l2 ≡x l1. @@ -1310,40 +1317,19 @@ Proof. intros dim l1 l2 H. destruct H as [p1 [p2 [Hp1 [Hp2 H]]]]. unfold uc_equiv_perm in *. - destruct Hp1 as [p1inv Hp1]. - destruct Hp2 as [p2inv Hp2]. - assert (permutation dim p1inv). - { exists p1. - intros x Hx. - destruct (Hp1 x Hx) as [? [? [? ?]]]. - repeat split; auto. } - assert (permutation dim p2inv). - { exists p2. - intros x Hx. - destruct (Hp2 x Hx) as [? [? [? ?]]]. - repeat split; auto. } - exists p1inv. - exists p2inv. - repeat split; auto. + exists (perm_inv dim p1). + exists (perm_inv dim p2). + repeat split; auto with perm_db. rewrite H. - repeat rewrite Mmult_assoc. - rewrite perm_to_matrix_Mmult; auto. - repeat rewrite <- Mmult_assoc. - rewrite perm_to_matrix_Mmult; auto. - rewrite 2 perm_to_matrix_I. - unfold eval. Msimpl. reflexivity. - apply permutation_compose; auto. - exists p1inv. assumption. - intros x Hx. - destruct (Hp1 x Hx) as [_ [_ [? _]]]. - assumption. - apply permutation_compose; auto. - exists p2inv. assumption. - intros x Hx. - destruct (Hp2 x Hx) as [_ [_ [_ ?]]]. - assumption. - exists p2inv. assumption. - exists p1inv. assumption. + rewrite <- !perm_to_matrix_transpose_eq by easy. + rewrite !Mmult_assoc. + restore_dims. + rewrite 2!perm_to_matrix_transpose_eq, + <- perm_to_matrix_compose by auto with perm_bounded_db. + rewrite <- !Mmult_assoc. + rewrite <- perm_to_matrix_compose by auto with perm_bounded_db. + rewrite 2!perm_to_matrix_I by (intros; cleanup_perm_inv). + now Msimpl. Qed. Lemma uc_equiv_perm_trans : forall {dim} (l1 l2 l3 : gate_list G.U dim), @@ -1359,7 +1345,7 @@ Proof. repeat split. apply permutation_compose; auto. apply permutation_compose; auto. - rewrite <- 2 perm_to_matrix_Mmult; auto. + rewrite <- 2 perm_to_matrix_Mmult; auto with perm_bounded_db. repeat rewrite Mmult_assoc. reflexivity. Qed. diff --git a/coq-sqir.opam b/coq-sqir.opam index 3efcd97..f96071b 100644 --- a/coq-sqir.opam +++ b/coq-sqir.opam @@ -1,6 +1,6 @@ # This file is generated by dune, edit dune-project instead opam-version: "2.0" -version: "1.3.1" +version: "1.3.2" synopsis: "Coq library for reasoning about quantum circuits" description: """ inQWIRE's SQIR is a Coq library for reasoning @@ -14,7 +14,7 @@ bug-reports: "https://github.com/inQWIRE/SQIR/issues" depends: [ "dune" {>= "3.8"} "coq-interval" {>= "4.9.0"} - "coq-quantumlib" {>= "1.5.0"} + "coq-quantumlib" {>= "1.6.0"} "coq" {>= "8.16"} "odoc" {with-doc} ] diff --git a/coq-voqc.opam b/coq-voqc.opam index b53d2f0..93e4749 100644 --- a/coq-voqc.opam +++ b/coq-voqc.opam @@ -1,6 +1,6 @@ # This file is generated by dune, edit dune-project instead opam-version: "2.0" -version: "1.3.1" +version: "1.3.2" synopsis: "A verified optimizer for quantum compilation" description: """ inQWIRE's VOQC is a Coq library for verified @@ -12,11 +12,11 @@ license: "MIT" homepage: "https://github.com/inQWIRE/SQIR" bug-reports: "https://github.com/inQWIRE/SQIR/issues" depends: [ - "dune" {>= "2.8"} - "coq-interval" {>= "4.6.1"} - "coq-quantumlib" {>= "1.1.0"} - "coq-sqir" {>= "1.3.0"} - "coq" {>= "8.12"} + "dune" {>= "3.8"} + "coq-interval" {>= "4.9.0"} + "coq-quantumlib" {>= "1.6.0"} + "coq-sqir" {>= "1.3.2"} + "coq" {>= "8.16"} "odoc" {with-doc} ] build: [