Skip to content

Commit

Permalink
ntt transform
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmucb committed Apr 12, 2024
1 parent a554e9f commit f7c2a73
Showing 1 changed file with 74 additions and 4 deletions.
78 changes: 74 additions & 4 deletions v2/kyber/kyber.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,24 @@ byte bit_reverse(byte b) {
}

bool ntt_base_mult(short int q, short int g, int& in1, int& in2, int* out) {
// out[1] = (in1[0] * in2[1] + in1[1] * in2[0]) % q;
// out[0] = (((in1[0] * in2[0]) % q) + ((g * in1[1]) % q) * in2[1]) % q;
short int s1 = (short int) (in1 & 0xffff);
short int s2 = (short int) ((in1>>16) & 0xffff);
short int t1 = (short int) (in2 & 0xffff);
short int t2 = (short int) ((in2>>16) & 0xffff);
short int u1 = ((t1 * s1) % q + (g * t2 * s2) % q) % q;
short int u2 = (s1 * t2 + t1 * s2) % q;
*out = ((int) u2) << 16 | ((int) u1);
return true;
}

bool ntt_base_add(short int q, short int g, int& in1, int& in2, int* out) {
short int s1 = (short int) (in1 & 0xffff);
short int s2 = (short int) ((in1>>16) & 0xffff);
short int t1 = (short int) (in2 & 0xffff);
short int t2 = (short int) ((in2>>16) & 0xffff);
short int u1 = (t1 + s1) % q;
short int u2 = (t2 + s2) % q;
*out = ((int) u2) << 16 | ((int) u1);
return true;
}

Expand All @@ -495,14 +511,68 @@ bool sample_poly_cbd(int q, int eta, int l, short int* out) {
return false;
}

short int read_ntt(vector<int> x, int m) {
int t = m / 2;
if ((m&1)==0)
return (short int)x[t];
else
return (short int)(x[t]>>16);
}

void write_ntt(int m, short int y, vector<int>& x) {
int t = m / 2;
short int t1 = (short int) (x[m] & 0xffff);
short int t2 = (short int) ((x[m]>>16) & 0xffff);
if ((m&1)==0)
t1 = y;
else
t2 = y;
x[m] = ((int) t2) << 16 | (int) t1;
}

// ntt representation of f= f0 + f_1x + ... is
// [ f mod (x^2-g^2Rev(0)+1, f mod (x^2-g^2Rev(1)+1,..., f mod (x^2-g^2Rev(127)+1) ]
bool ntt(short int g, coefficient_vector& in, coefficient_vector* out) {
return false;
int k = 1;
coefficient_set_vector(in, out);

for (int l = 128; l >= 2; l /=2) {
for (int s = 0; s < 256; s+= 2*l) {
byte bb = bit_reverse((byte)k);
bb >>= 1;
short int z = exp_in_ntt((short int) in.q_, (short int) bb, g);
k++;
for (int j = 0; j < s + l; j++) {
short int t = (z * read_ntt(out->c_, j+l)) % in.q_;
write_ntt(j + l, (read_ntt(out->c_, j) - t) % in.q_, out->c_);
write_ntt(j, (read_ntt(out->c_, j) + t) % in.q_, out->c_);
}
}
}
return true;
}

bool ntt_inv(short int g, coefficient_vector& in, coefficient_vector* out) {
return false;
int k = 127;
coefficient_set_vector(in, out);

for (int l = 2; l <= 128; l *= 2) {
for (int s = 0; s < 256; s += 2 * l) {
byte bb = bit_reverse((byte)k);
bb >>= 1;
short int z = exp_in_ntt((short int) in.q_, (short int) bb, g);
k--;
for (int j = s; j < s + l; s += 2 * l) {
short int t = read_ntt(out->c_, j);
write_ntt(j, (t + read_ntt(out->c_, j + l)) % in.q_, out->c_);
write_ntt(j + l, (in.q_ + (z * read_ntt(out->c_, j + l)) - t )% in.q_, out->c_);
}
}
}
for (int i = 0; i < 256; i++) {
out->c_[i] = (out->c_[i] * 3303) % in.q_;
}
return true;
}

bool ntt_add(coefficient_vector& in1, coefficient_vector& in2, coefficient_vector* out) {
Expand Down

0 comments on commit f7c2a73

Please sign in to comment.