diff --git a/v2/include/kyber.h b/v2/include/kyber.h index 73c6bee..8e43e86 100644 --- a/v2/include/kyber.h +++ b/v2/include/kyber.h @@ -165,7 +165,7 @@ bool ntt_module_vector_dot_product_first_transposed(module_vector& in1, void print_module_vector(module_vector& mv); bool ntt_module_apply_array(int g, module_array& A, module_vector& v, module_vector* out); -bool ntt_module_apply_transpose_array(int g, module_array& A, module_vector& v, module_vector* out); +bool ntt_module_apply_transposed_array(int g, module_array& A, module_vector& v, module_vector* out); void print_kyber_parameters(kyber_parameters& p); diff --git a/v2/kyber/kyber.cc b/v2/kyber/kyber.cc index 776a2a8..6e16853 100644 --- a/v2/kyber/kyber.cc +++ b/v2/kyber/kyber.cc @@ -360,7 +360,7 @@ bool ntt_module_apply_array(int g, module_array& A, module_vector& v, module_vec return true; } -bool ntt_module_apply_array_transpose(int g, module_array& A, module_vector& v, module_vector* out) { +bool ntt_module_apply_transposed_array(int g, module_array& A, module_vector& v, module_vector* out) { if ((A.nc_ != v.dim_) || A.nr_ != out->dim_) { printf("mismatch, nc: %d, v: %d, nr: %d, out: %d\n", A.nc_, v.dim_, A.nr_, out->dim_); return false; @@ -1325,10 +1325,9 @@ bool kyber_encrypt(int g, kyber_parameters& p, int ek_len, byte* ek, printf("\ne2:\n"); print_coefficient_vector(e2); printf("\n"); - return true; #endif - if (!ntt_module_apply_transpose_array(g, A_ntt, r_ntt, &s)) { + if (!ntt_module_apply_transposed_array(g, A_ntt, r_ntt, &s)) { printf("kyber_encrypt: ntt_module_apply_transpose_array) failed\n"); return false; } @@ -1352,13 +1351,29 @@ bool kyber_encrypt(int g, kyber_parameters& p, int ek_len, byte* ek, coefficient_vector mu(p.q_, p.n_); coefficient_vector nu(p.q_, p.n_); coefficient_vector nu_ntt(p.q_, p.n_); - if (!coefficient_vector_zero(&nu)) { + if (!coefficient_vector_zero(&mu)) { return false; } - if (!coefficient_vector_zero(&nu_ntt)) { + // mu = decompress(1, byte_decode(m)) + if (!byte_decode_to_vector(1, p.n_, m_len, m, mu.c_)) { + return false; + } + for (int i = 0; i < p.n_; i++) { + mu.c_[i] = decompress(p.q_, mu.c_[i], 1); + } + +#if 1 + printf("m: "); + print_bytes(m_len, m); + printf("mu:\n"); + print_coefficient_vector(mu); + return true; +#endif + + if (!coefficient_vector_zero(&nu)) { return false; } - if (!coefficient_vector_zero(&mu)) { + if (!coefficient_vector_zero(&nu_ntt)) { return false; } if (!ntt_module_vector_dot_product(t_ntt, r_ntt, &nu_ntt)) { @@ -1374,9 +1389,6 @@ bool kyber_encrypt(int g, kyber_parameters& p, int ek_len, byte* ek, return false; } - if (!byte_decode_to_vector(1, p.n_, m_len, m, mu.c_)) { - return false; - } module_vector compressed_u(p.q_, p.n_, p.k_); for (int i = 0; i < p.k_; i++) { diff --git a/v2/kyber/test_kyber.cc b/v2/kyber/test_kyber.cc index 102ce98..e0be130 100644 --- a/v2/kyber/test_kyber.cc +++ b/v2/kyber/test_kyber.cc @@ -73,6 +73,10 @@ bool test_kyber1() { printf("wrong return from crypto_get_random_bytes\n"); return false; } + m[1] = 0xff; + m[3] = 0x0f; + m[5] = 0x10; + m[7] = 0x11; if (!kyber_encrypt(g, p, ek_len, ek, m_len, m, b_r_len, b_r, &c_len, c)) {