Skip to content

Commit

Permalink
Cleaning up
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Oct 31, 2024
1 parent 060b8dd commit 2f6db1f
Showing 1 changed file with 86 additions and 19 deletions.
105 changes: 86 additions & 19 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1960,24 +1960,51 @@ op lane_func_compress10(x : W16.t) : W10.t = truncate64_10 (

op lane_polyvec_redcomp10(w : W16.t) : W10.t = lane_func_compress10 (lane_func_reduce w).

op lane(w : W16.t) : W16.t = w.
op pcond (w: W16.t) = true.
op pcond2 (w: W16.t) = w \ule W16.of_int (2*3329).

op pcond_all (w: W16.t) = true.
op pcond_reduced (w: W16.t) = w \ule W16.of_int (2*3329).

lemma reduce_commutes x xr : xr = lane_func_reduce x => pcond_reduced xr.
rewrite /lane_func_reduce /pcond_reduced. print Fq.SignedReductions.
have := Fq.SignedReductions.BREDCp_corr (to_sint x) 26 _ _ _ _ _ _; rewrite ?qE /R //=.
+ admit. smt().
rewrite /BREDC.
admit.
qed.

lemma ref_reduce (_bp : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.__poly_reduce : true ==> true].
import BitEncoding.BitChunking.
lemma ref_polyvec_reduce (_r : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.__polyvec_reduce : r = _r ==>
map lane_func_reduce
(map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _r))]))) =
map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list res))]))].
proc.
inline *.
proc change ^while.5 : (sra_32 t0 (W32.of_int 26)). admit.
proc change ^while.9 : (W16_sub r (truncateu16 t0)). admit.
unroll for ^while.
wp 2816.
bdep 16 16 [_bp] [rp] [rp] lane_func_reduce pcond.
admit. admit.
proc change 1 : (init_256_16 (fun i => r.[i])). admit.
proc change ^while{1}.5 : (sra_32 t2 (W32.of_int 26)). admit.
proc change ^while{1}.9 : (W16_sub r0 (truncateu16 t2)). admit.
proc change 5 : (init_768_16 (fun i => if 0 <= i < 256 then aux.[i] else r.[i])). admit.
proc change 6 : (init_256_16 (fun i => r.[256+i])). admit.
proc change ^while{2}.5 : (sra_32 t3 (W32.of_int 26)). admit.
proc change ^while{2}.9 : (W16_sub r1 (truncateu16 t3)). admit.
proc change 10 : (init_768_16 (fun i => if 256 <= i < 512 then aux.[i-256] else r.[i])). admit.
proc change 11 : (init_256_16 (fun i => r.[512+i])). admit.
proc change ^while{3}.5 : (sra_32 t4 (W32.of_int 26)). admit.
proc change ^while{3}.9 : (W16_sub r2 (truncateu16 t4)). admit.
proc change 15 : (init_768_16 (fun i => if 512 <= i < 768 then aux.[i-512] else r.[i])). admit.
do 3!(unroll for ^while).
cfold 2818.
cfold 5637.
cfold 8456.
wp 8457.
bdep 16 16 [_r] [r] [r] lane_func_reduce pcond_all.
+ by move => *;rewrite /pcond_all -/predT;smt(all_predT).
by smt().
qed.


lemma ref_compress (_bp : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.__i_polyvec_compress_ref : true ==> true].
lemma ref_polyvec_compress (_a : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.__i_polyvec_compress_ref :
a = _a /\ all pcond_reduced (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _a))]))) ==>
map lane_func_compress10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _a))]))) =
map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))].
proc.
inline *.
proc change 1 : (init_960_8 (fun i => W8.zero)). admit.
Expand Down Expand Up @@ -2017,12 +2044,33 @@ cfold 3592.
cfold 5387.
cfold 5390.
wp 14413.
bdep 16 10 [_bp] [a] [rp] lane_func_compress10 pcond2.
admit. admit.
bdep 16 10 [_a] [a] [rp] lane_func_compress10 pcond_reduced.
by smt(). by smt().
qed.

op lane0(w : W16.t) = W16.zero.
lemma avx_correctness (_bp : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.avx2 : true ==> true].
lemma ref_correctness (_bp : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.ref :
_bp = bp ==>
map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) =
map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))
].
proof.
proc.
ecall (ref_polyvec_compress bp).
ecall (ref_polyvec_reduce bp).
auto => />.
move => inter Hreduce;split.
+ rewrite -Hreduce allP => x.
rewrite mapP => exm;elim exm => xrep.
rewrite mapP => exmm;elim exmm => xxrep.
by smt(reduce_commutes).
move => preCompress fin.
rewrite -Hreduce => <-.
by rewrite map_comp /=.
qed.

lemma avx_correctness (_bp : W16.t Array768.t) : hoare [ AuxPolyVecCompress10.avx2 : _bp = bp ==>
map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) =
map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))].
proof.
proc.
inline *.
Expand Down Expand Up @@ -2060,14 +2108,33 @@ cfold 183.
cfold 365.
cfold 547.
wp 1807.
bdep 16 10 [_bp] [bp] [rp] lane_polyvec_redcomp10 pcond.
admit. admit.
bdep 16 10 [_bp] [bp] [rp] lane_polyvec_redcomp10 pcond_all.
+ by move => *;rewrite /pcond_all -/predT;smt(all_predT).
+ by smt().
qed.

lemma ref_correctness_p (_bp : W16.t Array768.t) : phoare [ AuxPolyVecCompress10.ref :
_bp = bp ==>
map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) =
map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))
] = 1%r.
admitted.

lemma avx_correctness_p (_bp : W16.t Array768.t) : phoare [ AuxPolyVecCompress10.avx2 : _bp = bp ==>
map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) =
map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))] = 1%r.
admitted.

(* MAP REDUCE GOAL *)
lemma compress10_mr :
equiv [AuxPolyVecCompress10.avx2 ~ AuxPolyVecCompress10.ref : lift_array768 bp{1} = lift_array768 bp{2}==> ={res}].
admitted.
proc*.
exlim bp{1}, bp{2} => _bp1 _bp2.
call{1} (avx_correctness_p _bp1).
call{2} (ref_correctness_p _bp2).
auto => /> Hpre r1 Hr1 r2 Hr2.
admit.
qed.

(*****************************************************************)

Expand Down

0 comments on commit 2f6db1f

Please sign in to comment.