Skip to content

Commit

Permalink
ntt works
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmucb committed Apr 16, 2024
1 parent aad2c9a commit 5b01bce
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 35 deletions.
48 changes: 23 additions & 25 deletions v2/kyber/kyber.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@ byte bit_reverse(byte b) {

bool ntt_base_mult(int q, int g, int& in1a, int& in1b,
int& in2a, int& in2b, int* outa, int* outb) {
int u0 = ((in1a * in2a) % q + (((in1b * in2b) %q) * g) % q) % q;
int u1 = (((in1a * in2b) % q) + ((in2a * in1a) %q)) % q;
*outa = u0;
*outb = u1;
int u1 = ((in1a * in2a) % q + (((in1b * in2b) %q) * g) % q) % q;
int u2 = (((in1a * in2b) % q) + ((in2a * in1a) %q)) % q;
*outa = u1;
*outb = u2;
return true;
}

Expand Down Expand Up @@ -407,7 +407,7 @@ bool sample_ntt(int q, int l, int b_len, byte* b, int* out_len, short int* out)
}

bool sample_poly_cbd(int q, int eta, int l, int b_len, byte* b,
int* out_len, short int* out) {
int* out_len, short int* out) {
if (b_len * NBITSINBYTE < l)
return false;

Expand Down Expand Up @@ -435,19 +435,18 @@ bool ntt(int g, coefficient_vector& in, coefficient_vector* out) {
int k = 1;
coefficient_set_vector(in, out);

for (int l = 128; l >= 2; l /=2) {
for (int l = 128; l >= 2; l /= 2) {
for (int s = 0; s < in.len_; s+= 2 * l) {
byte bb = bit_reverse((byte)k);
bb >>= 1;
int z = exp_in_ntt(in.q_, (int) bb, g);
k++;
for (int j = s; j < (s + l); j++) {
// FIX
//int t = (z * read_ntt(out->c_, j + l)) % in.q_;
//int s1 = (read_ntt(out->c_, j) + (in.q_ - t)) % in.q_;
//write_ntt(j + l, s1, &out->c_);
//short int s2 = (read_ntt(out->c_, j) + t) % in.q_;
//write_ntt(j, s2, &out->c_);
int t = (z * out->c_[j + l]) % in.q_;
int s1 = (out->c_[j] + (in.q_ - t)) % in.q_;
out->c_[j + l]= s1;
int s2 = (out->c_[j] + t) % in.q_;
out->c_[j] = s2;
}
}
}
Expand All @@ -465,15 +464,14 @@ bool ntt_inv(int g, coefficient_vector& in, coefficient_vector* out) {
for (int s = 0; s < 256; s += 2 * l) {
byte bb = bit_reverse((byte)k);
bb >>= 1;
int z = exp_in_ntt(in.q_, bb, g);
int z = exp_in_ntt(in.q_, (int)bb, g);
k--;
for (int j = s; j < s + l; j += 2 * l) {
// FIX
// short int t = read_ntt(out->c_, j);
// short int s1 = (in.q_ + (z * read_ntt(out->c_, j + l)) - t ) % in.q_;
// write_ntt(j, s1, &out->c_);
// short int s2 = (in.q_ + ((z * read_ntt(out->c_, j + l) % in.q_)) - t ) % in.q_;
// write_ntt(j + l, s2, &out->c_);
for (int j = s; j < s + l; j++) {
int t = out->c_[j];
int s1 = (t + out->c_[j+l]) % in.q_;
out->c_[j] = s1;
int s2 = (z * (out->c_[j + l] + in.q_ - t)) % in.q_;
out->c_[j + l] = s2;
}
}
}
Expand All @@ -483,21 +481,21 @@ bool ntt_inv(int g, coefficient_vector& in, coefficient_vector* out) {
return true;
}

bool ntt_mult(short int g, coefficient_vector& in1, coefficient_vector& in2, coefficient_vector* out) {
bool ntt_mult(int g, coefficient_vector& in1, coefficient_vector& in2, coefficient_vector* out) {
if (in1.len_ != in2.len_ || in1.len_ != out->len_)
return false;
if (in1.q_ != in2.q_ || in1.q_ != out->q_)
return false;
int t = 0;
for (int i = 0; i < in1.len_; i++) {
for (int i = 0; i < in1.len_; i += 2) {
int e = (short int)(bit_reverse(i/2)>>1);
e *= 2;
e += 1;
t = exp_in_ntt(in1.q_, e, g);
#if 0
if (!ntt_base_mult(in1.q_, g, in1.c_[i], in2.c_[i], &(out->c_[i])))
if (!ntt_base_mult(in1.q_, g, in1.c_[i], in1.c_[i + 1],
in2.c_[i], in2.c_[i + 1],
&(out->c_[i]), &(out->c_[i + 1])))
return false;
#endif
}
return true;
}
Expand Down
14 changes: 4 additions & 10 deletions v2/kyber/test_kyber.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,10 @@ bool test_kyber_support() {
print_kyber_parameters(p);

int g = 17;
int i1a = 0;
int i2a = 0;
int i1b = 0;
int i2b = 0;
int i1a = 2;
int i2a = 5;
int i1b = 1;
int i2b = 3;
int oa = 0;
int ob = 0;
if (!ntt_base_mult(p.q_, g, i1a, i1b, i2a, i2b, &oa, &ob)) {
Expand Down Expand Up @@ -386,20 +386,14 @@ bool test_kyber_support() {
printf("\n");
}

#if 0
for (int i = 0; i < 256; i++) {
if (ntt_in.c_[i] != ntt_inv_out.c_[i]) {
printf("input and ntt_inv(ntt(input)) do not match at %d\n", i);
return false;
}
}
#endif

/*
if (!ntt_add(coefficient_vector& in1, coefficient_vector& in2, coefficient_vector* out)) {
printf("Could not inverse ntt_add\n");
return false;
}
if (!ntt_mult(short int g, coefficient_vector& in1, coefficient_vector& in2, coefficient_vector* out)) {
printf("Could not inverse ntt_mult\n");
return false;
Expand Down

0 comments on commit 5b01bce

Please sign in to comment.