diff --git a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec index 18d32ece..8459d5f9 100644 --- a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec +++ b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec @@ -1960,24 +1960,51 @@ op lane_func_compress10(x : W16.t) : W10.t = truncate64_10 ( op lane_polyvec_redcomp10(w : W16.t) : W10.t = lane_func_compress10 (lane_func_reduce w). -op lane(w : W16.t) : W16.t = w. -op pcond (w: W16.t) = true. -op pcond2 (w: W16.t) = w \ule W16.of_int (2*3329). - +op pcond_all (w: W16.t) = true. +op pcond_reduced (w: W16.t) = w \ule W16.of_int (2*3329). + +lemma reduce_commutes x xr : xr = lane_func_reduce x => pcond_reduced xr. +rewrite /lane_func_reduce /pcond_reduced. print Fq.SignedReductions. +have := Fq.SignedReductions.BREDCp_corr (to_sint x) 26 _ _ _ _ _ _; rewrite ?qE /R //=. ++ admit. smt(). + rewrite /BREDC. + admit. +qed. -lemma ref_reduce (_bp : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.__poly_reduce : true ==> true]. +import BitEncoding.BitChunking. +lemma ref_polyvec_reduce (_r : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.__polyvec_reduce : r = _r ==> + map lane_func_reduce + (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _r))]))) = + map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list res))]))]. proc. inline *. -proc change ^while.5 : (sra_32 t0 (W32.of_int 26)). admit. -proc change ^while.9 : (W16_sub r (truncateu16 t0)). admit. -unroll for ^while. -wp 2816. -bdep 16 16 [_bp] [rp] [rp] lane_func_reduce pcond. -admit. admit. +proc change 1 : (init_256_16 (fun i => r.[i])). admit. +proc change ^while{1}.5 : (sra_32 t2 (W32.of_int 26)). admit. +proc change ^while{1}.9 : (W16_sub r0 (truncateu16 t2)). admit. +proc change 5 : (init_768_16 (fun i => if 0 <= i < 256 then aux.[i] else r.[i])). admit. +proc change 6 : (init_256_16 (fun i => r.[256+i])). admit. +proc change ^while{2}.5 : (sra_32 t3 (W32.of_int 26)). admit. +proc change ^while{2}.9 : (W16_sub r1 (truncateu16 t3)). admit. +proc change 10 : (init_768_16 (fun i => if 256 <= i < 512 then aux.[i-256] else r.[i])). admit. +proc change 11 : (init_256_16 (fun i => r.[512+i])). admit. +proc change ^while{3}.5 : (sra_32 t4 (W32.of_int 26)). admit. +proc change ^while{3}.9 : (W16_sub r2 (truncateu16 t4)). admit. +proc change 15 : (init_768_16 (fun i => if 512 <= i < 768 then aux.[i-512] else r.[i])). admit. +do 3!(unroll for ^while). +cfold 2818. +cfold 5637. +cfold 8456. +wp 8457. +bdep 16 16 [_r] [r] [r] lane_func_reduce pcond_all. ++ by move => *;rewrite /pcond_all -/predT;smt(all_predT). +by smt(). qed. -lemma ref_compress (_bp : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.__i_polyvec_compress_ref : true ==> true]. +lemma ref_polyvec_compress (_a : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.__i_polyvec_compress_ref : + a = _a /\ all pcond_reduced (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _a))]))) ==> +map lane_func_compress10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _a))]))) = + map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))]. proc. inline *. proc change 1 : (init_960_8 (fun i => W8.zero)). admit. @@ -2017,12 +2044,33 @@ cfold 3592. cfold 5387. cfold 5390. wp 14413. -bdep 16 10 [_bp] [a] [rp] lane_func_compress10 pcond2. -admit. admit. +bdep 16 10 [_a] [a] [rp] lane_func_compress10 pcond_reduced. +by smt(). by smt(). qed. -op lane0(w : W16.t) = W16.zero. -lemma avx_correctness (_bp : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.avx2 : true ==> true]. +lemma ref_correctness (_bp : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.ref : +_bp = bp ==> +map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) = + map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))])) +]. +proof. +proc. +ecall (ref_polyvec_compress bp). +ecall (ref_polyvec_reduce bp). +auto => />. +move => inter Hreduce;split. ++ rewrite -Hreduce allP => x. + rewrite mapP => exm;elim exm => xrep. + rewrite mapP => exmm;elim exmm => xxrep. + by smt(reduce_commutes). +move => preCompress fin. +rewrite -Hreduce => <-. +by rewrite map_comp /=. +qed. + +lemma avx_correctness (_bp : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.avx2 : _bp = bp ==> +map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) = + map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))]. proof. proc. inline *. @@ -2060,14 +2108,33 @@ cfold 183. cfold 365. cfold 547. wp 1807. -bdep 16 10 [_bp] [bp] [rp] lane_polyvec_redcomp10 pcond. -admit. admit. +bdep 16 10 [_bp] [bp] [rp] lane_polyvec_redcomp10 pcond_all. ++ by move => *;rewrite /pcond_all -/predT;smt(all_predT). ++ by smt(). qed. +lemma ref_correctness_p (_bp : W16.t Array768.t) : phoare [ AuxPolyVecCompress10.ref : +_bp = bp ==> +map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) = + map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))])) +] = 1%r. +admitted. + +lemma avx_correctness_p (_bp : W16.t Array768.t) : phoare [ AuxPolyVecCompress10.avx2 : _bp = bp ==> +map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) = + map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))] = 1%r. +admitted. + (* MAP REDUCE GOAL *) lemma compress10_mr : equiv [AuxPolyVecCompress10.avx2 ~ AuxPolyVecCompress10.ref : lift_array768 bp{1} = lift_array768 bp{2}==> ={res}]. -admitted. +proc*. +exlim bp{1}, bp{2} => _bp1 _bp2. +call{1} (avx_correctness_p _bp1). +call{2} (ref_correctness_p _bp2). +auto => /> Hpre r1 Hr1 r2 Hr2. +admit. +qed. (*****************************************************************)