Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User/dev/kupadhyayula/gh ntt shuffling #22

Merged
merged 26 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
270176e
Shuffling wip
Sep 17, 2024
dd428ed
NTT shuffling updates
Sep 19, 2024
a7d5e54
WIP INTT shuffling
Sep 23, 2024
c29c827
Merge branch 'main' into user/dev/kupadhyayula/ntt_shuffling_updated
Oct 2, 2024
2d10620
Add shuffle_en wip
Oct 3, 2024
ca825ba
shuffling updates for pwo, intt
Oct 3, 2024
e6423dc
Add shuffle_en (wip)
Oct 4, 2024
55e83ac
Merge branch 'user/dev/kupadhyayula/ntt_shuffling_updated' into user/…
Oct 7, 2024
8f11f47
Merge branch 'main' into user/dev/kupadhyayula/gh_ntt_shuffling
upadhyayulakiran Oct 25, 2024
cac0b8b
Add shuffle_en input to ntt
upadhyayulakiran Oct 28, 2024
6b2cf17
Remove random input and use lfsr bits
upadhyayulakiran Oct 28, 2024
8d56be4
Clean up
upadhyayulakiran Oct 29, 2024
9c3d3e8
Lint cleanup
upadhyayulakiran Oct 29, 2024
1be0663
Fix lint
upadhyayulakiran Oct 29, 2024
89881b4
MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/gh_ntt_shu…
upadhyayulakiran Oct 29, 2024
4f6ce4f
Disable shuffling for now
upadhyayulakiran Oct 29, 2024
4ee1e73
Merge branch 'user/dev/kupadhyayula/gh_ntt_shuffling' of ssh://github…
upadhyayulakiran Oct 29, 2024
f6347ec
MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/gh_ntt_shu…
upadhyayulakiran Oct 29, 2024
1c9e1b0
Remove old TODO
upadhyayulakiran Oct 30, 2024
9d9fd88
Merge branch 'user/dev/kupadhyayula/gh_ntt_shuffling' of ssh://github…
upadhyayulakiran Oct 30, 2024
27e8d87
Merge branch 'main' into user/dev/kupadhyayula/gh_ntt_shuffling
upadhyayulakiran Oct 31, 2024
1a4c362
Merge branch 'main' into user/dev/kupadhyayula/gh_ntt_shuffling
upadhyayulakiran Oct 31, 2024
4008aff
MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/gh_ntt_shu…
upadhyayulakiran Oct 31, 2024
badf093
Merge branch 'main' into user/dev/kupadhyayula/gh_ntt_shuffling
upadhyayulakiran Oct 31, 2024
9e0c797
Merge branch 'user/dev/kupadhyayula/gh_ntt_shuffling' of ssh://github…
upadhyayulakiran Oct 31, 2024
b2ca140
MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/gh_ntt_shu…
upadhyayulakiran Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflow_metadata/pr_hash
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ccfbc1a0345264c8c044b6e441c9757f47467e8478b89f57bf502ad6f3b1b29f3be1047b7ee043d5790c8a24977ceed9
6125c7fa394b48b54b85ad7453c0c374b7853fa525492ba552f8bef665069a6d4878c5906115ca22510e83bc2c9be4dd
2 changes: 1 addition & 1 deletion .github/workflow_metadata/pr_timestamp
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1730388545
1730398117
28 changes: 25 additions & 3 deletions src/mldsa_top/rtl/mldsa_top.sv
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ module mldsa_top
logic [MLDSA_MEM_DATA_WIDTH-1:0] sampler_mem_data;
logic [MLDSA_MEM_ADDR_WIDTH-1:0] sampler_mem_addr;

logic [1:0] sampler_ntt_dv;
logic [1:0] sampler_ntt_dv, sampler_ntt_dv_f;
logic [1:0] sampler_ntt_mode;
logic [1:0] sampler_valid;
logic [COEFF_PER_CLK-1:0][MLDSA_Q_WIDTH-1:0] sampler_ntt_data;
Expand All @@ -117,6 +117,7 @@ module mldsa_top
logic [1:0][MLDSA_MEM_DATA_WIDTH-1:0] pwm_b_rd_data;
logic [1:0] ntt_done;
logic [1:0] ntt_busy;
logic [1:0] shuffle_en;

mem_if_t w1_mem_wr_req;
logic [3:0] w1_mem_wr_data;
Expand Down Expand Up @@ -421,6 +422,18 @@ mldsa_sampler_top sampler_top_inst
.sampler_state_data_o(sampler_state_data)
);

always_ff @(posedge clk or negedge rst_b) begin
if (!rst_b) begin
sampler_ntt_dv_f <= 0;
end
else if (zeroize_reg) begin
sampler_ntt_dv_f <= 0;
end
else begin
sampler_ntt_dv_f <= sampler_ntt_dv;
end
end

assign sampler_ntt_dv[1] = 0; //no sampler interface to secondary ntt

generate
Expand All @@ -432,15 +445,18 @@ generate
accumulate[g_inst] = '0;
sampler_valid[g_inst] = 0;
sampler_ntt_mode[g_inst] = 0;
shuffle_en[g_inst] = 0; //TODO: temp change for testing, remove and add to opcodes

unique case (ntt_mode[g_inst]) inside
MLDSA_NTT_NONE: begin
end
MLDSA_NTT: begin
mode[g_inst] = ct;
// shuffle_en[g_inst] = 1;
end
MLDSA_INTT: begin
mode[g_inst] = gs;
// shuffle_en[g_inst] = 1;
end
MLDSA_PWM_SMPL: begin
mode[g_inst] = pwm;
Expand All @@ -456,19 +472,23 @@ generate
MLDSA_PWM: begin
mode[g_inst] = pwm;
sampler_valid[g_inst] = 1;
// shuffle_en[g_inst] = 1;
end
MLDSA_PWM_ACCUM: begin
mode[g_inst] = pwm;
accumulate[g_inst] = 1;
sampler_valid[g_inst] = 1;
// shuffle_en[g_inst] = 1;
end
MLDSA_PWA: begin
mode[g_inst] = pwa;
sampler_valid[g_inst] = 1;
// shuffle_en[g_inst] = 1;
end
MLDSA_PWS: begin
mode[g_inst] = pws;
sampler_valid[g_inst] = 1;
// shuffle_en[g_inst] = 1;
end
default: begin
end
Expand All @@ -492,6 +512,8 @@ generate
.pwo_mem_base_addr(pwo_mem_base_addr[g_inst]),
.accumulate(accumulate[g_inst]),
.sampler_valid(sampler_valid[g_inst]),
.shuffle_en(shuffle_en[g_inst]),
.random(rand_bits[g_inst*6+5:g_inst*6]),
//NTT mem IF
.mem_wr_req(ntt_mem_wr_req[g_inst]),
.mem_rd_req(ntt_mem_rd_req[g_inst]),
Expand Down Expand Up @@ -930,7 +952,7 @@ always_comb begin
for (int bank = 0; bank < 2; bank++) begin
ntt_mem_re0_bank[0][bank] = (ntt_mem_rd_req[0].rd_wr_en == RW_READ) & (ntt_mem_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (ntt_mem_rd_req[0].addr[0] == bank);
pwo_a_mem_re0_bank[0][bank] = (pwm_a_rd_req[0].rd_wr_en == RW_READ) & (pwm_a_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_a_rd_req[0].addr[0] == bank);
pwo_b_mem_re0_bank[0][bank] = ~sampler_ntt_dv & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_b_rd_req[0].addr[0] == bank);
pwo_b_mem_re0_bank[0][bank] = (shuffle_en[0] ? ~sampler_ntt_dv_f : ~sampler_ntt_dv) & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_b_rd_req[0].addr[0] == bank);

ntt_mem_re0_bank[1][bank] = (ntt_mem_rd_req[1].rd_wr_en == RW_READ) & (ntt_mem_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (ntt_mem_rd_req[1].addr[0] == bank);
pwo_a_mem_re0_bank[1][bank] = (pwm_a_rd_req[1].rd_wr_en == RW_READ) & (pwm_a_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_a_rd_req[1].addr[0] == bank);
Expand Down Expand Up @@ -964,7 +986,7 @@ always_comb begin
end else begin
ntt_mem_re[0][i] = (ntt_mem_rd_req[0].rd_wr_en == RW_READ) & (ntt_mem_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]);
pwo_a_mem_re[0][i] = (pwm_a_rd_req[0].rd_wr_en == RW_READ) & (pwm_a_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]);
pwo_b_mem_re[0][i] = ~sampler_ntt_dv & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]);
pwo_b_mem_re[0][i] = (shuffle_en[0] ? ~sampler_ntt_dv_f : ~sampler_ntt_dv) & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]);

ntt_mem_re[1][i] = (ntt_mem_rd_req[1].rd_wr_en == RW_READ) & (ntt_mem_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]);
pwo_a_mem_re[1][i] = (pwm_a_rd_req[1].rd_wr_en == RW_READ) & (pwm_a_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]);
Expand Down
25 changes: 23 additions & 2 deletions src/ntt_top/Model/maksed_gadgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,35 @@ def one_share_mult(a0, a1, b):
randomness = CustomUnsignedInteger(0, 0, MultMod-1)
#get a random number ranging [0, MultMod-1]
randomness.generate_random()
r1 = int(randomness.value)
r1 = int(randomness.value) #optional
#refresh the shares
a00 = int(a0+r1) % MultMod
a10 = int(a1-r1) % MultMod
c0 = int(a00*b) % MultMod
c1 = int(a10*b) % MultMod
return c0, c1

def two_share_mult(a0, a1, b0, b1):
# Construct randomness class
randomness = CustomUnsignedInteger(0, 0, MultMod-1)
#get a random number ranging [0, MultMod-1]
randomness.generate_random()
r1 = int(randomness.value)
# #refresh the shares
# a00 = int(a0+r1) % MultMod
# a10 = int(a1-r1) % MultMod
# c0 = int(a00*b) % MultMod
# c1 = int(a10*b) % MultMod
c0 = int(a0*b0) % MultMod
d0 = int(a1*b0) % MultMod
c1 = int(a0*b1) % MultMod
d1 = int(a1*b1) % MultMod
e0 = int(c1+r1) % MultMod
e1 = int(d0-r1) % MultMod
final_res0 = int(c0+e0) % MultMod
final_res1 = int(d1+e1) % MultMod
return final_res0, final_res1

def maskedAdder(a0, a1, b0, b1):
# Construct randomness class
randomness = CustomUnsignedInteger(0, 0, MultMod-1)
Expand Down Expand Up @@ -205,7 +226,7 @@ def unMaskedModAddition(x0, x1, y0, y1):
y = y0 ^ y1
z = (x+y) % DILITHIUM_Q
randomness = CustomUnsignedInteger(0, 0, (2**23)-1)
# Refresh the shares
# Refresh the shares -- not optional
randomness.generate_random()
r1 = int(randomness.value)
z0 = z ^ r1
Expand Down
31 changes: 31 additions & 0 deletions src/ntt_top/Model/testForMasking.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,40 @@ def test_maskedBFU_CT(numTest = 10):
if vNew != exp_v:
print(f"CT Lower branch gives an Error; gotten = {vNew}, while exp = {exp_v}")

def test_twoshare_mult(numTest = 10):
randomness = CustomUnsignedInteger(0, 0, MultMod-1)
operands = CustomUnsignedInteger(0, 0, DILITHIUM_Q-1)
for i in range(0, numTest):
#get a random number ranging [0, DILITHIUM_Q-1]
#generate inputs
operands.generate_random()
u = int(operands.value)
operands.generate_random()
v = int(operands.value)
#calculate expected result
exp_uv = u*v
#Split inputs to shares
randomness.generate_random()
r0 = int(randomness.value)
u0 = int(u-r0) % MultMod
u1 = r0
randomness.generate_random()
r1 = int(randomness.value)
v0 = int(v-r1) % MultMod
v1 = r1
#Test two share mult
uv0, uv1 = two_share_mult(u0, u1, v0, v1)
uv = int(uv0 + uv1) % MultMod
#Check result
if uv != exp_uv:
print(f"Incorrect mult op. Operands: {u, v}, Shares: {u0, u1, v0, v1} Exp = {exp_uv}, actual = {uv}")




test_maskedBFU_CT(numTest = 100000)
test_maskedBFU_GS(numTest = 100000)
test_twoshare_mult(numTest = 100000)

def test_masked_inv_NTT2x2_div2(numTest = 10):
for test_i in range(numTest):
Expand Down
1 change: 1 addition & 0 deletions src/ntt_top/config/compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ targets:
- $COMPILE_ROOT/rtl/ntt_special_adder.sv
- $COMPILE_ROOT/rtl/ntt_div2.sv
- $COMPILE_ROOT/rtl/ntt_buffer.sv
- $COMPILE_ROOT/rtl/ntt_shuffle_buffer.sv
- $COMPILE_ROOT/rtl/ntt_twiddle_lookup.sv
- $COMPILE_ROOT/rtl/ntt_ctrl.sv
- $COMPILE_ROOT/rtl/ntt_top.sv
Expand Down
Loading