diff --git a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec index cb1241f8..b2e3c7da 100644 --- a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec +++ b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec @@ -1819,7 +1819,7 @@ op sliceget256_16_256 (arr: W16.t Array256.t) (offset: int) : W256.t = W256.bit lemma flatten_take_drop_16 (l : W16.t list) (csize offset bit : int) : 0 <= offset => - offset + csize < 16 * size l => + offset + csize <= 16 * size l => 0 <= bit < csize => nth false (take csize (drop offset (flatten (map W16.w2bits l)))) bit = (nth witness l ((offset + bit) %/ 16)).[(offset + bit) %% 16]. @@ -1833,17 +1833,21 @@ rewrite -get_w2bits;congr. by rewrite (nth_map witness) 1:/#. qed. +lemma size_flatten_W16_w2bits (a : W16.t list) : + (size (flatten (map W16.w2bits (a)))) = 16 * size a. +proof. + rewrite size_flatten -map_comp /(\o) /=. + rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + rewrite StdBigop.Bigint.big_constz count_predT /#. +qed. + lemma aligned_get256_16_256 arr offset : -0 <= offset < 16*256 - 256 => +0 <= offset <= 16*256 - 256 => 256 %| offset => sliceget256_16_256 arr offset = WArray512.get256 (WArray512.init16 (fun (i_0 : int) => arr.[i_0])) (offset %/ 256). move => Ho1 Ho2; rewrite /sliceget256_16_256. -have sz : size (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))) = 256. -+ rewrite size_take 1:/# size_drop 1:/# /max /=. - rewrite size_flatten -map_comp /(\o) /=. - rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. - rewrite StdBigop.Bigint.big_constz count_predT;smt(Array256.size_to_list). +have sz : size (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))) = 256 by rewrite size_take 1:/# size_drop 1:/# /max /=;smt(Array256.size_to_list size_flatten_W16_w2bits). rewrite wordP => i ib; rewrite get_bits2w //. rewrite flatten_take_drop_16;1..3:smt(Array256.size_to_list). rewrite nth_mkseq 1:/# /=. @@ -1857,9 +1861,7 @@ bind op [W16.t & W256.t & Array256.t] sliceget256_16_256 "asliceget". realize bvaslicegetP. move => *; rewrite /sliceget256_16_256 bits2wK // size_take //= size_drop //=. admit. (* bounds are incomplete! 0 <= offset <= 16 * 256 - 256 *) -rewrite size_flatten -map_comp /(\o) /=. -rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. -by rewrite StdBigop.Bigint.big_constz count_predT;smt(Array256.size_to_list). +by smt(Array256.size_to_list size_flatten_W16_w2bits). qed. import BitEncoding BS2Int BitChunking. @@ -1869,29 +1871,83 @@ op sliceset256_16_256 (arr: W16.t Array256.t) (offset: int) (bv: W256.t) : W16.t lemma aligned_set256_16_256 arr offset bv : -0 <= offset < 16*256 - 256 => +0 <= offset <= 16*256 - 256 => 256 %| offset => sliceset256_16_256 arr offset bv = Array256.init (fun (i3 : int) => get16 (set256 ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 256) bv) i3). rewrite /sliceset256_16_256 tP /= => ?? i ib. rewrite !initiE 1,2:/# /=. -rewrite get16_set256E 1,2:/# /= (nth_map witness). -+ admit. -admitted. +rewrite get16_set256E 1,2:/# /= (nth_map []). ++ rewrite size_chunk // !size_cat !size_take 1:/# !size_drop 1:/# /max /=. + by smt(Array256.size_to_list size_flatten_W16_w2bits). +rewrite JWordList.nth_chunk //= 1:/#. +rewrite !size_cat !size_take 1:/# !size_drop 1:/# /max /=. + by smt(Array256.size_to_list size_flatten_W16_w2bits). +case (32 * (offset %/ 256) <= 2 * i);last first. ++ move => ? /=. have ? : 16*i < offset. smt(). + rewrite get16_init16 1:/# -catA drop_cat ifT;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite take_cat_le ifT;1: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + have -> : offset = 16 * (offset %/ 16) by smt(). + rewrite take_flatten_ctt; 1: by smt(mapP W16.size_w2bits). + rewrite -map_take. + rewrite -(W16.w2bitsK arr.[i]);congr. + apply (eq_from_nth false). + + rewrite size_w2bits size_take // size_drop 1:/# /= /max /=;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + move => k kb; rewrite flatten_take_drop_16 1:/#. + + rewrite size_take 1:/# size_to_list //= 1:/#. + by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite nth_take 1:/#. smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite get_w2bits;congr; rewrite ?get_to_list;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). +case (2 * i < 32 * (offset %/ 256 + 1));last first. ++ move => ? /=. have ? : offset + 256 <= 16*i . smt(). + rewrite get16_init16 1:/# -catA drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite size_take 1:/# size_flatten_W16_w2bits size_to_list /= ifT 1:/#. + have -> : offset + 256 = 16 * ((offset + 256) %/ 16) by smt(). + rewrite drop_flatten_ctt; 1: by smt(mapP W16.size_w2bits). + have -> : 16 * i - offset - 256 = 16 * (i - offset %/ 16 - 16) by smt(). + rewrite drop_flatten_ctt; 1: by smt(mapP W16.size_w2bits mem_drop). + rewrite drop_drop 1,2:/# /= => ?. + rewrite -(W16.w2bitsK arr.[i]);congr. + apply (eq_from_nth false). + + rewrite -map_drop size_take // size_flatten_W16_w2bits size_drop 1:/#; smt(Array256.size_to_list W16.size_w2bits). + move => k kb. + have -> : i - offset %/ 16 - 16 + (offset + 256) %/ 16 = i by smt(). + rewrite -(drop_flatten_ctt 16); 1: smt(mapP W16.size_w2bits). + rewrite flatten_take_drop_16; 1..3: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite get_w2bits;congr; rewrite ?get_to_list;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + ++ move => ?? /=. have ? : offset <= 16*i < offset + 256. smt(). + rewrite -!catA drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite !drop_cat ifT;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite size_take 1:/# size_flatten_W16_w2bits size_to_list /= ifT 1:/#. + rewrite take_cat_le ifT;1: by rewrite size_drop 1:/# size_w2bits /= /max ifT /#. + rewrite -(W16.w2bitsK ((bv \bits16 i - 16 * (offset %/ 256))));congr. + apply (eq_from_nth false). + + rewrite size_take // size_drop 1:/#; smt(Array256.size_to_list W16.size_w2bits). + move => k kb. + rewrite nth_take; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite nth_drop; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite !get_w2bits get_bits16;by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). +qed. + + bind op [W16.t & W256.t & Array256.t] sliceset256_16_256 "asliceset". -realize bvaslicesetP. -move => arr offset bv *. +realize bvaslicesetP. (* bounds are incomplete! 0 <= offset <= 16 * 256 - 256 *) +move => arr offset bv *. have ? : 0 <= offset by admit. rewrite /sliceset256_16_256 of_listK. -+ admit. ++ rewrite size_map size_chunk // !size_cat size_take 1:/#. + by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). rewrite -(map_comp W16.w2bits W16.bits2w) /(\o). have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W16)) idfun (chunk 16 (take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++ drop (offset + 256) (flatten (map W16.w2bits (to_list arr))))). -rewrite iffE => [#] -> *. -+ admit. +rewrite iffE => [#] -> * /=. ++ by smt(in_chunk_size W16.bits2wK). rewrite map_id /= chunkK //. -+ admit. ++ rewrite !size_cat size_take 1:/#. + by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). qed. op sliceget32_8_256 (arr: W8.t Array32.t) (i: int) : W256.t = get256 (WArray32.init8 (fun (i_0 : int) => pvc_shufbidx_s.[i_0])) (i%/256).