Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmucb committed Apr 6, 2024
1 parent 057f674 commit 1821d07
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 25 deletions.
47 changes: 29 additions & 18 deletions v2/dilithium/dilithium.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,11 @@ bool module_vector_subtract(module_vector& in1, module_vector& in2, module_vecto
module_vector neg_in2(in2.q_, in2.n_, in2.dim_);
for (int i = 0; i < in2.dim_; i++) {
for (int j = 0; j < in2.n_; j++) {
neg_in2.c_[i]->c_[j] = (in2.q_ - in2.c_[i]->c_[j]) % in2.q_;
int t = in2.c_[i]->c_[j];
if (t < 0)
neg_in2.c_[i]->c_[j] = (-t) % in2.q_;
else
neg_in2.c_[i]->c_[j] = in2.q_ - (t % in2.q_);
}
}
return module_vector_add(in1, neg_in2, out);
Expand Down Expand Up @@ -336,18 +340,16 @@ int module_inf_norm(module_vector& mv) {
}

int high_bits(int x, int a) {
// x = x_high*2*a + x_low
// x = x_high*a + x_low
x = abs(x); // check
return x / (2 * a);
return x / a;
}

int low_bits(int x, int a) {
// x = x_high*2*a + x_low
x = abs(x); // check
int k = x / (2 * a);
int y = x - (k * 2 * a);
if (y >= 2*a)
printf("Huh?\n");
int k = x / a;
int y = x - (k * a);
return y;
}

Expand Down Expand Up @@ -546,7 +548,7 @@ bool dilithium_keygen(dilithium_parameters& params, module_array* A,
// }
bool dilithium_sign(dilithium_parameters& params, module_array& A, module_vector& t,
module_vector& s1, module_vector& s2, int m_len, byte* M,
module_vector* z, int len_c, byte* c) {
module_vector* z, int len_c, byte* c, int len_cc, int* cc) {

// y: dim l_
// tv1 = Ay, dim k
Expand All @@ -561,6 +563,10 @@ bool dilithium_sign(dilithium_parameters& params, module_array& A, module_vect
t.dim_, s1.dim_, s2.dim_, z->dim_);
return false;
}
if (len_cc != 256) {
printf("sign: cc wrong size %d\n", len_cc);
return false;
}

bool done = false;
memset(c, 0, len_c);
Expand Down Expand Up @@ -631,7 +637,6 @@ bool dilithium_sign(dilithium_parameters& params, module_array& A, module_vect
return false;
}

int cc[256];
memset((byte*)cc, 0, 256 * sizeof(int));
if (!c_from_h(32, c, cc)) {
printf("sign: c_from_h\n");
Expand Down Expand Up @@ -668,7 +673,7 @@ bool dilithium_sign(dilithium_parameters& params, module_array& A, module_vect
print_module_vector(y);
printf("\ntu1:\n");
print_module_vector(tu1);
printf("z:\n");
printf("\nz:\n");
print_module_vector(*z);
#endif
int inf = module_inf_norm(*z);
Expand All @@ -680,6 +685,7 @@ bool dilithium_sign(dilithium_parameters& params, module_array& A, module_vect
#if 1
printf("sign: compare 1 failed\n");
#else
printf("sign: compare 1 failed\n");
continue;
#endif
}
Expand All @@ -701,10 +707,11 @@ bool dilithium_sign(dilithium_parameters& params, module_array& A, module_vect
return false;
}
#if 1
printf("tv2:\n");
printf("\ntu2:\n");
print_module_vector(tu2);
printf("\ntv2:\n");
print_module_vector(tv2);
printf("\n");
printf("w3, 2 * params.gamma_2_: %d\n", 2 * params.gamma_2_);
printf("\nw3, 2 * params.gamma_2_: %d\n", 2 * params.gamma_2_);
print_module_vector(w3);
printf("\n");
#endif
Expand All @@ -714,7 +721,10 @@ bool dilithium_sign(dilithium_parameters& params, module_array& A, module_vect
low, params.gamma_2_ - params.beta_);
#endif
if (low >= (params.gamma_2_ - params.beta_)) {
#if 0
#if 1
printf("compare 2 fail\n");
#else
printf("compare 2 fail\n");
continue;
#endif
}
Expand All @@ -729,8 +739,12 @@ bool dilithium_sign(dilithium_parameters& params, module_array& A, module_vect
// return ||z||_inf < g1-beta and c == H(M||w1)
bool dilithium_verify(dilithium_parameters& params, module_array& A,
module_vector& t, int m_len, byte* M,
module_vector& z, int len_c, byte* c) {
module_vector& z, int len_c, byte* c, int len_cc, int* cc) {

if (len_cc != 256) {
printf("verify: cc len wrong %d\n", len_cc);
return false;
}

// tv1 = Az, dim k
// tu = ct, dim k
Expand All @@ -757,7 +771,6 @@ bool dilithium_verify(dilithium_parameters& params, module_array& A,
module_vector tu(params.q_, params.n_, params.k_);
module_vector w1(params.q_, params.n_, params.k_);
coefficient_vector c_poly(params.q_, params.n_);
int cc[256];

H.add_to_hash(m_len, M);
// this is not quite right
Expand All @@ -775,8 +788,6 @@ bool dilithium_verify(dilithium_parameters& params, module_array& A,
return false;
}

if (!c_from_h(32, c, cc))
return false;
for (int i = 0; i < c_poly.len_; i++) {
c_poly.c_[i] = cc[i];
}
Expand Down
14 changes: 9 additions & 5 deletions v2/dilithium/test_dilithium.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,16 +418,20 @@ bool test_dilithium1() {
printf("\n");
}

module_vector z(params.q_, params.n_, params.l_);
int len_c = 32;
byte c[len_c];
if (!dilithium_sign(params, A, t, s1, s2, m_len, M, &z, len_c, c)) {
int len_cc = 256;
int cc[256];
memset(c, 0, len_c);
memset(cc, 0, len_cc * sizeof(int));

module_vector z(params.q_, params.n_, params.l_);
if (!dilithium_sign(params, A, t, s1, s2, m_len, M, &z, len_c, c, len_cc, cc)) {
printf("dilithium_sign failed\n");
return false;
}

//if (FLAGS_print_all) {
if (0) {
if (FLAGS_print_all) {
printf("\nz:\n");
print_module_vector(z);
printf("\n");
Expand All @@ -437,7 +441,7 @@ bool test_dilithium1() {
}
return true;

if (dilithium_verify(params, A, t, m_len, M, z, len_c, c)) {
if (dilithium_verify(params, A, t, m_len, M, z, len_c, c, len_cc, cc)) {
printf("dilithium_verify succeeded\n");
} else {
printf("dilithium_verify failed\n");
Expand Down
5 changes: 3 additions & 2 deletions v2/include/dilithium.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ bool dilithium_keygen(dilithium_parameters& params, module_array* A,
module_vector* t, module_vector* s1, module_vector* s2);
bool dilithium_sign(dilithium_parameters& params, module_array& A,
module_vector& t, module_vector& s1, module_vector& s2,
int m_len, byte* M, module_vector* z, int len_c, byte* c);
int m_len, byte* M, module_vector* z, int len_c, byte* c,
int len_cc, int* cc);
bool dilithium_verify(dilithium_parameters& params, module_array& A,
module_vector& t, int m_len, byte* M,
module_vector& z, int len_c, byte* c);
module_vector& z, int len_c, byte* c, int len_cc, int* cc);
#endif

0 comments on commit 1821d07

Please sign in to comment.