diff --git a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec index e71271e9..3e365655 100644 --- a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec +++ b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec @@ -1269,6 +1269,8 @@ qed. require import WArray1088 WArray1536 Array4. +print Jkem.M. + module AuxPolyVecCompress10 = { proc avx2_orig(ctp : W64.t, bp : W16.t Array768.t) : WArray1088.t = { bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp); @@ -1405,9 +1407,37 @@ proc __polyvec_compress_ref(a : W16.t Array768.t) : WArray1088.t = { return rr; } +proc __poly_reduce(rp : W16.t Array256.t) : W16.t Array256.t = { + var j : int; + var t : W16.t; + + j <- 0; + while (j < 256){ + t <- rp.[j]; + t <@ M(Syscall).__barrett_reduce(t); + rp.[j] <- t; + j <- j + 1; + } + + return rp; + } + + proc __polyvec_reduce(r : W16.t Array768.t) : W16.t Array768.t = { + var aux : W16.t Array256.t; + + aux <@ __poly_reduce((init (fun (i : int) => r.[0 + i]))%Array256); + r <- (init (fun (i : int) => if 0 <= i && i < 0 + 256 then aux.[i - 0] else r.[i]))%Array768; + aux <@ __poly_reduce((init (fun (i : int) => r.[256 + i]))%Array256); + r <- (init (fun (i : int) => if 256 <= i && i < 256 + 256 then aux.[i - 256] else r.[i]))%Array768; + aux <@ __poly_reduce((init (fun (i : int) => r.[2 * 256 + i]))%Array256); + r <- (init (fun (i : int) => if 2 * 256 <= i && i < 2 * 256 + 256 then aux.[i - 2 * 256] else r.[i]))%Array768; + + return r; + } + proc ref(bp : W16.t Array768.t) : WArray1088.t = { var rr : WArray1088.t; - bp <@ Jkem.M(Jkem_avx2.Syscall).__polyvec_reduce(bp); + bp <@ __polyvec_reduce(bp); rr <@ __polyvec_compress_ref(bp); return rr; }