diff --git a/compiler/src/compiler/DeadCodeElim.v b/compiler/src/compiler/DeadCodeElim.v index 10099f03..61608a25 100644 --- a/compiler/src/compiler/DeadCodeElim.v +++ b/compiler/src/compiler/DeadCodeElim.v @@ -15,7 +15,7 @@ Require Import bedrock2.MetricCosts. (* below only for of_list_list_diff *) Require Import compiler.DeadCodeElimDef. -Local Notation exec pick_sp := (exec (pick_sp := pick_sp) PreSpill isRegStr). +Local Notation exec e pick_sp := (@exec _ _ _ _ _ _ _ _ PreSpill isRegStr pick_sp e). Section WithArguments1. Context {width: Z}. diff --git a/compiler/src/compiler/FlatImp.v b/compiler/src/compiler/FlatImp.v index eebc5990..d474edf6 100644 --- a/compiler/src/compiler/FlatImp.v +++ b/compiler/src/compiler/FlatImp.v @@ -21,6 +21,7 @@ Require Import coqutil.Datatypes.ListSet. Require Import coqutil.Map.OfListWord. Require Import coqutil.Word.Bitwidth. Require Import coqutil.Word.Interface. +Require Import coqutil.Tactics.fwd. Local Hint Mode Word.Interface.word - : typeclass_instances. Inductive bbinop: Set := @@ -300,7 +301,7 @@ Module exec. Context {width: Z} {BW: Bitwidth width} {word: word.word width}. Context {mem: map.map word byte} {locals: map.map varname word} {env: map.map String.string (list varname * list varname * stmt varname)}. - Context {ext_spec: ExtSpec} {pick_sp: PickSp}. + Context {ext_spec: ExtSpec} . Context {varname_eq_spec: EqDecider varname_eqb} {word_ok: word.ok word} {mem_ok: map.ok mem} @@ -349,7 +350,7 @@ Module exec. end mc. (* alternative semantics which allow non-determinism *) - Inductive exec: + Inductive exec {pick_sp: PickSp} : stmt varname -> leakage -> trace -> mem -> locals -> metrics -> (leakage -> trace -> mem -> locals -> metrics -> Prop) @@ -457,7 +458,7 @@ Module exec. post k t m l mc -> exec SSkip k t m l mc post. - Lemma det_step: forall k0 t0 m0 l0 mc0 s1 s2 k1 t1 m1 l1 mc1 post, + Lemma det_step {pick_sp: PickSp} : forall k0 t0 m0 l0 mc0 s1 s2 k1 t1 m1 l1 mc1 post, exec s1 k0 t0 m0 l0 mc0 (fun k1' t1' m1' l1' mc1' => k1' = k1 /\ t1' = t1 /\ m1' = m1 /\ l1' = l1 /\ mc1 = mc1') -> exec s2 k1 t1 m1 l1 mc1 post -> exec (SSeq s1 s2) k0 t0 m0 l0 mc0 post. @@ -468,14 +469,14 @@ Module exec. assumption. Qed. - Lemma seq_cps: forall s1 s2 k t m (l: locals) mc post, + Lemma seq_cps {pick_sp: PickSp} : forall s1 s2 k t m (l: locals) mc post, exec s1 k t m l mc (fun k' t' m' l' mc' => exec s2 k' t' m' l' mc' post) -> exec (SSeq s1 s2) k t m l mc post. Proof. intros. eapply seq. 1: eassumption. simpl. clear. auto. Qed. - Lemma call_cps: forall fname params rets binds args fbody argvs k t (l: locals) m mc st post, + Lemma call_cps {pick_sp: PickSp} : forall fname params rets binds args fbody argvs k t (l: locals) m mc st post, map.get e fname = Some (params, rets, fbody) -> map.getmany_of_list l args = Some argvs -> map.putmany_of_list_zip params argvs map.empty = Some st -> @@ -491,7 +492,7 @@ Module exec. cbv beta. intros *. exact id. Qed. - Lemma loop_cps: forall body1 cond body2 k t m l mc post, + Lemma loop_cps {pick_sp: PickSp} : forall body1 cond body2 k t m l mc post, exec body1 k t m l mc (fun k t m l mc => exists b, eval_bcond l cond = Some b /\ (b = false -> post (leak_bool false :: k) t m l (cost_SLoop_false cond mc)) /\ @@ -507,7 +508,7 @@ Module exec. - assumption. Qed. - Lemma weaken: forall k t l m mc s post1, + Lemma weaken {pick_sp: PickSp} : forall k t l m mc s post1, exec s k t m l mc post1 -> forall post2, (forall k' t' m' l' mc', post1 k' t' m' l' mc' -> post2 k' t' m' l' mc') -> @@ -530,7 +531,7 @@ Module exec. intros. simp. eauto 10. Qed. - Lemma seq_assoc: forall s1 s2 s3 k t m l mc post, + Lemma seq_assoc {pick_sp: PickSp} : forall s1 s2 s3 k t m l mc post, exec (SSeq s1 (SSeq s2 s3)) k t m l mc post -> exec (SSeq (SSeq s1 s2) s3) k t m l mc post. Proof. @@ -543,7 +544,7 @@ Module exec. eauto. Qed. - Lemma seq_assoc_bw: forall s1 s2 s3 k t m l mc post, + Lemma seq_assoc_bw {pick_sp: PickSp} : forall s1 s2 s3 k t m l mc post, exec (SSeq (SSeq s1 s2) s3) k t m l mc post -> exec (SSeq s1 (SSeq s2 s3)) k t m l mc post. Proof. intros. simp. eauto 10 using seq. Qed. @@ -556,7 +557,7 @@ Module exec. end; simp. - Lemma intersect: forall k t l m mc s post1, + Lemma intersect {pick_sp: PickSp} : forall k t l m mc s post1, exec s k t m l mc post1 -> forall post2, exec s k t m l mc post2 -> @@ -634,6 +635,140 @@ Module exec. eauto. Qed. + Lemma exec_extends_trace {pick_sp: PickSp} s k t m l mc post : + exec s k t m l mc post -> + exec s k t m l mc (fun k' t' m' l' mc' => post k' t' m' l' mc' /\ exists k'', k' = k'' ++ k). + Proof. + intros H. induction H; try (econstructor; intuition eauto; eexists; align_trace; fail). + - econstructor; intuition eauto. specialize H2 with (1 := H3). fwd. + eexists. intuition eauto. eexists. align_trace. + - econstructor; intuition eauto. fwd. specialize H3 with (1 := H4p0). fwd. + eexists. intuition eauto. eexists. intuition eauto. + eexists. align_trace. + - econstructor; intuition eauto. intros. eapply weaken. 1: eapply H1; eauto. + simpl. intros. fwd. eexists. eexists. intuition eauto. eexists. align_trace. + - eapply if_true; intuition eauto. eapply weaken. 1: eapply IHexec. + simpl. intros. fwd. intuition eauto. eexists. align_trace. + - eapply if_false; intuition eauto. eapply weaken. 1: eapply IHexec. + simpl. intros. fwd. intuition eauto. eexists. align_trace. + - clear H2 H4. econstructor; intuition eauto; fwd; eauto. + { eexists. align_trace. } + { eapply weaken. 1: eapply H3; eauto. simpl. intros. fwd. + instantiate (1 := fun k'0 t'0 m'0 l'0 mc'0 => + mid2 k'0 t'0 m'0 l'0 mc'0 /\ exists k'', k'0 = k'' ++ k). + simpl. intuition. eexists. align_trace. } + simpl in *. fwd. eapply weaken. 1: eapply H5; eauto. + simpl. intros. fwd. intuition. eexists. align_trace. + - econstructor; intuition eauto. fwd. eapply weaken. 1: eapply H1; eauto. + simpl. intros. fwd. intuition eauto. eexists. align_trace. + Qed. + + Lemma exec_ext (pick_sp1: PickSp) s k t m l mc post : + exec (pick_sp := pick_sp1) s k t m l mc post -> + forall pick_sp2, + (forall k', pick_sp1 (k' ++ k) = pick_sp2 (k' ++ k)) -> + exec (pick_sp := pick_sp2) s k t m l mc post. + Proof. + Set Printing Implicit. + intros H1 pick_sp2. induction H1; intros; try solve [econstructor; eauto]. + - econstructor. 4: eapply exec_extends_trace. all: intuition eauto. + { eapply IHexec. intros. rewrite associate_one_left. + repeat rewrite app_assoc. auto. } + fwd. specialize H3 with (1 := H5p0). fwd. intuition eauto. + - econstructor; eauto. intros. replace (pick_sp1 k) with (pick_sp2 k) in *. + { subst a. eapply weaken. + { eapply H1; eauto. + intros. eassert (H2' := H2 (_ ++ _ :: nil)). rewrite <- app_assoc in H2'. eapply H2'. } + eauto. } + symmetry. apply H2 with (k' := nil). + - eapply if_true; eauto. eapply IHexec. + intros. rewrite associate_one_left. repeat rewrite app_assoc. auto. + - eapply if_false; intuition eauto. eapply IHexec. + intros. rewrite associate_one_left. repeat rewrite app_assoc. auto. + - clear H2 H4. eapply loop. 1: eapply exec_extends_trace. all: intuition eauto; fwd; eauto. + { eapply weaken. 1: eapply exec_extends_trace. + { eapply H3; eauto. + intros. rewrite associate_one_left. repeat rewrite app_assoc. auto. } + simpl. intros. fwd. + instantiate (1 := fun k'0 t'0 m'0 l'0 mc'0 => + mid2 k'0 t'0 m'0 l'0 mc'0 /\ exists k'', k'0 = k'' ++ k). + simpl. intuition eauto. eexists. align_trace. } + simpl in *. fwd. eapply H5; eauto. intros. + repeat (rewrite app_assoc || rewrite (app_one_l _ (_ ++ k))). auto. + - econstructor. 1: eapply exec_extends_trace; eauto. simpl. intros. fwd. + eapply H0; eauto. intros. repeat rewrite app_assoc. apply H2. + Qed. + + Local Ltac solve_picksps_equal := + intros; cbv beta; f_equal; + repeat (rewrite rev_app_distr || cbn [rev app]); rewrite List.skipn_app_r; + [|repeat (rewrite app_length || rewrite rev_length || simpl); blia]; + repeat rewrite <- app_assoc; rewrite List.skipn_app_r; + [|rewrite rev_length; reflexivity]; + repeat (rewrite rev_app_distr || cbn [rev app] || rewrite rev_involutive); + repeat rewrite <- app_assoc; reflexivity. + + Lemma exec_to_other_trace (pick_sp: PickSp) s k1 k2 t m l mc post : + exec s k1 t m l mc post -> + exec (pick_sp := fun k => pick_sp (rev (skipn (length k2) (rev k)) ++ k1)) + s k2 t m l mc (fun k2' t' m' l' mc' => + exists k'', + k2' = k'' ++ k2 /\ + post (k'' ++ k1) t' m' l' mc'). + Proof. + intros H. generalize dependent k2. induction H; intros. + - econstructor; intuition eauto. apply H2 in H3. fwd. + eexists. intuition eauto. eexists. intuition eauto. 1: align_trace. + auto. + - econstructor; intuition eauto. + { eapply exec_ext with (pick_sp1 := _). 1: eapply IHexec; eauto. solve_picksps_equal. } + cbv beta in *. fwd. apply H3 in H4p1. + fwd. eexists. intuition eauto. eexists. intuition eauto. eexists. + split; [align_trace|]. repeat rewrite <- app_assoc. auto. + - econstructor; intuition eauto. eexists. split; [align_trace|]. auto. + - econstructor; intuition eauto. eexists. split; [align_trace|]. auto. + - econstructor; intuition eauto. eexists. split; [align_trace|]. auto. + - econstructor; intuition eauto. intros. + replace (rev k2) with (rev k2 ++ nil) in * by apply app_nil_r. + rewrite List.skipn_app_r in * by (rewrite rev_length; reflexivity). + simpl in *. eapply weaken. + { eapply exec_ext with (pick_sp1 := _). 1: eapply H1; eauto. solve_picksps_equal. } + simpl. intros. fwd. eexists _, _. intuition eauto. eexists (_ ++ _ :: nil). + rewrite <- app_assoc. simpl. rewrite <- (app_assoc _ _ k). simpl. eauto. + - econstructor; intuition eauto. eexists. split; [align_trace|]. auto. + - econstructor; intuition eauto. + - econstructor; intuition eauto. eexists. split; [align_trace|]. auto. + - eapply if_true; intuition eauto. eapply weaken. + { eapply exec_ext with (pick_sp1 := _). 1: eapply IHexec. solve_picksps_equal. } + simpl. intros. fwd. eexists. split; [align_trace|]. + repeat rewrite <- app_assoc. auto. + - eapply if_false; intuition eauto. eapply weaken. + { eapply exec_ext with (pick_sp1 := _). 1: eapply IHexec. solve_picksps_equal. } + simpl. intros. fwd. eexists. split; [align_trace|]. + repeat rewrite <- app_assoc. auto. + - eapply loop. 1: eapply IHexec. all: intuition eauto; fwd; eauto. + { eexists. split; [align_trace|]. simpl. auto. } + { eapply exec_ext with (pick_sp1 := _). + { eapply weaken. 1: eapply H3; eauto. simpl. intros. + instantiate (1 := fun k'0 t'0 m'0 l'0 mc'0 => + exists k''0, + k'0 = k''0 ++ k2 /\ + mid2 (k''0 ++ k) t'0 m'0 l'0 mc'0). + fwd. eexists. split; [align_trace|]. + repeat rewrite <- app_assoc. simpl. auto. } + solve_picksps_equal. } + simpl in *. fwd. eapply exec_ext with (pick_sp1 := _). + { eapply weaken. 1: eapply H5; eauto. simpl. intros. fwd. + eexists. split; [align_trace|]. + repeat rewrite <- app_assoc. auto. } + solve_picksps_equal. + - econstructor; intuition. fwd. eapply weaken. + { eapply exec_ext with (pick_sp1 := _). 1: eapply H1; eauto. solve_picksps_equal. } + simpl. intros. fwd. eexists. split; [align_trace|]. + repeat rewrite <- app_assoc. auto. + - econstructor. eexists. split; [align_trace|]. assumption. + Qed. + End FlatImpExec. End exec. Notation exec := exec.exec. diff --git a/compiler/src/compiler/FlatToRiscvCommon.v b/compiler/src/compiler/FlatToRiscvCommon.v index 579964be..fede6d63 100644 --- a/compiler/src/compiler/FlatToRiscvCommon.v +++ b/compiler/src/compiler/FlatToRiscvCommon.v @@ -291,8 +291,8 @@ Section WithParameters. let '(argnames, retnames, fbody) := fun_impl in exists pos, map.get finfo f = Some pos /\ pos mod 4 = 0. - Local Notation stmt := (stmt Z). Check exec. - Local Notation exec pick_sp := (exec (pick_sp := pick_sp) PostSpill isRegZ). + Local Notation stmt := (stmt Z). Check @exec. + Local Notation exec pick_sp e := (@exec _ _ _ _ _ _ _ _ PostSpill isRegZ e pick_sp). (* note: [e_impl_reduced] and [funnames] will shrink one function at a time each time we enter a new function body, to make sure functions cannot call themselves, while diff --git a/compiler/src/compiler/FlatToRiscvFunctions.v b/compiler/src/compiler/FlatToRiscvFunctions.v index e596539b..f3dd3455 100644 --- a/compiler/src/compiler/FlatToRiscvFunctions.v +++ b/compiler/src/compiler/FlatToRiscvFunctions.v @@ -476,9 +476,8 @@ Section Proofs. destruct cond; [destruct op | ]; simpl in *; Simp.simp; repeat (simulate'; simpl_bools; simpl); rewrite option_map_option_map'; intuition. Qed. - - - Local Notation exec pick_sp := (exec (pick_sp := pick_sp) PostSpill isRegZ). + + Local Notation exec e pick_sp := (@exec _ _ _ _ _ _ _ _ PostSpill isRegZ pick_sp e). Definition cost_compile_spec mc := Platform.MetricLogging.addMetricInstructions 95 diff --git a/compiler/src/compiler/Pipeline.v b/compiler/src/compiler/Pipeline.v index 60865d3c..219244cc 100644 --- a/compiler/src/compiler/Pipeline.v +++ b/compiler/src/compiler/Pipeline.v @@ -279,7 +279,7 @@ Section WithWordAndMem. refine ({| Program := string_keyed_map (list string * list string * FlatImp.stmt string); Valid := map.forall_values ParamsNoDup; - Call := locals_based_call_spec (fun pick_sp => FlatImp.exec (pick_sp := pick_sp) PreSpill isRegStr); + Call := locals_based_call_spec (fun pick_sp e => @FlatImp.exec _ _ _ _ _ _ _ _ PreSpill isRegStr e pick_sp); |}). 1: exact tt. intros. cbv [locals_based_call_spec] in *. fwd. do 4 eexists. intuition eauto. @@ -304,7 +304,7 @@ Section WithWordAndMem. refine ({| Program := string_keyed_map (list Z * list Z * FlatImp.stmt Z); Valid := map.forall_values ParamsNoDup; - Call := locals_based_call_spec (fun pick_sp => FlatImp.exec (pick_sp := pick_sp) PreSpill isRegZ); + Call := locals_based_call_spec (fun pick_sp e => @FlatImp.exec _ _ _ _ _ _ _ _ PreSpill isRegZ e pick_sp); |}). 1: exact tt. intros. cbv [locals_based_call_spec] in *. fwd. do 4 eexists. intuition eauto. @@ -319,7 +319,7 @@ Section WithWordAndMem. refine ({| Program := string_keyed_map (list Z * list Z * FlatImp.stmt Z); Valid := map.forall_values FlatToRiscvDef.valid_FlatImp_fun; - Call := locals_based_call_spec_spilled (fun pick_sp => FlatImp.exec (pick_sp := pick_sp) PostSpill isRegZ); + Call := locals_based_call_spec_spilled (fun e pick_sp => @FlatImp.exec _ _ _ _ _ _ _ _ PostSpill isRegZ pick_sp e); |}). 1: exact tt. intros. cbv [locals_based_call_spec_spilled] in *. fwd. do 4 eexists. intuition eauto. @@ -461,7 +461,9 @@ Section WithWordAndMem. simpl in H0. assumption. } - unfold locals_based_call_spec. intros. fwd. + unfold locals_based_call_spec. intros. + exists (fun pick_spL kH kL k => pick_spL ((rev (skipn (length kH) (rev k)) ++ kL))). + exists (fun k => k). intros. fwd. pose proof H0 as GI. unfold useimmediate_functions in GI. eapply map.try_map_values_fw in GI. 2: eassumption. @@ -470,8 +472,16 @@ Section WithWordAndMem. intros. eapply exec.weaken. - eapply useImmediate_correct_aux; eauto. - - simpl. destruct 1 as (?&?&?&?&?). - repeat (eexists; split; try eassumption). + eapply FlatImp.exec.exec_ext. + 1: eapply FlatImp.exec.exec_to_other_trace. + 1: eassumption. + intros. simpl. simpl_rev. rewrite List.skipn_app_r. + 2: rewrite length_rev; reflexivity. + simpl_rev. rewrite List.skipn_app_r. + 2: rewrite length_rev; reflexivity. + rewrite rev_involutive. reflexivity. + - simpl. intros. fwd. eexists. intuition eauto. + do 3 eexists. intuition eauto. unfold cost_spill_spec in *; solve_MetricLog. Qed. @@ -500,7 +510,13 @@ Section WithWordAndMem. simpl in H0. assumption. } - unfold locals_based_call_spec. intros. fwd. + unfold locals_based_call_spec. intros. + exists (fun pick_spL kH kL k => let '(argnames, retnames, fbody) := + match (map.get p1 fname) with Some finfo => finfo | None => (nil, nil, SSkip) end in + fun kk => let k := rev (skipn (length kH) (rev kk)) in + pick_spL (rev kL ++ stmt_leakage eH (rev k, sH, used_ + + eexists. eexists. intros. fwd. pose proof H0 as GI. unfold dce_functions in GI. eapply map.try_map_values_fw in GI. 2: eassumption. @@ -508,8 +524,23 @@ Section WithWordAndMem. eexists _, _, _, _. split. 1: eassumption. split. 1: eassumption. intros. eapply @exec.weaken. - - eapply dce_correct_aux; eauto. - eapply MapEauto.agree_on_refl. + - eapply exec.exec_ext. 1: eapply dce_correct_aux; eauto. + { eapply MapEauto.agree_on_refl. } + 2: { intros. simpl. instantiate (1 := fun x => pick_spL (rev x)). + simpl. rewrite rev_involutive. reflexivity. } + intros. remember (k ++ kH) as kk eqn:Hkk. + replace k with (rev (skipn (length kH) (rev (k ++ kH)))). + { forget (k ++ kH) as kk0. subst. + set (finfo := + match (map.get p1 fname) with | Some finfo => finfo | None => (nil, nil, SSkip) end). + replace fbody with (snd finfo). 1: replace argnames with (fst (fst finfo)). + 1: replace retnames with (snd (fst finfo)). + { + instantiate (2 := fun _ _ _ _ => _). simpl. reflexivity. + 1: reflexivity. + 1: forget (k ++ kH) as kk. 1: instantiate (2 := fun _ _ _ _ => _). 1: reflexivity. + insta + intros. simpl. instantiate (2 := fun _ _ _ _ => _). simpl. reflexivity. - unfold compile_post. intros. fwd. exists retvals. split. diff --git a/compiler/src/compiler/Spilling.v b/compiler/src/compiler/Spilling.v index 834872e7..9558e0ba 100644 --- a/compiler/src/compiler/Spilling.v +++ b/compiler/src/compiler/Spilling.v @@ -24,8 +24,8 @@ Open Scope Z_scope. Section Spilling. Notation stmt := (stmt Z). - Notation execpre pick_sp := (exec (pick_sp := pick_sp) PreSpill isRegZ). - Notation execpost pick_sp := (exec (pick_sp := pick_sp) PostSpill isRegZ). + Notation execpre pick_sp e := (@exec _ _ _ _ _ _ _ _ PreSpill isRegZ e pick_sp). + Notation execpost pick_sp e := (@exec _ _ _ _ _ _ _ _ PostSpill isRegZ e pick_sp). Definition zero := 0. Definition ra := 1.