From 270176eac5ef480fb938833ff8563307f7b0e0cc Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Tue, 17 Sep 2024 12:17:23 -0700 Subject: [PATCH 01/17] Shuffling wip --- src/ntt_top/config/compile.yml | 1 + src/ntt_top/rtl/ntt_ctrl.sv | 142 +++++++++++++++++++++++--- src/ntt_top/rtl/ntt_shuffle_buffer.sv | 111 ++++++++++++++++++++ src/ntt_top/rtl/ntt_top.sv | 41 +++++++- src/ntt_top/tb/ntt_top_tb.sv | 47 ++++++++- src/ntt_top/tb/ntt_wrapper.sv | 4 +- 6 files changed, 323 insertions(+), 23 deletions(-) create mode 100644 src/ntt_top/rtl/ntt_shuffle_buffer.sv diff --git a/src/ntt_top/config/compile.yml b/src/ntt_top/config/compile.yml index da09037..3e22fb1 100755 --- a/src/ntt_top/config/compile.yml +++ b/src/ntt_top/config/compile.yml @@ -51,6 +51,7 @@ targets: - $COMPILE_ROOT/rtl/ntt_mult_reduction.sv - $COMPILE_ROOT/rtl/ntt_div2.sv - $COMPILE_ROOT/rtl/ntt_buffer.sv + - $COMPILE_ROOT/rtl/ntt_shuffle_buffer.sv - $COMPILE_ROOT/rtl/ntt_twiddle_lookup.sv - $COMPILE_ROOT/rtl/ntt_ctrl.sv - $COMPILE_ROOT/rtl/ntt_top.sv diff --git a/src/ntt_top/rtl/ntt_ctrl.sv b/src/ntt_top/rtl/ntt_ctrl.sv index 7dc475a..14a5be5 100644 --- a/src/ntt_top/rtl/ntt_ctrl.sv +++ b/src/ntt_top/rtl/ntt_ctrl.sv @@ -31,7 +31,9 @@ module ntt_ctrl parameter MLDSA_Q_DIV2_ODD = (MLDSA_Q+1)/2, parameter MLDSA_N = 256, parameter MLDSA_LOGN = 8, - parameter MEM_ADDR_WIDTH = 15 + parameter MEM_ADDR_WIDTH = 15, + parameter BF_LATENCY = 10, //5 cycles per butterfly * 2 instances in serial = 10 clks + parameter NTT_BUF_LATENCY = 4 ) ( input wire clk, @@ -53,10 +55,13 @@ module ntt_ctrl // input wire [MEM_ADDR_WIDTH-1:0] pw_base_addr_b, // input wire [MEM_ADDR_WIDTH-1:0] pw_base_addr_c, //result input pwo_mem_addr_t pwo_mem_base_addr, + input wire [5:0] random, //4+2 bits output logic bf_enable, output logic buf_wren, output logic buf_rden, + output logic [1:0] buf_wrptr, + output logic [1:0] buf_rdptr, output logic [6:0] twiddle_addr, output logic [MEM_ADDR_WIDTH-1:0] mem_rd_addr, @@ -100,6 +105,16 @@ logic buf_wr_rst_count_ntt, buf_rd_rst_count_ntt; logic buf_wr_rst_count_intt, buf_rd_rst_count_intt; logic [1:0] buf_count; +//Shuffle buffer signals +logic [3:0] chunk_rand_offset; +logic [3:0] chunk_count; //goes from 1 to 16 +logic [1:0] index_rand_offset; +logic [BF_LATENCY-1:0][1:0] index_rand_offset_reg, buf_rdptr_reg; +logic [BF_LATENCY-1:0][3:0] chunk_count_reg; +logic latch_chunk_rand_offset, latch_index_rand_offset; +logic last_rd_addr, last_wr_addr; +logic mem_wr_en_fsm, mem_wr_en_reg; + //Mode flags logic ct_mode, gs_mode, pwo_mode; //point-wise operations mode logic pwm_mode, pwa_mode, pws_mode; @@ -248,8 +263,9 @@ end always_comb begin mem_rd_base_addr = (rounds_count == 'h0) ? src_base_addr : rounds_count[0] ? interim_base_addr : dest_base_addr; mem_wr_base_addr = rounds_count[0] ? dest_base_addr : interim_base_addr; - mem_rd_addr_nxt = mem_rd_addr + rd_addr_step; - mem_wr_addr_nxt = mem_wr_addr + wr_addr_step; + mem_rd_addr_nxt = mem_rd_addr + rd_addr_step; //TODO gs, pwo modes + // mem_wr_addr_nxt = mem_wr_addr + wr_addr_step; + mem_wr_addr_nxt = ct_mode ? (4*(chunk_count_reg[0])) + (wr_addr_step*buf_rdptr_reg[0]) + mem_wr_base_addr : mem_wr_addr + wr_addr_step; //TODO: gs, pwo modes rd_addr_wraparound = mem_rd_addr_nxt > {1'b0,mem_rd_base_addr} + MEM_LAST_ADDR; wr_addr_wraparound = mem_wr_addr_nxt > {1'b0,mem_wr_base_addr} + MEM_LAST_ADDR; end @@ -263,10 +279,10 @@ always_ff @(posedge clk or negedge reset_n) begin mem_rd_addr <= 'h0; end else if (rst_rd_addr) begin - mem_rd_addr <= mem_rd_base_addr; + mem_rd_addr <= ct_mode ? mem_rd_base_addr + chunk_rand_offset : mem_rd_base_addr; //TODO: gs, pwo end else if (incr_mem_rd_addr) begin - mem_rd_addr <= rd_addr_wraparound ? MEM_ADDR_WIDTH'(mem_rd_addr_nxt - MEM_LAST_ADDR) : mem_rd_addr_nxt[MEM_ADDR_WIDTH-1:0]; + mem_rd_addr <= last_rd_addr ? mem_rd_base_addr : rd_addr_wraparound ? MEM_ADDR_WIDTH'(mem_rd_addr_nxt - MEM_LAST_ADDR) : mem_rd_addr_nxt[MEM_ADDR_WIDTH-1:0]; end end @@ -279,7 +295,7 @@ always_ff @(posedge clk or negedge reset_n) begin mem_wr_addr <= 'h0; end else if (rst_wr_addr) begin - mem_wr_addr <= mem_wr_base_addr; + mem_wr_addr <= ct_mode ? mem_wr_base_addr + (4*chunk_rand_offset) : mem_wr_base_addr; //TODO: gs, pwo end else if (incr_mem_wr_addr) begin mem_wr_addr <= wr_addr_wraparound ? MEM_ADDR_WIDTH'(mem_wr_addr_nxt - MEM_LAST_ADDR) : mem_wr_addr_nxt[MEM_ADDR_WIDTH-1:0]; @@ -323,27 +339,33 @@ end //------------------------------------------ //Twiddle addr logic //------------------------------------------ +logic [6:0] twiddle_rand_offset; always_comb begin unique case(rounds_count) 'h0: begin twiddle_end_addr = ct_mode ? 'd0 : 'd63; twiddle_offset = 'h0; + twiddle_rand_offset = 'h0; end 'h1: begin twiddle_end_addr = ct_mode ? 'd3 : 'd15; twiddle_offset = ct_mode ? 'd1 : 'd64; + twiddle_rand_offset = ct_mode ? /*(chunk_count_reg[BF_LATENCY-1] % 'd4)*/ 'd0 : 'h0; //TODO gs mode end 'h2: begin twiddle_end_addr = ct_mode ? 'd15 : 'd3; twiddle_offset = ct_mode ? 'd5 : 'd80; + twiddle_rand_offset = ct_mode ? (chunk_count_reg[BF_LATENCY-1] % 'd4)*'d4 : 'h0; end 'h3: begin twiddle_end_addr = ct_mode ? 'd63 : 'd0; twiddle_offset = ct_mode ? 'd21 : 'd84; + twiddle_rand_offset = ct_mode ? (chunk_count_reg[BF_LATENCY-1] % 'd16)*4 : 'h0; end default: begin twiddle_end_addr = 'h0; twiddle_offset = 'h0; + twiddle_rand_offset = 'h0; end endcase end @@ -367,7 +389,7 @@ always_ff @(posedge clk or negedge reset_n) begin else if (incr_twiddle_addr) twiddle_addr_reg <= (twiddle_addr_reg == twiddle_end_addr) ? 'h0 : twiddle_addr_reg + 'd1; else if (rst_twiddle_addr) - twiddle_addr_reg <= 'h0; + twiddle_addr_reg <= ct_mode ? twiddle_rand_offset : 'h0; //TODO: gs mode end assign twiddle_addr = twiddle_addr_reg + twiddle_offset; @@ -428,6 +450,90 @@ always_ff @(posedge clk or negedge reset_n) begin buf_count <= 'h0; end +//------------------------------------------ +//Shuffle buffer +//------------------------------------------ +always_ff @(posedge clk or negedge reset_n) begin + if (!reset_n) begin + chunk_rand_offset <= 'h0; + chunk_count <= 'h0; + end + else if (zeroize) begin + chunk_rand_offset <= 'h0; + chunk_count <= 'h0; + end + else if (latch_chunk_rand_offset) begin + chunk_rand_offset <= random[5:2]; + chunk_count <= {1'b0, random[5:2]}; + end + else if (buf_count == 'h3) begin //update chunk after every 4 cycles + chunk_count <= (chunk_count == 'hf) ? 'h0 : chunk_count + 'h1; + end +end + +always_ff @(posedge clk or negedge reset_n) begin + if (!reset_n) begin + index_rand_offset <= 'h0; + end + else if (zeroize) begin + index_rand_offset <= 'h0; + + end + else if (latch_index_rand_offset) begin + index_rand_offset <= random[1:0]; + + end +end + +always_ff @(posedge clk or negedge reset_n) begin + if (!reset_n) begin + buf_rdptr_reg <= 'h0; + end + else if (zeroize) begin + buf_rdptr_reg <= 'h0; + end + else if (buf_rden | butterfly_ready) begin + buf_rdptr_reg <= {buf_rdptr, buf_rdptr_reg[BF_LATENCY-1:1]}; + end +end + +always_ff @(posedge clk or negedge reset_n) begin + if (!reset_n) begin + chunk_count_reg <= 'h0; + end + else if (zeroize) begin + chunk_count_reg <= 'h0; + end + else if (buf_rden | butterfly_ready) begin + chunk_count_reg <= {chunk_count, chunk_count_reg[BF_LATENCY-1:1]}; + end +end + +always_ff @(posedge clk or negedge reset_n) begin + if (!reset_n) begin + buf_wrptr <= 'h0; + end + else if (zeroize) begin + buf_wrptr <= 'h0; + end + else if (buf_wren & ct_mode) begin //ct mode - buf writes are in order + buf_wrptr <= (buf_wrptr == 'h3) ? 'h0 : buf_wrptr + 'h1; + end + else if (buf_wren & gs_mode) begin // gs mode - TODO: shuffling + buf_wrptr <= (buf_wrptr == 'h3) ? 'h0 : buf_wrptr + 'h1; + end + //TODO: gs_mode +end + +always_comb begin + last_rd_addr = ct_mode & (mem_rd_addr == mem_rd_base_addr + MEM_LAST_ADDR); //TODO: other modes + buf_rdptr = ct_mode ? index_rand_offset + buf_count : buf_count; //TODO: flop + // buf_wrptr = gs_mode ? index_rand_offset + buf_count : buf_count; + latch_chunk_rand_offset = arc_IDLE_RD_STAGE; + latch_index_rand_offset = buf0_valid; +end + + //------------------------------------------ //NTT/INTT Read FSM //------------------------------------------ @@ -607,13 +713,13 @@ always_comb begin //Move to WR_WAIT state when the last outputs from bf2x2 have been captured in the buffers. They still need to be shifted out of the buffers and into memory, so keep buf_wren 1 here //Assumption - no bubbles in NTT or INTT. If bubbles, need to consider sampler_valid //TODO: can WR_WAIT state be removed? fsm can finish all 64 addr in WR_MEM state? - arc_WR_MEM_WR_WAIT = (write_fsm_state_ps == WR_MEM) && ((gs_mode && (buf0_valid && (wr_valid_count == 'h3c))) || (pwo_mode && !butterfly_ready && (wr_valid_count < 'h3f))); + arc_WR_MEM_WR_WAIT = (write_fsm_state_ps == WR_MEM) && ((gs_mode && (buf0_valid && (wr_valid_count == 'h3c))) || (pwo_mode && !butterfly_ready && (wr_valid_count < 'h3f))); // || (ct_mode && (wr_valid_count == 'h3f))); //This arc is only for pwo mode. Move back from wait to write state when there's a valid BFU output arc_WR_WAIT_WR_MEM = (write_fsm_state_ps == WR_WAIT) && (pwo_mode && butterfly_ready); //When valid_count is 64 and buf_count is 3 (meaning all 4 buffers have been used), move to WR_STAGE indicating that round is done - arc_WR_WAIT_WR_STAGE = (write_fsm_state_ps == WR_WAIT) && (!pwo_mode && (buf_count == 'h3)); + arc_WR_WAIT_WR_STAGE = (write_fsm_state_ps == WR_WAIT) && (((!pwo_mode && gs_mode) && (buf_count == 'h3)) || ct_mode); end always_comb begin @@ -621,7 +727,7 @@ always_comb begin buf_wren_intt = 1'b0; buf_rden_intt = 1'b0; incr_mem_wr_addr = 1'b0; - mem_wr_en = 1'b0; + mem_wr_en_fsm = 1'b0; wr_addr_step = 'h0; rst_wr_addr = 1'b0; rst_wr_valid_count = 1'b0; @@ -653,7 +759,7 @@ always_comb begin buf_wren_intt = butterfly_ready; buf_rden_intt = buf0_valid; incr_mem_wr_addr = buf0_valid; - mem_wr_en = buf0_valid; + mem_wr_en_fsm = buf0_valid; wr_addr_step = INTT_WRITE_ADDR_STEP; end WR_MEM: begin @@ -663,18 +769,18 @@ always_comb begin buf_wren_intt = gs_mode ; buf_rden_intt = gs_mode ; incr_mem_wr_addr = ct_mode ? butterfly_ready : gs_mode ? 1'b1 : 1'b0; - mem_wr_en = ct_mode ? butterfly_ready : gs_mode ? 1'b1 : 1'b0; + mem_wr_en_fsm = ct_mode ? butterfly_ready : gs_mode ? 1'b1 : 1'b0; wr_addr_step = ct_mode ? NTT_WRITE_ADDR_STEP : INTT_WRITE_ADDR_STEP; incr_pw_wr_addr = pwo_mode & butterfly_ready; pw_wren = pwo_mode & butterfly_ready; end WR_WAIT: begin write_fsm_state_ns = arc_WR_WAIT_WR_STAGE ? WR_STAGE : arc_WR_WAIT_WR_MEM ? WR_MEM : WR_WAIT; - buf_wren_intt = (buf_count <= 'h3); //1'b0; - buf_rden_intt = 1'b1; + buf_wren_intt = gs_mode & (buf_count <= 'h3); //1'b0; + buf_rden_intt = gs_mode; incr_mem_wr_addr = (ct_mode | gs_mode); //1'b1; - mem_wr_en = (ct_mode | gs_mode); //1'b1; - wr_addr_step = INTT_WRITE_ADDR_STEP; + mem_wr_en_fsm = gs_mode; //1'b1; + wr_addr_step = gs_mode ? INTT_WRITE_ADDR_STEP : NTT_WRITE_ADDR_STEP; incr_pw_wr_addr = arc_WR_WAIT_WR_MEM; pw_wren = arc_WR_WAIT_WR_MEM; end @@ -691,19 +797,23 @@ assign buf_rden = pwo_mode ? 1'b0 : buf_rden_ntt | buf_rden_intt; assign bf_enable = (gs_mode || pwo_mode) ? bf_enable_reg : bf_enable_fsm; //In gs mode, memory is directly feeding bf2x2, so we need to enable it one cycle later assign buf_wr_rst_count = pwo_mode ? 1'b1 : buf_wr_rst_count_ntt | buf_wr_rst_count_intt; assign buf_rd_rst_count = pwo_mode ? 1'b1 : buf_rd_rst_count_ntt | buf_rd_rst_count_intt; +assign mem_wr_en = gs_mode ? mem_wr_en_fsm : mem_wr_en_reg; //TODO pwo mode, GS mode + shuffling always_ff @(posedge clk or negedge reset_n) begin if (!reset_n) begin buf_wren_ntt_reg <= 'b0; bf_enable_reg <= 'b0; + mem_wr_en_reg <= 'b0; end else if (zeroize) begin buf_wren_ntt_reg <= 'b0; bf_enable_reg <= 'b0; + mem_wr_en_reg <= 'b0; end else begin buf_wren_ntt_reg <= buf_wren_ntt; bf_enable_reg <= bf_enable_fsm; + mem_wr_en_reg <= mem_wr_en_fsm; end end diff --git a/src/ntt_top/rtl/ntt_shuffle_buffer.sv b/src/ntt_top/rtl/ntt_shuffle_buffer.sv new file mode 100644 index 0000000..0a3eeb6 --- /dev/null +++ b/src/ntt_top/rtl/ntt_shuffle_buffer.sv @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//====================================================================== +// +// ntt_shuffle_buffer.sv +// -------- + +// This buffer temporarily holds data from NTT memory OR NTT BF and collects +// 4 coeffs for each address being read/written to. This buffer will be confisgured based on BF +// mode at ntt_top level. +// NTT mode --> buffer contains data read from memory before it's consumed by BF2x2 +// INTT mode --> buffer contains data from BF2x2 to be written to memory + +// This buffer is customized to support shuffling countermeasure for NTT. It's an +// addressable buffer with the following attributes: +// NTT mode --> writes to buffer are in order. Starting read addr is randomized +// INTT mode --> Starting write addr is randomized. Reads from buffer are in order + +module ntt_shuffle_buffer + import mldsa_params_pkg::*; + import ntt_defines_pkg::*; + #( + parameter REG_SIZE = 24 + ) + ( + input wire clk, + input wire reset_n, + input wire zeroize, + input wire wren, + input wire rden, + input wire [1:0] wrptr, + input wire [1:0] rdptr, + input wire wr_rst_count, + // input wire rd_rst_count, + input wire [(4*REG_SIZE)-1:0] data_i, + output logic buf_valid, + output logic [(4*REG_SIZE)-1:0] data_o + ); + + //buffer*[0] is lo, buffer*[1] is hi + logic [1:0][3:0][3:0][REG_SIZE-1:0] buffer; //2x4x4 buffer [lo_hi][] + + logic [1:0] data_i_count, data_i_count_reg; + logic lo_hi, lo_hi_reg; //0 - lo, 1 - hi + + //Write + always_ff @(posedge clk or negedge reset_n) begin + if (!reset_n) begin + buffer <= 'h0; + end + else if (zeroize) begin + buffer <= 'h0; + end + else if (wren) begin + buffer[lo_hi][wrptr] <= data_i; //4 coeff into lo/hi buf0/1/2/3 based on wrptr + end + end + + //Buffer valid + always_ff @(posedge clk or negedge reset_n) begin + if (!reset_n) begin + data_i_count <= 'h0; + data_i_count_reg <= 'h0; + end + else if (zeroize) begin + data_i_count <= 'h0; + data_i_count_reg <= 'h0; + end + else if (wr_rst_count) begin + data_i_count <= 'h0; + data_i_count_reg <= 'h0; + end + else if (wren) begin + data_i_count <= data_i_count + 'h1; + data_i_count_reg <= data_i_count; + end + end + + always_comb begin + buf_valid = (data_i_count_reg == 'd3); + lo_hi = buf_valid ^ lo_hi_reg; + end + + //lo hi + always_ff @(posedge clk or negedge reset_n) begin + if (!reset_n) + lo_hi_reg <= 'b0; + else if (zeroize) + lo_hi_reg <= 'b0; + else if (buf_valid) + lo_hi_reg <= ~lo_hi_reg; + end + + //assign output + always_comb data_o = rden ? {buffer[~lo_hi][3][rdptr], buffer[~lo_hi][2][rdptr], buffer[~lo_hi][1][rdptr], buffer[~lo_hi][0][rdptr]} : 'h0; //lo_hi points to buffer currently being written. So, read from the other section that's full + + + +endmodule \ No newline at end of file diff --git a/src/ntt_top/rtl/ntt_top.sv b/src/ntt_top/rtl/ntt_top.sv index 6604e73..6343d08 100644 --- a/src/ntt_top/rtl/ntt_top.sv +++ b/src/ntt_top/rtl/ntt_top.sv @@ -73,6 +73,8 @@ module ntt_top //Sampler IF input wire sampler_valid, + input wire [5:0] random, + //Memory if //Reuse between pwm c, ntt output mem_if_t mem_wr_req, @@ -96,6 +98,7 @@ module ntt_top logic mem_wren, mem_wren_reg, mem_wren_mux; logic [MLDSA_MEM_ADDR_WIDTH-1:0] mem_wr_addr, mem_wr_addr_reg, mem_wr_addr_mux; // logic [(4*REG_SIZE)-1:0] mem_wr_data; + logic [MEM_DATA_WIDTH-1:0] mem_wr_data_int, mem_wr_data_reg; //Read IF logic mem_rden; @@ -111,7 +114,7 @@ module ntt_top logic buf0_valid; //Internal - logic [6:0] twiddle_addr; + logic [6:0] twiddle_addr, twiddle_addr_reg; logic buf_wren; logic buf_rden; logic buf_wr_rst_count, buf_rd_rst_count; @@ -119,6 +122,7 @@ module ntt_top //buffer IF logic [(4*REG_SIZE)-1:0] buf_data_i, buf_data_o; logic [(3*NTT_REG_SIZE)-1:0] twiddle_factor, twiddle_factor_reg; + logic [1:0] buf_wrptr, buf_rdptr; //PWM mem IF pwo_uvwi_t pw_uvw_i; @@ -159,8 +163,9 @@ module ntt_top assign mem_wr_req.rd_wr_en = !pwo_mode ? (mem_wren_mux ? RW_WRITE : RW_IDLE) //TODO convert mem_wren_mux to rw enum : (pw_wren_reg ? RW_WRITE : RW_IDLE); assign mem_wr_req.addr = !pwo_mode ? mem_wr_addr_mux : pwm_wr_addr_c_reg; - assign mem_wr_data = !pwo_mode ? (ct_mode ? {1'b0, uv_o_reg.v21_o, 1'b0, uv_o_reg.u21_o, 1'b0, uv_o_reg.v20_o, 1'b0, uv_o_reg.u20_o} : buf_data_o) + assign mem_wr_data_int = !pwo_mode ? (ct_mode ? {1'b0, uv_o_reg.v21_o, 1'b0, uv_o_reg.u21_o, 1'b0, uv_o_reg.v20_o, 1'b0, uv_o_reg.u20_o} : buf_data_o) : pwm_wr_data_reg; + assign mem_wr_data = mem_wr_data_int; //ct_mode ? mem_wr_data_reg : mem_wr_data_int; //TODO: gs, pwo modes //mem rd - NTT/INTT mode, read ntt data. PWM mode, read accumulate data from c mem. PWA/S mode, unused assign mem_rd_req.rd_wr_en = (ct_mode || gs_mode) ? (mem_rden ? RW_READ : RW_IDLE) : pwm_mode ? (pw_rden_dest_mem ? RW_READ : RW_IDLE) : RW_IDLE; @@ -190,6 +195,7 @@ module ntt_top .butterfly_ready(bf_ready), .buf0_valid(buf0_valid), .sampler_valid(sampler_valid), + .random(random), .ntt_mem_base_addr(ntt_mem_base_addr), .pwo_mem_base_addr(pwo_mem_base_addr), @@ -198,6 +204,8 @@ module ntt_top .bf_enable(bf_enable), .buf_wren(buf_wren), .buf_rden(buf_rden), + .buf_wrptr(buf_wrptr), + .buf_rdptr(buf_rdptr), .twiddle_addr(twiddle_addr), .mem_rd_addr(mem_rd_addr), @@ -230,6 +238,7 @@ module ntt_top always_comb begin unique case(mode) ct: begin + //with shuffling, twiddle factor needs to be delayed uvw_i.w00_i = twiddle_factor[NTT_REG_SIZE-1:0]; uvw_i.w01_i = twiddle_factor[NTT_REG_SIZE-1:0]; uvw_i.w10_i = twiddle_factor[(2*NTT_REG_SIZE)-1:NTT_REG_SIZE]; @@ -274,6 +283,7 @@ module ntt_top if (!reset_n) begin mem_rd_data_reg <= 'h0; bf_enable_reg <= 'b0; + twiddle_addr_reg <= 'h0; twiddle_factor_reg <= 'h0; uv_o_reg <= 'h0; @@ -289,11 +299,13 @@ module ntt_top pwm_wr_addr_c_reg <= 'h0; pw_wren_reg <= 'b0; + mem_wr_data_reg <= 'h0; end else if (zeroize) begin mem_rd_data_reg <= 'h0; bf_enable_reg <= 'b0; + twiddle_addr_reg <= 'h0; twiddle_factor_reg <= 'h0; uv_o_reg <= 'h0; @@ -308,10 +320,12 @@ module ntt_top pwm_wr_addr_c_reg <= 'h0; pw_wren_reg <= 'b0; + mem_wr_data_reg <= 'h0; end else begin mem_rd_data_reg <= mem_rd_data; bf_enable_reg <= bf_enable; + twiddle_addr_reg <= twiddle_addr; twiddle_factor_reg <= twiddle_factor; uv_o_reg <= uv_o; @@ -327,6 +341,7 @@ module ntt_top pwm_wr_data_reg <= {1'b0, pwo_uv_o.uv3, 1'b0, pwo_uv_o.uv2, 1'b0, pwo_uv_o.uv1, 1'b0, pwo_uv_o.uv0}; pw_wren_reg <= pw_wren; + mem_wr_data_reg <= mem_wr_data_int; end end @@ -417,9 +432,10 @@ module ntt_top endcase end assign bf_enable_mux = ct_mode ? bf_enable : bf_enable_reg; - assign mem_wren_mux = ct_mode ? mem_wren_reg : mem_wren; - assign mem_wr_addr_mux = ct_mode ? mem_wr_addr_reg : mem_wr_addr; + assign mem_wren_mux = mem_wren; //ct_mode ? mem_wren_reg : mem_wren; + assign mem_wr_addr_mux = mem_wr_addr; //ct_mode ? mem_wr_addr_reg : mem_wr_addr; + /* ntt_buffer #( .REG_SIZE(REG_SIZE) ) buffer_inst0 ( @@ -435,5 +451,22 @@ module ntt_top .buf0_valid(buf0_valid), .data_o(buf_data_o) ); + */ + + ntt_shuffle_buffer #( + .REG_SIZE(REG_SIZE) + ) buffer_inst0 ( + .clk(clk), + .reset_n(reset_n), + .zeroize(zeroize), + .wren(buf_wren), + .rden(buf_rden), + .wrptr(buf_wrptr), + .rdptr(buf_rdptr), + .wr_rst_count(buf_wr_rst_count), + .data_i(buf_data_i), + .buf_valid(buf0_valid), + .data_o(buf_data_o) + ); endmodule diff --git a/src/ntt_top/tb/ntt_top_tb.sv b/src/ntt_top/tb/ntt_top_tb.sv index 4f3f95f..14630f0 100644 --- a/src/ntt_top/tb/ntt_top_tb.sv +++ b/src/ntt_top/tb/ntt_top_tb.sv @@ -74,6 +74,10 @@ pwo_mem_addr_t pwo_mem_base_addr_tb; string operation; +logic wren_tb, rden_tb; +logic [1:0] wrptr_tb, rdptr_tb; +logic [5:0] random_tb; + //---------------------------------------------------------------- // Device Under Test. //---------------------------------------------------------------- @@ -131,6 +135,20 @@ string operation; // .sampler_valid(svalid_tb) // ); +// ntt_shuffle_buffer dut ( +// .clk(clk_tb), +// .reset_n(reset_n_tb), +// .zeroize(zeroize_tb), +// .wren(wren_tb), +// .rden(rden_tb), +// .wrptr(wrptr_tb), +// .rdptr(rdptr_tb), +// .wr_rst_count(), +// .data_i(data_i_tb), +// .buf_valid(), +// .data_o() +// ); + ntt_wrapper dut ( .clk(clk_tb), .reset_n(reset_n_tb), @@ -139,6 +157,7 @@ ntt_wrapper dut ( .ntt_enable(enable_tb), .load_tb_values(load_tb_values), .load_tb_addr(load_tb_addr), + .random(random_tb), // .src_base_addr(src_base_addr), // .interim_base_addr(interim_base_addr), // .dest_base_addr(dest_base_addr), @@ -220,6 +239,8 @@ task init_sim; data_i_tb = 'h0; zeroize_tb = 'b0; enable_tb = 'b0; + wren_tb = 'b0; rden_tb = 'b0; + wrptr_tb = 'h0; rdptr_tb = 'h0; mode_tb = ct; addr0 = 'h0; addr1 = 'h0; addr2 = 'h0; addr3 = 'h0; @@ -237,6 +258,7 @@ task init_sim; acc_tb = 1'b0; svalid_tb = 1'b0; sampler_mode_tb = 1'b0; + random_tb = 'h0; $display("End of init\n"); end @@ -244,13 +266,19 @@ endtask task buffer_test(); reg [REG_SIZE-1:0] i; - enable_tb <= 1'b1; + reg [1:0] j; + wren_tb <= 1'b1; + rden_tb <= 'b1; for (i = 0; i < 64; i++) begin // data_i_tb <= {(i*23'd64)+23'd192, (i*23'd64)+23'd128, (i*23'd64)+23'd64, (i*23'd64)}; + wrptr_tb <= i%4; + j = $urandom_range(3); + rdptr_tb <= (i%4)+j; data_i_tb <= {(i*23'd64)+23'd3, (i*23'd64)+23'd2, (i*23'd64)+23'd1, i*23'd64}; //{23'd3, 23'd2, 23'd1, 23'd0}; @(posedge clk_tb); end - enable_tb <= 1'b0; + wren_tb <= 1'b0; + rden_tb <= 'b0; endtask task twiddle_rom_test(); @@ -307,8 +335,21 @@ task ntt_top_test(); ntt_mem_base_addr_tb.dest_base_addr = 8'd128; acc_tb = 1'b0; svalid_tb = 1'b1; + random_tb = {4'h5, 2'h0}; @(posedge clk_tb); enable_tb = 1'b0; + while(dut.ntt_top_inst0.ntt_ctrl_inst0.rounds_count == 'h0) + @(posedge clk_tb); + random_tb = {4'h9, 2'h0}; + + while(dut.ntt_top_inst0.ntt_ctrl_inst0.rounds_count == 'h1) + @(posedge clk_tb); + random_tb = {4'h0, 2'h0}; + + while(dut.ntt_top_inst0.ntt_ctrl_inst0.rounds_count == 'h2) + @(posedge clk_tb); + random_tb = {4'hf, 2'h0}; + $display("Waiting for ntt_done\n"); while(ntt_done_tb == 1'b0) @(posedge clk_tb); @@ -335,6 +376,7 @@ task ntt_top_test(); @(posedge clk_tb); $display("Received intt_done\n"); + /* $display("PWM operation 1\n"); operation = "PWM 1 no acc"; // $readmemh("pwm_iter1.hex", ntt_mem_tb); @@ -485,6 +527,7 @@ task ntt_top_test(); while(ntt_done_tb == 1'b0) @(posedge clk_tb); $display("Received pwo_done\n"); + */ $display("End of test\n"); endtask diff --git a/src/ntt_top/tb/ntt_wrapper.sv b/src/ntt_top/tb/ntt_wrapper.sv index b43ed99..0e15f61 100644 --- a/src/ntt_top/tb/ntt_wrapper.sv +++ b/src/ntt_top/tb/ntt_wrapper.sv @@ -27,7 +27,7 @@ module ntt_wrapper parameter RADIX = 23, parameter MLDSA_Q = 23'd8380417, parameter MLDSA_N = 256, - parameter MEM_ADDR_WIDTH = 15, + parameter MEM_ADDR_WIDTH = 14, parameter MEM_DATA_WIDTH = 96 ) ( @@ -37,6 +37,7 @@ module ntt_wrapper input mode_t mode, input wire ntt_enable, + input wire [5:0] random, //TB purpose - remove later TODO input wire load_tb_values, @@ -178,6 +179,7 @@ module ntt_wrapper .pwo_mem_base_addr(pwo_mem_base_addr), .accumulate(accumulate), .sampler_valid(sampler_valid), + .random(random), //NTT mem IF .mem_wr_req(mem_wr_req), .mem_rd_req(mem_rd_req), From dd428edd409fdee248c5455cb35db2147ee8c7b6 Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Thu, 19 Sep 2024 16:14:15 -0700 Subject: [PATCH 02/17] NTT shuffling updates --- src/ntt_top/rtl/ntt_ctrl.sv | 51 ++++++++++++++-------- src/ntt_top/rtl/ntt_shuffle_buffer.sv | 6 ++- src/ntt_top/rtl/ntt_top.sv | 1 + src/ntt_top/tb/ntt_top_tb.sv | 43 ++++++++++++------ src/ntt_top/utb/interfaces/mem_if.sv | 12 ++--- src/ntt_top/utb/mem_agent/mem_txn.sv | 10 ++--- src/ntt_top/utb/ntt_agent/ntt_txn.sv | 4 +- src/ntt_top/utb/ntt_utb_top/ntt_utb_top.sv | 16 ++++--- src/ntt_top/utb/scoreboard/ntt_sb.sv | 22 +++++----- 9 files changed, 104 insertions(+), 61 deletions(-) diff --git a/src/ntt_top/rtl/ntt_ctrl.sv b/src/ntt_top/rtl/ntt_ctrl.sv index 14a5be5..d58de71 100644 --- a/src/ntt_top/rtl/ntt_ctrl.sv +++ b/src/ntt_top/rtl/ntt_ctrl.sv @@ -32,7 +32,7 @@ module ntt_ctrl parameter MLDSA_N = 256, parameter MLDSA_LOGN = 8, parameter MEM_ADDR_WIDTH = 15, - parameter BF_LATENCY = 10, //5 cycles per butterfly * 2 instances in serial = 10 clks + parameter BF_LATENCY = 11, //5 cycles per butterfly * 2 instances in serial = 10 clks parameter NTT_BUF_LATENCY = 4 ) ( @@ -109,7 +109,8 @@ logic [1:0] buf_count; logic [3:0] chunk_rand_offset; logic [3:0] chunk_count; //goes from 1 to 16 logic [1:0] index_rand_offset; -logic [BF_LATENCY-1:0][1:0] index_rand_offset_reg, buf_rdptr_reg; +logic [1:0] buf_rdptr_int; +logic [BF_LATENCY-1:0][1:0] buf_rdptr_reg; logic [BF_LATENCY-1:0][3:0] chunk_count_reg; logic latch_chunk_rand_offset, latch_index_rand_offset; logic last_rd_addr, last_wr_addr; @@ -184,7 +185,7 @@ logic arc_EXEC_WAIT_RD_EXEC; //Other signals logic buf_wren_ntt, buf_wren_ntt_reg; logic buf_wren_intt; -logic buf_rden_ntt; +logic buf_rden_ntt, buf_rden_ntt_reg; logic buf_rden_intt; //Write FSM @@ -345,22 +346,22 @@ always_comb begin 'h0: begin twiddle_end_addr = ct_mode ? 'd0 : 'd63; twiddle_offset = 'h0; - twiddle_rand_offset = 'h0; + twiddle_rand_offset = 'h0; //gs mode: (chunk_rand_offset)*4 + index_rand_offset end 'h1: begin twiddle_end_addr = ct_mode ? 'd3 : 'd15; twiddle_offset = ct_mode ? 'd1 : 'd64; - twiddle_rand_offset = ct_mode ? /*(chunk_count_reg[BF_LATENCY-1] % 'd4)*/ 'd0 : 'h0; //TODO gs mode + twiddle_rand_offset = ct_mode ? /*(chunk_count_reg[BF_LATENCY-1] % 'd4)*/ /*index_rand_offset*/buf_rdptr_int : 'h0; //gs mode: (chunk_rand_offset % 4)*4 + index_rand_offset end 'h2: begin twiddle_end_addr = ct_mode ? 'd15 : 'd3; twiddle_offset = ct_mode ? 'd5 : 'd80; - twiddle_rand_offset = ct_mode ? (chunk_count_reg[BF_LATENCY-1] % 'd4)*'d4 : 'h0; + twiddle_rand_offset = ct_mode ? (chunk_count % 'd4)*'d4 + /*index_rand_offset*/buf_rdptr_int : 'h0; //gs mode: index_rand_offset end 'h3: begin twiddle_end_addr = ct_mode ? 'd63 : 'd0; twiddle_offset = ct_mode ? 'd21 : 'd84; - twiddle_rand_offset = ct_mode ? (chunk_count_reg[BF_LATENCY-1] % 'd16)*4 : 'h0; + twiddle_rand_offset = ct_mode ? (chunk_count % 'd16)*4 + /*index_rand_offset*/buf_rdptr_int : 'h0; //gs mode: 0 end default: begin twiddle_end_addr = 'h0; @@ -387,9 +388,9 @@ always_ff @(posedge clk or negedge reset_n) begin else if (zeroize) twiddle_addr_reg <= 'h0; else if (incr_twiddle_addr) - twiddle_addr_reg <= (twiddle_addr_reg == twiddle_end_addr) ? 'h0 : twiddle_addr_reg + 'd1; + twiddle_addr_reg <= /*buf0_valid ? twiddle_rand_offset :*/ gs_mode & (twiddle_addr_reg == twiddle_end_addr) ? 'h0 : ct_mode ? twiddle_rand_offset : twiddle_addr_reg + 'd1; else if (rst_twiddle_addr) - twiddle_addr_reg <= ct_mode ? twiddle_rand_offset : 'h0; //TODO: gs mode + twiddle_addr_reg <= /*ct_mode ? twiddle_rand_offset :*/ 'h0; //TODO: gs mode end assign twiddle_addr = twiddle_addr_reg + twiddle_offset; @@ -492,8 +493,20 @@ always_ff @(posedge clk or negedge reset_n) begin else if (zeroize) begin buf_rdptr_reg <= 'h0; end - else if (buf_rden | butterfly_ready) begin - buf_rdptr_reg <= {buf_rdptr, buf_rdptr_reg[BF_LATENCY-1:1]}; + else if (buf_rden_ntt | butterfly_ready) begin //TODO gs + buf_rdptr_reg <= {buf_rdptr_int, buf_rdptr_reg[BF_LATENCY-1:1]}; + end +end +logic [1:0] buf_rdptr_f; +always_ff @(posedge clk or negedge reset_n) begin + if (!reset_n) begin + buf_rdptr_f <= 'h0; + end + else if (zeroize) begin + buf_rdptr_f <= 'h0; + end + else begin + buf_rdptr_f <= buf_rdptr_int; end end @@ -504,7 +517,7 @@ always_ff @(posedge clk or negedge reset_n) begin else if (zeroize) begin chunk_count_reg <= 'h0; end - else if (buf_rden | butterfly_ready) begin + else if (buf_rden_ntt | butterfly_ready) begin chunk_count_reg <= {chunk_count, chunk_count_reg[BF_LATENCY-1:1]}; end end @@ -527,10 +540,11 @@ end always_comb begin last_rd_addr = ct_mode & (mem_rd_addr == mem_rd_base_addr + MEM_LAST_ADDR); //TODO: other modes - buf_rdptr = ct_mode ? index_rand_offset + buf_count : buf_count; //TODO: flop + buf_rdptr_int = ct_mode ? index_rand_offset + buf_count : buf_count; //TODO: flop + buf_rdptr = ct_mode ? buf_rdptr_f : buf_count; // buf_wrptr = gs_mode ? index_rand_offset + buf_count : buf_count; - latch_chunk_rand_offset = arc_IDLE_RD_STAGE; - latch_index_rand_offset = buf0_valid; + latch_chunk_rand_offset = arc_IDLE_WR_STAGE | arc_WR_MEM_WR_STAGE | arc_WR_WAIT_WR_STAGE; + latch_index_rand_offset = ct_mode & (buf_wrptr == 'h3); //buf0_valid; //TODO: INTT end @@ -793,8 +807,8 @@ end assign rst_rounds = (read_fsm_state_ps == RD_IDLE) && (write_fsm_state_ps == WR_IDLE); assign incr_rounds = arc_WR_MEM_WR_STAGE | arc_WR_WAIT_WR_STAGE; //TODO: revisit for high-perf mode (if we go with above opt) assign buf_wren = pwo_mode ? 1'b0 : buf_wren_ntt_reg | buf_wren_intt; -assign buf_rden = pwo_mode ? 1'b0 : buf_rden_ntt | buf_rden_intt; -assign bf_enable = (gs_mode || pwo_mode) ? bf_enable_reg : bf_enable_fsm; //In gs mode, memory is directly feeding bf2x2, so we need to enable it one cycle later +assign buf_rden = pwo_mode ? 1'b0 : ct_mode ? buf_rden_ntt_reg : /*buf_rden_ntt |*/ buf_rden_intt; +assign bf_enable = (gs_mode || pwo_mode) ? bf_enable_reg : bf_enable_reg; //bf_enable_fsm; //In gs mode, memory is directly feeding bf2x2, so we need to enable it one cycle later assign buf_wr_rst_count = pwo_mode ? 1'b1 : buf_wr_rst_count_ntt | buf_wr_rst_count_intt; assign buf_rd_rst_count = pwo_mode ? 1'b1 : buf_rd_rst_count_ntt | buf_rd_rst_count_intt; assign mem_wr_en = gs_mode ? mem_wr_en_fsm : mem_wr_en_reg; //TODO pwo mode, GS mode + shuffling @@ -802,16 +816,19 @@ assign mem_wr_en = gs_mode ? mem_wr_en_fsm : mem_wr_en_reg; //TODO pwo mo always_ff @(posedge clk or negedge reset_n) begin if (!reset_n) begin buf_wren_ntt_reg <= 'b0; + buf_rden_ntt_reg <= 'b0; bf_enable_reg <= 'b0; mem_wr_en_reg <= 'b0; end else if (zeroize) begin buf_wren_ntt_reg <= 'b0; + buf_rden_ntt_reg <= 'b0; bf_enable_reg <= 'b0; mem_wr_en_reg <= 'b0; end else begin buf_wren_ntt_reg <= buf_wren_ntt; + buf_rden_ntt_reg <= buf_rden_ntt; bf_enable_reg <= bf_enable_fsm; mem_wr_en_reg <= mem_wr_en_fsm; end diff --git a/src/ntt_top/rtl/ntt_shuffle_buffer.sv b/src/ntt_top/rtl/ntt_shuffle_buffer.sv index 0a3eeb6..cf970c7 100644 --- a/src/ntt_top/rtl/ntt_shuffle_buffer.sv +++ b/src/ntt_top/rtl/ntt_shuffle_buffer.sv @@ -38,6 +38,7 @@ module ntt_shuffle_buffer input wire clk, input wire reset_n, input wire zeroize, + input mode_t mode, input wire wren, input wire rden, input wire [1:0] wrptr, @@ -53,7 +54,7 @@ module ntt_shuffle_buffer logic [1:0][3:0][3:0][REG_SIZE-1:0] buffer; //2x4x4 buffer [lo_hi][] logic [1:0] data_i_count, data_i_count_reg; - logic lo_hi, lo_hi_reg; //0 - lo, 1 - hi + logic lo_hi, lo_hi_reg, lo_hi_rd; //0 - lo, 1 - hi //Write always_ff @(posedge clk or negedge reset_n) begin @@ -91,6 +92,7 @@ module ntt_shuffle_buffer always_comb begin buf_valid = (data_i_count_reg == 'd3); lo_hi = buf_valid ^ lo_hi_reg; + lo_hi_rd = (mode == 0) ? lo_hi_reg : lo_hi; end //lo hi @@ -104,7 +106,7 @@ module ntt_shuffle_buffer end //assign output - always_comb data_o = rden ? {buffer[~lo_hi][3][rdptr], buffer[~lo_hi][2][rdptr], buffer[~lo_hi][1][rdptr], buffer[~lo_hi][0][rdptr]} : 'h0; //lo_hi points to buffer currently being written. So, read from the other section that's full + always_comb data_o = rden ? {buffer[~lo_hi_rd][3][rdptr], buffer[~lo_hi_rd][2][rdptr], buffer[~lo_hi_rd][1][rdptr], buffer[~lo_hi_rd][0][rdptr]} : 'h0; //lo_hi points to buffer currently being written. So, read from the other section that's full diff --git a/src/ntt_top/rtl/ntt_top.sv b/src/ntt_top/rtl/ntt_top.sv index 6343d08..4b7a866 100644 --- a/src/ntt_top/rtl/ntt_top.sv +++ b/src/ntt_top/rtl/ntt_top.sv @@ -459,6 +459,7 @@ module ntt_top .clk(clk), .reset_n(reset_n), .zeroize(zeroize), + .mode(mode), .wren(buf_wren), .rden(buf_rden), .wrptr(buf_wrptr), diff --git a/src/ntt_top/tb/ntt_top_tb.sv b/src/ntt_top/tb/ntt_top_tb.sv index 14630f0..eb26036 100644 --- a/src/ntt_top/tb/ntt_top_tb.sv +++ b/src/ntt_top/tb/ntt_top_tb.sv @@ -326,6 +326,14 @@ task ntt_ctrl_test(); endtask task ntt_top_test(); + fork + begin + while(ntt_done_tb == 1'b0) begin + random_tb = $urandom(); + @(posedge clk_tb); + end + end + begin $display("NTT operation\n"); operation = "NTT"; mode_tb = ct; @@ -335,20 +343,20 @@ task ntt_top_test(); ntt_mem_base_addr_tb.dest_base_addr = 8'd128; acc_tb = 1'b0; svalid_tb = 1'b1; - random_tb = {4'h5, 2'h0}; @(posedge clk_tb); enable_tb = 1'b0; - while(dut.ntt_top_inst0.ntt_ctrl_inst0.rounds_count == 'h0) - @(posedge clk_tb); - random_tb = {4'h9, 2'h0}; - while(dut.ntt_top_inst0.ntt_ctrl_inst0.rounds_count == 'h1) - @(posedge clk_tb); - random_tb = {4'h0, 2'h0}; + // while(dut.ntt_top_inst0.ntt_ctrl_inst0.rounds_count == 'h0) + // @(posedge clk_tb); + // random_tb = {4'h9, 2'h3}; - while(dut.ntt_top_inst0.ntt_ctrl_inst0.rounds_count == 'h2) - @(posedge clk_tb); - random_tb = {4'hf, 2'h0}; + // while(dut.ntt_top_inst0.ntt_ctrl_inst0.rounds_count == 'h1) + // @(posedge clk_tb); + // random_tb = {4'h0, 2'h2}; + + // while(dut.ntt_top_inst0.ntt_ctrl_inst0.rounds_count == 'h2) + // @(posedge clk_tb); + // random_tb = {4'hf, 2'h0}; $display("Waiting for ntt_done\n"); while(ntt_done_tb == 1'b0) @@ -360,7 +368,16 @@ task ntt_top_test(); // $display("Error: NTT data mismatch at index %0d (dest_base addr = %0d). Actual data = %h, expected data = %h", i, dest_base_addr, dut.ntt_mem.mem[i+dest_base_addr], ntt_mem_tb[i]); // @(posedge clk_tb); // end - + end + join + fork + begin + while(ntt_done_tb == 1'b0) begin + random_tb = $urandom(); + @(posedge clk_tb); + end + end + begin $display("INTT operation\n"); operation = "INTT"; mode_tb = gs; @@ -375,7 +392,6 @@ task ntt_top_test(); while(ntt_done_tb == 1'b0) @(posedge clk_tb); $display("Received intt_done\n"); - /* $display("PWM operation 1\n"); operation = "PWM 1 no acc"; @@ -528,7 +544,8 @@ task ntt_top_test(); @(posedge clk_tb); $display("Received pwo_done\n"); */ - + end + join $display("End of test\n"); endtask diff --git a/src/ntt_top/utb/interfaces/mem_if.sv b/src/ntt_top/utb/interfaces/mem_if.sv index 12e77d2..b08949f 100644 --- a/src/ntt_top/utb/interfaces/mem_if.sv +++ b/src/ntt_top/utb/interfaces/mem_if.sv @@ -4,10 +4,10 @@ interface mem_if(input bit clk); logic reset_n; mem_if_t mem_port0_req; mem_if_t mem_port1_req; - logic [MEM_DATA_WIDTH-1:0] p0_read_data; - logic [MEM_DATA_WIDTH-1:0] p0_write_data; - logic [MEM_DATA_WIDTH-1:0] p1_read_data; - logic [MEM_DATA_WIDTH-1:0] p1_write_data; + logic [MLDSA_MEM_DATA_WIDTH-1:0] p0_read_data; + logic [MLDSA_MEM_DATA_WIDTH-1:0] p0_write_data; + logic [MLDSA_MEM_DATA_WIDTH-1:0] p1_read_data; + logic [MLDSA_MEM_DATA_WIDTH-1:0] p1_write_data; logic update_mem; string mem_path; @@ -28,7 +28,7 @@ interface mem_if(input bit clk); // modport mem_m_sync_mp(clocking mem_m_cb); modport mem_s_sync_mp(clocking mem_s_cb); - task update_mem_task(input logic [MLDSA_MEM_ADDR_WIDTH-1:0] addr, input logic [MEM_DATA_WIDTH-1:0] data); + task update_mem_task(input logic [MLDSA_MEM_ADDR_WIDTH-1:0] addr, input logic [MLDSA_MEM_DATA_WIDTH-1:0] data); // Time zero assignment to update memory content using hierarchical reference case (mem_path) "mem_ntt": ntt_utb_top.ntt_mem.mem[addr] = data; @@ -38,7 +38,7 @@ interface mem_if(input bit clk); endcase endtask: update_mem_task - task read_mem(input logic [MLDSA_MEM_ADDR_WIDTH-1:0] addr, output logic [MEM_DATA_WIDTH-1:0] data); + task read_mem(input logic [MLDSA_MEM_ADDR_WIDTH-1:0] addr, output logic [MLDSA_MEM_DATA_WIDTH-1:0] data); // Read the memory content using hierarchical reference case (mem_path) "mem_ntt": data = ntt_utb_top.ntt_mem.mem[addr]; diff --git a/src/ntt_top/utb/mem_agent/mem_txn.sv b/src/ntt_top/utb/mem_agent/mem_txn.sv index 316b636..15209bd 100644 --- a/src/ntt_top/utb/mem_agent/mem_txn.sv +++ b/src/ntt_top/utb/mem_agent/mem_txn.sv @@ -9,12 +9,12 @@ class mem_txn extends uvm_sequence_item; rand bit reset_indicator; rand mem_if_t mem_port0_req; rand mem_if_t mem_port1_req; - rand bit [MEM_DATA_WIDTH-1:0] p0_read_data; - rand bit [MEM_DATA_WIDTH-1:0] p0_write_data; - rand bit [MEM_DATA_WIDTH-1:0] p1_read_data; - rand bit [MEM_DATA_WIDTH-1:0] p1_write_data; + rand bit [MLDSA_MEM_DATA_WIDTH-1:0] p0_read_data; + rand bit [MLDSA_MEM_DATA_WIDTH-1:0] p0_write_data; + rand bit [MLDSA_MEM_DATA_WIDTH-1:0] p1_read_data; + rand bit [MLDSA_MEM_DATA_WIDTH-1:0] p1_write_data; rand bit update_mem; - rand bit [MEM_DATA_WIDTH-1:0] artificialMemory []; + rand bit [MLDSA_MEM_DATA_WIDTH-1:0] artificialMemory []; // Define constants localparam int MLDSA_Q = 23'd8380417; diff --git a/src/ntt_top/utb/ntt_agent/ntt_txn.sv b/src/ntt_top/utb/ntt_agent/ntt_txn.sv index 4cf93aa..5f733b5 100644 --- a/src/ntt_top/utb/ntt_agent/ntt_txn.sv +++ b/src/ntt_top/utb/ntt_agent/ntt_txn.sv @@ -14,11 +14,13 @@ class ntt_txn extends uvm_sequence_item; rand bit accumulate; rand bit sampler_valid; rand bit sampler_mode; - rand bit [MEM_DATA_WIDTH-1:0] sampler_data; + rand bit [MLDSA_MEM_DATA_WIDTH-1:0] sampler_data; rand bit ntt_done; rand bit stage_done; rand int stage_idx; + localparam MEM_DEPTH = MLDSA_MEM_MAX_DEPTH; + constraint ntt_c { ntt_enable == 1; } diff --git a/src/ntt_top/utb/ntt_utb_top/ntt_utb_top.sv b/src/ntt_top/utb/ntt_utb_top/ntt_utb_top.sv index c2e0843..b36b313 100644 --- a/src/ntt_top/utb/ntt_utb_top/ntt_utb_top.sv +++ b/src/ntt_top/utb/ntt_utb_top/ntt_utb_top.sv @@ -45,22 +45,22 @@ module ntt_utb_top //NTT, PWM C memory IF mem_if_t mem_port0_req; mem_if_t mem_rd_req; - logic [MEM_DATA_WIDTH-1:0] mem_wr_data; - logic [MEM_DATA_WIDTH-1:0] mem_rd_data; + logic [MLDSA_MEM_DATA_WIDTH-1:0] mem_wr_data; + logic [MLDSA_MEM_DATA_WIDTH-1:0] mem_rd_data; //PWM A/B, PWA/S memory IF mem_if_t pwm_a_rd_req; mem_if_t pwm_b_rd_req; - logic [MEM_DATA_WIDTH-1:0] pwm_a_rd_data; - logic [MEM_DATA_WIDTH-1:0] pwm_b_rd_data; + logic [MLDSA_MEM_DATA_WIDTH-1:0] pwm_a_rd_data; + logic [MLDSA_MEM_DATA_WIDTH-1:0] pwm_b_rd_data; //NTT/PWM muxes logic ntt_mem_wren, ntt_mem_rden; logic [MLDSA_MEM_ADDR_WIDTH-1:0] ntt_mem_wr_addr; logic [MLDSA_MEM_ADDR_WIDTH-1:0] ntt_mem_rd_addr; - logic [MEM_DATA_WIDTH-1:0] ntt_mem_wr_data; - logic [MEM_DATA_WIDTH-1:0] ntt_mem_rd_data; + logic [MLDSA_MEM_DATA_WIDTH-1:0] ntt_mem_wr_data; + logic [MLDSA_MEM_DATA_WIDTH-1:0] ntt_mem_rd_data; logic pwm_mem_a_rden, pwm_mem_b_rden; @@ -70,6 +70,8 @@ module ntt_utb_top logic pwo_mode; logic pwm_mode, pwa_mode, pws_mode; + logic [5:0] random_utb; + assign ct_mode = (ntt_if_i.mode == ct); assign gs_mode = (ntt_if_i.mode == gs); assign pwo_mode = (ntt_if_i.mode inside {pwm, pwa, pws}); @@ -166,6 +168,7 @@ module ntt_utb_top .accumulate(ntt_if_i.accumulate), .sampler_valid(ntt_if_i.sampler_valid), + .random(random_utb), //NTT mem IF .mem_wr_req(ntt_mem_if_i.mem_port0_req), .mem_rd_req(ntt_mem_if_i.mem_port1_req), @@ -188,6 +191,7 @@ module ntt_utb_top always begin #1 clk = ~clk; + random_utb = {4'h5, $urandom_range(0,3)}; end initial begin diff --git a/src/ntt_top/utb/scoreboard/ntt_sb.sv b/src/ntt_top/utb/scoreboard/ntt_sb.sv index 99deb98..3d85307 100644 --- a/src/ntt_top/utb/scoreboard/ntt_sb.sv +++ b/src/ntt_top/utb/scoreboard/ntt_sb.sv @@ -1,5 +1,5 @@ class ntt_sb extends uvm_scoreboard; - import mldsa_params_pkg::*; + // import mldsa_params_pkg::*; `uvm_component_utils(ntt_sb) uvm_analysis_imp_ntt_txn#(ntt_txn, ntt_sb) ntt_ap; @@ -14,15 +14,15 @@ class ntt_sb extends uvm_scoreboard; localparam MLDSA_N = 256; localparam MLDSA_LOGN = $clog2(MLDSA_N); localparam f= 8347681; // 256^-1 mod MLDSA_Q - + localparam MEM_DEPTH = MLDSA_MEM_MAX_DEPTH; // Memory models for the three memories - bit [MEM_DATA_WIDTH-1:0] ntt_mem_model [0:MEM_DEPTH-1]; - bit [MEM_DATA_WIDTH-1:0] pwm_a_mem_model [0:MEM_DEPTH-1]; - bit [MEM_DATA_WIDTH-1:0] pwm_b_mem_model [0:MEM_DEPTH-1]; + bit [MLDSA_MEM_DATA_WIDTH-1:0] ntt_mem_model [0:MEM_DEPTH-1]; + bit [MLDSA_MEM_DATA_WIDTH-1:0] pwm_a_mem_model [0:MEM_DEPTH-1]; + bit [MLDSA_MEM_DATA_WIDTH-1:0] pwm_b_mem_model [0:MEM_DEPTH-1]; - bit [MEM_DATA_WIDTH-1:0] ntt_model_inputs [0:MEM_DEPTH-1]; - bit [MEM_DATA_WIDTH-1:0] pwm_a_model_inputs [0:MEM_DEPTH-1]; - bit [MEM_DATA_WIDTH-1:0] pwm_b_model_inputs [0:MEM_DEPTH-1]; + bit [MLDSA_MEM_DATA_WIDTH-1:0] ntt_model_inputs [0:MEM_DEPTH-1]; + bit [MLDSA_MEM_DATA_WIDTH-1:0] pwm_a_model_inputs [0:MEM_DEPTH-1]; + bit [MLDSA_MEM_DATA_WIDTH-1:0] pwm_b_model_inputs [0:MEM_DEPTH-1]; bit [REG_SIZE-1:0] One_NTT_input [0:MLDSA_N-1]; bit [REG_SIZE-1:0] model_NTT_output [0:MLDSA_N-1]; @@ -134,8 +134,8 @@ class ntt_sb extends uvm_scoreboard; // Function to update the memory model based on the received transactions - function void update_sb_memory_model(ref bit [MEM_DATA_WIDTH-1:0] mem_from_DUT [0:MEM_DEPTH-1], - ref bit [MEM_DATA_WIDTH-1:0] mem_to_model [0:MEM_DEPTH-1], + function void update_sb_memory_model(ref bit [MLDSA_MEM_DATA_WIDTH-1:0] mem_from_DUT [0:MEM_DEPTH-1], + ref bit [MLDSA_MEM_DATA_WIDTH-1:0] mem_to_model [0:MEM_DEPTH-1], mem_txn mem_txn_i ); if (mem_txn_i.update_mem) begin @@ -156,7 +156,7 @@ class ntt_sb extends uvm_scoreboard; // Function to extract 256 coefficients from the input memory starting at src_base_addr function void extract_256_coeffs( - input bit [MEM_DATA_WIDTH-1:0] input_memory [0:MEM_DEPTH-1], + input bit [MLDSA_MEM_DATA_WIDTH-1:0] input_memory [0:MEM_DEPTH-1], input logic [MLDSA_MEM_ADDR_WIDTH-1:0] base_addr, input int stage_idx, output bit [REG_SIZE-1:0] NTT_coeffs [0:MLDSA_N-1] From a7d5e54d99e27593ae1b49711835142ea670ad4c Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Mon, 23 Sep 2024 12:02:56 -0700 Subject: [PATCH 03/17] WIP INTT shuffling --- src/ntt_top/rtl/ntt_ctrl.sv | 104 ++++++++++++++++++++++++------------ src/ntt_top/rtl/ntt_top.sv | 8 +-- 2 files changed, 75 insertions(+), 37 deletions(-) diff --git a/src/ntt_top/rtl/ntt_ctrl.sv b/src/ntt_top/rtl/ntt_ctrl.sv index d58de71..cbe1a04 100644 --- a/src/ntt_top/rtl/ntt_ctrl.sv +++ b/src/ntt_top/rtl/ntt_ctrl.sv @@ -32,7 +32,7 @@ module ntt_ctrl parameter MLDSA_N = 256, parameter MLDSA_LOGN = 8, parameter MEM_ADDR_WIDTH = 15, - parameter BF_LATENCY = 11, //5 cycles per butterfly * 2 instances in serial = 10 clks + parameter BF_LATENCY = 11, //TODO: change back to 10 and test shuffling //5 cycles per butterfly * 2 instances in serial = 10 clks parameter NTT_BUF_LATENCY = 4 ) ( @@ -93,12 +93,13 @@ localparam PWO_READ_ADDR_STEP = 1; localparam PWO_WRITE_ADDR_STEP = 1; localparam [MEM_ADDR_WIDTH-1:0] MEM_LAST_ADDR = 63; +localparam INTT_WRBUF_LATENCY = 13; //includes BF latency + mem latency for shuffled reads to begin //FSM states ntt_read_state_t read_fsm_state_ps, read_fsm_state_ns; ntt_write_state_t write_fsm_state_ps, write_fsm_state_ns; //BF enable flags -logic bf_enable_fsm, bf_enable_reg; +logic bf_enable_fsm, bf_enable_reg, bf_enable_reg_d2; //Buffer signals logic buf_wr_rst_count_ntt, buf_rd_rst_count_ntt; @@ -107,14 +108,17 @@ logic [1:0] buf_count; //Shuffle buffer signals logic [3:0] chunk_rand_offset; -logic [3:0] chunk_count; //goes from 1 to 16 -logic [1:0] index_rand_offset; +logic [3:0] chunk_count; +logic [1:0] index_rand_offset, index_count, mem_rd_index_ofst; logic [1:0] buf_rdptr_int; +logic [1:0] buf_rdptr_f; logic [BF_LATENCY-1:0][1:0] buf_rdptr_reg; +logic [INTT_WRBUF_LATENCY-1:0][1:0] buf_wrptr_reg; logic [BF_LATENCY-1:0][3:0] chunk_count_reg; logic latch_chunk_rand_offset, latch_index_rand_offset; logic last_rd_addr, last_wr_addr; logic mem_wr_en_fsm, mem_wr_en_reg; +logic mem_rd_en_fsm, mem_rd_en_reg; //Mode flags logic ct_mode, gs_mode, pwo_mode; //point-wise operations mode @@ -139,7 +143,7 @@ logic rst_pw_addr; //Twiddle ROM wires logic incr_twiddle_addr, incr_twiddle_addr_fsm, incr_twiddle_addr_reg; logic twiddle_mode, rst_twiddle_addr; -logic [6:0] twiddle_end_addr, twiddle_addr_reg, twiddle_offset; +logic [6:0] twiddle_end_addr, twiddle_addr_reg, twiddle_addr_reg_d2, twiddle_addr_int, twiddle_offset; //FSM round signals logic [$clog2(NTT_NUM_ROUNDS):0] num_rounds; @@ -264,9 +268,9 @@ end always_comb begin mem_rd_base_addr = (rounds_count == 'h0) ? src_base_addr : rounds_count[0] ? interim_base_addr : dest_base_addr; mem_wr_base_addr = rounds_count[0] ? dest_base_addr : interim_base_addr; - mem_rd_addr_nxt = mem_rd_addr + rd_addr_step; //TODO gs, pwo modes + mem_rd_addr_nxt = gs_mode ? (4*chunk_count) + (rd_addr_step*mem_rd_index_ofst) + mem_rd_base_addr : mem_rd_addr + rd_addr_step; //TODO pwo modes // mem_wr_addr_nxt = mem_wr_addr + wr_addr_step; - mem_wr_addr_nxt = ct_mode ? (4*(chunk_count_reg[0])) + (wr_addr_step*buf_rdptr_reg[0]) + mem_wr_base_addr : mem_wr_addr + wr_addr_step; //TODO: gs, pwo modes + mem_wr_addr_nxt = ct_mode ? (4*(chunk_count_reg[0])) + (wr_addr_step*buf_rdptr_reg[0]) + mem_wr_base_addr : mem_wr_addr + wr_addr_step; //TODO: pwo modes rd_addr_wraparound = mem_rd_addr_nxt > {1'b0,mem_rd_base_addr} + MEM_LAST_ADDR; wr_addr_wraparound = mem_wr_addr_nxt > {1'b0,mem_wr_base_addr} + MEM_LAST_ADDR; end @@ -280,7 +284,7 @@ always_ff @(posedge clk or negedge reset_n) begin mem_rd_addr <= 'h0; end else if (rst_rd_addr) begin - mem_rd_addr <= ct_mode ? mem_rd_base_addr + chunk_rand_offset : mem_rd_base_addr; //TODO: gs, pwo + mem_rd_addr <= ct_mode ? mem_rd_base_addr + chunk_rand_offset : gs_mode ? mem_rd_base_addr + (4*chunk_rand_offset) : mem_rd_base_addr; //TODO: pwo end else if (incr_mem_rd_addr) begin mem_rd_addr <= last_rd_addr ? mem_rd_base_addr : rd_addr_wraparound ? MEM_ADDR_WIDTH'(mem_rd_addr_nxt - MEM_LAST_ADDR) : mem_rd_addr_nxt[MEM_ADDR_WIDTH-1:0]; @@ -296,10 +300,10 @@ always_ff @(posedge clk or negedge reset_n) begin mem_wr_addr <= 'h0; end else if (rst_wr_addr) begin - mem_wr_addr <= ct_mode ? mem_wr_base_addr + (4*chunk_rand_offset) : mem_wr_base_addr; //TODO: gs, pwo + mem_wr_addr <= ct_mode ? mem_wr_base_addr + (4*chunk_rand_offset) : gs_mode ? mem_wr_base_addr + chunk_rand_offset : mem_wr_base_addr; //TODO: pwo end else if (incr_mem_wr_addr) begin - mem_wr_addr <= wr_addr_wraparound ? MEM_ADDR_WIDTH'(mem_wr_addr_nxt - MEM_LAST_ADDR) : mem_wr_addr_nxt[MEM_ADDR_WIDTH-1:0]; + mem_wr_addr <= (gs_mode & last_wr_addr) ? mem_wr_base_addr : wr_addr_wraparound ? MEM_ADDR_WIDTH'(mem_wr_addr_nxt - MEM_LAST_ADDR) : mem_wr_addr_nxt[MEM_ADDR_WIDTH-1:0]; end end @@ -346,22 +350,22 @@ always_comb begin 'h0: begin twiddle_end_addr = ct_mode ? 'd0 : 'd63; twiddle_offset = 'h0; - twiddle_rand_offset = 'h0; //gs mode: (chunk_rand_offset)*4 + index_rand_offset + twiddle_rand_offset = ct_mode ? 'h0 : (chunk_count)*4 + buf_wrptr_reg[10]; end 'h1: begin twiddle_end_addr = ct_mode ? 'd3 : 'd15; twiddle_offset = ct_mode ? 'd1 : 'd64; - twiddle_rand_offset = ct_mode ? /*(chunk_count_reg[BF_LATENCY-1] % 'd4)*/ /*index_rand_offset*/buf_rdptr_int : 'h0; //gs mode: (chunk_rand_offset % 4)*4 + index_rand_offset + twiddle_rand_offset = ct_mode ? buf_rdptr_int : (chunk_count % 4)*4 + buf_wrptr_reg[10]; end 'h2: begin twiddle_end_addr = ct_mode ? 'd15 : 'd3; twiddle_offset = ct_mode ? 'd5 : 'd80; - twiddle_rand_offset = ct_mode ? (chunk_count % 'd4)*'d4 + /*index_rand_offset*/buf_rdptr_int : 'h0; //gs mode: index_rand_offset + twiddle_rand_offset = ct_mode ? (chunk_count % 'd4)*'d4 + buf_rdptr_int : buf_wrptr_reg[10]; end 'h3: begin twiddle_end_addr = ct_mode ? 'd63 : 'd0; twiddle_offset = ct_mode ? 'd21 : 'd84; - twiddle_rand_offset = ct_mode ? (chunk_count % 'd16)*4 + /*index_rand_offset*/buf_rdptr_int : 'h0; //gs mode: 0 + twiddle_rand_offset = ct_mode ? (chunk_count % 'd16)*4 + buf_rdptr_int : 'h0; //gs mode: 0 end default: begin twiddle_end_addr = 'h0; @@ -388,12 +392,12 @@ always_ff @(posedge clk or negedge reset_n) begin else if (zeroize) twiddle_addr_reg <= 'h0; else if (incr_twiddle_addr) - twiddle_addr_reg <= /*buf0_valid ? twiddle_rand_offset :*/ gs_mode & (twiddle_addr_reg == twiddle_end_addr) ? 'h0 : ct_mode ? twiddle_rand_offset : twiddle_addr_reg + 'd1; + twiddle_addr_reg <= /*gs_mode & (twiddle_addr_reg == twiddle_end_addr) ? 'h0 :*/ ct_mode ? twiddle_rand_offset : twiddle_rand_offset; //twiddle_addr_reg + 'd1; else if (rst_twiddle_addr) twiddle_addr_reg <= /*ct_mode ? twiddle_rand_offset :*/ 'h0; //TODO: gs mode end -assign twiddle_addr = twiddle_addr_reg + twiddle_offset; +assign twiddle_addr_int = twiddle_addr_reg + twiddle_offset; //------------------------------------------ //Busy logic @@ -465,9 +469,9 @@ always_ff @(posedge clk or negedge reset_n) begin end else if (latch_chunk_rand_offset) begin chunk_rand_offset <= random[5:2]; - chunk_count <= {1'b0, random[5:2]}; + chunk_count <= random[5:2]; end - else if (buf_count == 'h3) begin //update chunk after every 4 cycles + else if ((ct_mode & (buf_count == 'h3)) | (gs_mode & (buf_wrptr_reg[10] == 'h3))) begin //update chunk after every 4 cycles chunk_count <= (chunk_count == 'hf) ? 'h0 : chunk_count + 'h1; end end @@ -475,6 +479,7 @@ end always_ff @(posedge clk or negedge reset_n) begin if (!reset_n) begin index_rand_offset <= 'h0; + end else if (zeroize) begin index_rand_offset <= 'h0; @@ -482,34 +487,55 @@ always_ff @(posedge clk or negedge reset_n) begin end else if (latch_index_rand_offset) begin index_rand_offset <= random[1:0]; - end + end always_ff @(posedge clk or negedge reset_n) begin if (!reset_n) begin buf_rdptr_reg <= 'h0; + buf_wrptr_reg <= 'h0; end else if (zeroize) begin buf_rdptr_reg <= 'h0; + buf_wrptr_reg <= 'h0; end - else if (buf_rden_ntt | butterfly_ready) begin //TODO gs + else if (ct_mode & (buf_rden_ntt | butterfly_ready)) begin buf_rdptr_reg <= {buf_rdptr_int, buf_rdptr_reg[BF_LATENCY-1:1]}; end + else if (gs_mode & (incr_mem_rd_addr | butterfly_ready)) begin + buf_wrptr_reg <= {mem_rd_index_ofst, buf_wrptr_reg[INTT_WRBUF_LATENCY-1:1]}; + end + else begin + buf_rdptr_reg <= 'h0; + buf_wrptr_reg <= 'h0; + end end -logic [1:0] buf_rdptr_f; + always_ff @(posedge clk or negedge reset_n) begin if (!reset_n) begin buf_rdptr_f <= 'h0; end else if (zeroize) begin buf_rdptr_f <= 'h0; - end + end else begin buf_rdptr_f <= buf_rdptr_int; end end +always_ff @(posedge clk or negedge reset_n) begin + if (!reset_n) begin + index_count <= 'h0; + end + else if (zeroize) begin + index_count <= 'h0; + end + else if (gs_mode & incr_mem_rd_addr) begin + index_count <= index_count + 'h1; + end +end + always_ff @(posedge clk or negedge reset_n) begin if (!reset_n) begin chunk_count_reg <= 'h0; @@ -517,7 +543,7 @@ always_ff @(posedge clk or negedge reset_n) begin else if (zeroize) begin chunk_count_reg <= 'h0; end - else if (buf_rden_ntt | butterfly_ready) begin + else if (buf_rden_ntt | butterfly_ready | (gs_mode & incr_mem_rd_addr)) begin //TODO: replace gs condition with an fsm generated flag perhaps? chunk_count_reg <= {chunk_count, chunk_count_reg[BF_LATENCY-1:1]}; end end @@ -532,19 +558,20 @@ always_ff @(posedge clk or negedge reset_n) begin else if (buf_wren & ct_mode) begin //ct mode - buf writes are in order buf_wrptr <= (buf_wrptr == 'h3) ? 'h0 : buf_wrptr + 'h1; end - else if (buf_wren & gs_mode) begin // gs mode - TODO: shuffling - buf_wrptr <= (buf_wrptr == 'h3) ? 'h0 : buf_wrptr + 'h1; + else if (buf_wren & gs_mode) begin // gs mode + buf_wrptr <= buf_wrptr_reg[1]; //equivalent to [0] due to this flop //(buf_wrptr == 'h3) ? 'h0 : buf_wrptr + 'h1; end - //TODO: gs_mode end always_comb begin - last_rd_addr = ct_mode & (mem_rd_addr == mem_rd_base_addr + MEM_LAST_ADDR); //TODO: other modes + last_rd_addr = /*ct_mode &*/ (mem_rd_addr == mem_rd_base_addr + MEM_LAST_ADDR); //TODO: other modes + last_wr_addr = /*ct_mode &*/ (mem_wr_addr == mem_wr_base_addr + MEM_LAST_ADDR); //TODO: other modes buf_rdptr_int = ct_mode ? index_rand_offset + buf_count : buf_count; //TODO: flop buf_rdptr = ct_mode ? buf_rdptr_f : buf_count; // buf_wrptr = gs_mode ? index_rand_offset + buf_count : buf_count; latch_chunk_rand_offset = arc_IDLE_WR_STAGE | arc_WR_MEM_WR_STAGE | arc_WR_WAIT_WR_STAGE; - latch_index_rand_offset = ct_mode & (buf_wrptr == 'h3); //buf0_valid; //TODO: INTT + latch_index_rand_offset = ct_mode ? (buf_wrptr == 'h3) : gs_mode & (arc_RD_STAGE_RD_EXEC | (index_count == 'h3)); //TODO pwo mode + mem_rd_index_ofst = gs_mode ? (index_count + index_rand_offset) : 'h0; //TODO: pwo mode, not used in ct mode end @@ -609,7 +636,7 @@ always_comb begin buf_rden_ntt = 1'b0; incr_mem_rd_addr = 1'b0; bf_enable_fsm = 1'b0; - mem_rd_en = 1'b0; + mem_rd_en_fsm = 1'b0; incr_twiddle_addr_fsm = 1'b0; rd_addr_step = 'h0; rst_rd_addr = 1'b0; @@ -644,7 +671,7 @@ always_comb begin buf_wren_ntt = 1'b1; buf_rden_ntt = buf0_valid; incr_mem_rd_addr = 1'b1; - mem_rd_en = 1'b1; + mem_rd_en_fsm = 1'b1; bf_enable_fsm = buf0_valid; //Enable bf2x2 as soon as buf is valid incr_twiddle_addr_fsm = buf0_valid; rd_addr_step = NTT_READ_ADDR_STEP; @@ -656,7 +683,7 @@ always_comb begin buf_wren_ntt = ct_mode; buf_rden_ntt = ct_mode; incr_mem_rd_addr = (ntt_mode inside {ct, gs}); - mem_rd_en = (ntt_mode inside {ct, gs}) ? (mem_rd_addr <= MEM_LAST_ADDR + mem_rd_base_addr) : 1'b0; + mem_rd_en_fsm = (ntt_mode inside {ct, gs}) ? (mem_rd_addr <= MEM_LAST_ADDR + mem_rd_base_addr) : 1'b0; bf_enable_fsm = pwo_mode ? sampler_valid : 1'b1; incr_twiddle_addr_fsm = ntt_mode inside {ct, gs}; //1'b1; rd_addr_step = ct_mode ? NTT_READ_ADDR_STEP : INTT_READ_ADDR_STEP; @@ -733,7 +760,7 @@ always_comb begin arc_WR_WAIT_WR_MEM = (write_fsm_state_ps == WR_WAIT) && (pwo_mode && butterfly_ready); //When valid_count is 64 and buf_count is 3 (meaning all 4 buffers have been used), move to WR_STAGE indicating that round is done - arc_WR_WAIT_WR_STAGE = (write_fsm_state_ps == WR_WAIT) && (((!pwo_mode && gs_mode) && (buf_count == 'h3)) || ct_mode); + arc_WR_WAIT_WR_STAGE = (write_fsm_state_ps == WR_WAIT) && ((gs_mode && (buf_count == 'h3)) || ct_mode); end always_comb begin @@ -808,29 +835,40 @@ assign rst_rounds = (read_fsm_state_ps == RD_IDLE) && (write_fsm_state_ps assign incr_rounds = arc_WR_MEM_WR_STAGE | arc_WR_WAIT_WR_STAGE; //TODO: revisit for high-perf mode (if we go with above opt) assign buf_wren = pwo_mode ? 1'b0 : buf_wren_ntt_reg | buf_wren_intt; assign buf_rden = pwo_mode ? 1'b0 : ct_mode ? buf_rden_ntt_reg : /*buf_rden_ntt |*/ buf_rden_intt; -assign bf_enable = (gs_mode || pwo_mode) ? bf_enable_reg : bf_enable_reg; //bf_enable_fsm; //In gs mode, memory is directly feeding bf2x2, so we need to enable it one cycle later +assign bf_enable = (gs_mode || pwo_mode) ? bf_enable_reg_d2 : bf_enable_reg; //bf_enable_fsm; //In gs mode, memory is directly feeding bf2x2, so we need to enable it one cycle later assign buf_wr_rst_count = pwo_mode ? 1'b1 : buf_wr_rst_count_ntt | buf_wr_rst_count_intt; assign buf_rd_rst_count = pwo_mode ? 1'b1 : buf_rd_rst_count_ntt | buf_rd_rst_count_intt; assign mem_wr_en = gs_mode ? mem_wr_en_fsm : mem_wr_en_reg; //TODO pwo mode, GS mode + shuffling +assign mem_rd_en = gs_mode ? mem_rd_en_reg : mem_rd_en_fsm; //TODO pwo mode +assign twiddle_addr = gs_mode ? twiddle_addr_reg_d2 : twiddle_addr_int; always_ff @(posedge clk or negedge reset_n) begin if (!reset_n) begin buf_wren_ntt_reg <= 'b0; buf_rden_ntt_reg <= 'b0; bf_enable_reg <= 'b0; + bf_enable_reg_d2 <= 'b0; mem_wr_en_reg <= 'b0; + mem_rd_en_reg <= 'b0; + twiddle_addr_reg_d2 <= 'h0; end else if (zeroize) begin buf_wren_ntt_reg <= 'b0; buf_rden_ntt_reg <= 'b0; bf_enable_reg <= 'b0; + bf_enable_reg_d2 <= 'b0; mem_wr_en_reg <= 'b0; + mem_rd_en_reg <= 'b0; + twiddle_addr_reg_d2 <= 'h0; end else begin buf_wren_ntt_reg <= buf_wren_ntt; buf_rden_ntt_reg <= buf_rden_ntt; bf_enable_reg <= bf_enable_fsm; + bf_enable_reg_d2 <= bf_enable_reg; mem_wr_en_reg <= mem_wr_en_fsm; + mem_rd_en_reg <= mem_rd_en_fsm; + twiddle_addr_reg_d2 <= twiddle_addr_int; end end diff --git a/src/ntt_top/rtl/ntt_top.sv b/src/ntt_top/rtl/ntt_top.sv index 4b7a866..68b6490 100644 --- a/src/ntt_top/rtl/ntt_top.sv +++ b/src/ntt_top/rtl/ntt_top.sv @@ -245,10 +245,10 @@ module ntt_top uvw_i.w11_i = twiddle_factor[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; end gs: begin - uvw_i.w11_i = twiddle_factor_reg[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; - uvw_i.w10_i = twiddle_factor_reg[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; - uvw_i.w01_i = twiddle_factor_reg[(2*NTT_REG_SIZE)-1:NTT_REG_SIZE]; - uvw_i.w00_i = twiddle_factor_reg[NTT_REG_SIZE-1:0]; + uvw_i.w11_i = twiddle_factor/*_reg*/[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; + uvw_i.w10_i = twiddle_factor/*_reg*/[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; + uvw_i.w01_i = twiddle_factor/*_reg*/[(2*NTT_REG_SIZE)-1:NTT_REG_SIZE]; + uvw_i.w00_i = twiddle_factor/*_reg*/[NTT_REG_SIZE-1:0]; end default: begin uvw_i.w11_i = 'h0; From 2d106206ea4d5218bcb5bf5c9078439211279d48 Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Thu, 3 Oct 2024 11:41:59 -0700 Subject: [PATCH 04/17] Add shuffle_en wip --- src/ntt_top/rtl/ntt_ctrl.sv | 49 +++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/src/ntt_top/rtl/ntt_ctrl.sv b/src/ntt_top/rtl/ntt_ctrl.sv index cbe1a04..2273913 100644 --- a/src/ntt_top/rtl/ntt_ctrl.sv +++ b/src/ntt_top/rtl/ntt_ctrl.sv @@ -144,6 +144,7 @@ logic rst_pw_addr; logic incr_twiddle_addr, incr_twiddle_addr_fsm, incr_twiddle_addr_reg; logic twiddle_mode, rst_twiddle_addr; logic [6:0] twiddle_end_addr, twiddle_addr_reg, twiddle_addr_reg_d2, twiddle_addr_int, twiddle_offset; +logic [6:0] twiddle_rand_offset; //FSM round signals logic [$clog2(NTT_NUM_ROUNDS):0] num_rounds; @@ -268,11 +269,17 @@ end always_comb begin mem_rd_base_addr = (rounds_count == 'h0) ? src_base_addr : rounds_count[0] ? interim_base_addr : dest_base_addr; mem_wr_base_addr = rounds_count[0] ? dest_base_addr : interim_base_addr; - mem_rd_addr_nxt = gs_mode ? (4*chunk_count) + (rd_addr_step*mem_rd_index_ofst) + mem_rd_base_addr : mem_rd_addr + rd_addr_step; //TODO pwo modes - // mem_wr_addr_nxt = mem_wr_addr + wr_addr_step; - mem_wr_addr_nxt = ct_mode ? (4*(chunk_count_reg[0])) + (wr_addr_step*buf_rdptr_reg[0]) + mem_wr_base_addr : mem_wr_addr + wr_addr_step; //TODO: pwo modes rd_addr_wraparound = mem_rd_addr_nxt > {1'b0,mem_rd_base_addr} + MEM_LAST_ADDR; wr_addr_wraparound = mem_wr_addr_nxt > {1'b0,mem_wr_base_addr} + MEM_LAST_ADDR; + + if (shuffle_en) begin + mem_rd_addr_nxt = gs_mode ? (4*chunk_count) + (rd_addr_step*mem_rd_index_ofst) + mem_rd_base_addr : mem_rd_addr + rd_addr_step; + mem_wr_addr_nxt = ct_mode ? (4*(chunk_count_reg[0])) + (wr_addr_step*buf_rdptr_reg[0]) + mem_wr_base_addr : mem_wr_addr + wr_addr_step; + end + else begin + mem_rd_addr_nxt = mem_rd_addr + rd_addr_step; + mem_wr_addr_nxt = mem_wr_addr + wr_addr_step; + end end //Read addr @@ -284,10 +291,16 @@ always_ff @(posedge clk or negedge reset_n) begin mem_rd_addr <= 'h0; end else if (rst_rd_addr) begin - mem_rd_addr <= ct_mode ? mem_rd_base_addr + chunk_rand_offset : gs_mode ? mem_rd_base_addr + (4*chunk_rand_offset) : mem_rd_base_addr; //TODO: pwo + if (shuffle_en) + mem_rd_addr <= ct_mode ? mem_rd_base_addr + chunk_rand_offset : gs_mode ? mem_rd_base_addr + (4*chunk_rand_offset) : mem_rd_base_addr; + else + mem_rd_addr <= mem_rd_base_addr; end else if (incr_mem_rd_addr) begin - mem_rd_addr <= last_rd_addr ? mem_rd_base_addr : rd_addr_wraparound ? MEM_ADDR_WIDTH'(mem_rd_addr_nxt - MEM_LAST_ADDR) : mem_rd_addr_nxt[MEM_ADDR_WIDTH-1:0]; + if (shuffle_en) + mem_rd_addr <= last_rd_addr ? mem_rd_base_addr : rd_addr_wraparound ? MEM_ADDR_WIDTH'(mem_rd_addr_nxt - MEM_LAST_ADDR) : mem_rd_addr_nxt[MEM_ADDR_WIDTH-1:0]; + else + mem_rd_addr <= rd_addr_wraparound ? MEM_ADDR_WIDTH'(mem_rd_addr_nxt - MEM_LAST_ADDR) : mem_rd_addr_nxt[MEM_ADDR_WIDTH-1:0]; end end @@ -300,10 +313,16 @@ always_ff @(posedge clk or negedge reset_n) begin mem_wr_addr <= 'h0; end else if (rst_wr_addr) begin - mem_wr_addr <= ct_mode ? mem_wr_base_addr + (4*chunk_rand_offset) : gs_mode ? mem_wr_base_addr + chunk_rand_offset : mem_wr_base_addr; //TODO: pwo + if (shuffle_en) + mem_wr_addr <= ct_mode ? mem_wr_base_addr + (4*chunk_rand_offset) : gs_mode ? mem_wr_base_addr + chunk_rand_offset : mem_wr_base_addr; //TODO: pwo + else + mem_wr_addr <= mem_wr_base_addr; end else if (incr_mem_wr_addr) begin - mem_wr_addr <= (gs_mode & last_wr_addr) ? mem_wr_base_addr : wr_addr_wraparound ? MEM_ADDR_WIDTH'(mem_wr_addr_nxt - MEM_LAST_ADDR) : mem_wr_addr_nxt[MEM_ADDR_WIDTH-1:0]; + if (shuffle_en) + mem_wr_addr <= (gs_mode & last_wr_addr) ? mem_wr_base_addr : wr_addr_wraparound ? MEM_ADDR_WIDTH'(mem_wr_addr_nxt - MEM_LAST_ADDR) : mem_wr_addr_nxt[MEM_ADDR_WIDTH-1:0]; + else + mem_wr_addr <= wr_addr_wraparound ? MEM_ADDR_WIDTH'(mem_wr_addr_nxt - MEM_LAST_ADDR) : mem_wr_addr_nxt[MEM_ADDR_WIDTH-1:0]; end end @@ -344,7 +363,6 @@ end //------------------------------------------ //Twiddle addr logic //------------------------------------------ -logic [6:0] twiddle_rand_offset; always_comb begin unique case(rounds_count) 'h0: begin @@ -392,9 +410,9 @@ always_ff @(posedge clk or negedge reset_n) begin else if (zeroize) twiddle_addr_reg <= 'h0; else if (incr_twiddle_addr) - twiddle_addr_reg <= /*gs_mode & (twiddle_addr_reg == twiddle_end_addr) ? 'h0 :*/ ct_mode ? twiddle_rand_offset : twiddle_rand_offset; //twiddle_addr_reg + 'd1; + twiddle_addr_reg <= shuffle_en ? twiddle_rand_offset : (twiddle_addr_reg == twiddle_end_addr) ? 'h0 : twiddle_addr_reg + 'd1; else if (rst_twiddle_addr) - twiddle_addr_reg <= /*ct_mode ? twiddle_rand_offset :*/ 'h0; //TODO: gs mode + twiddle_addr_reg <= 'h0; end assign twiddle_addr_int = twiddle_addr_reg + twiddle_offset; @@ -754,13 +772,14 @@ always_comb begin //Move to WR_WAIT state when the last outputs from bf2x2 have been captured in the buffers. They still need to be shifted out of the buffers and into memory, so keep buf_wren 1 here //Assumption - no bubbles in NTT or INTT. If bubbles, need to consider sampler_valid //TODO: can WR_WAIT state be removed? fsm can finish all 64 addr in WR_MEM state? - arc_WR_MEM_WR_WAIT = (write_fsm_state_ps == WR_MEM) && ((gs_mode && (buf0_valid && (wr_valid_count == 'h3c))) || (pwo_mode && !butterfly_ready && (wr_valid_count < 'h3f))); // || (ct_mode && (wr_valid_count == 'h3f))); + arc_WR_MEM_WR_WAIT = (write_fsm_state_ps == WR_MEM) && ((gs_mode && (buf0_valid && (wr_valid_count == 'h3c))) || (pwo_mode && !butterfly_ready && (wr_valid_count < 'h3f))); //This arc is only for pwo mode. Move back from wait to write state when there's a valid BFU output arc_WR_WAIT_WR_MEM = (write_fsm_state_ps == WR_WAIT) && (pwo_mode && butterfly_ready); //When valid_count is 64 and buf_count is 3 (meaning all 4 buffers have been used), move to WR_STAGE indicating that round is done - arc_WR_WAIT_WR_STAGE = (write_fsm_state_ps == WR_WAIT) && ((gs_mode && (buf_count == 'h3)) || ct_mode); + arc_WR_WAIT_WR_STAGE = shuffle_en ? (write_fsm_state_ps == WR_WAIT) && ((gs_mode && (buf_count == 'h3)) || ct_mode) + : (write_fsm_state_ps == WR_WAIT) && (!pwo_mode && (buf_count == 'h3)); end always_comb begin @@ -817,10 +836,10 @@ always_comb begin end WR_WAIT: begin write_fsm_state_ns = arc_WR_WAIT_WR_STAGE ? WR_STAGE : arc_WR_WAIT_WR_MEM ? WR_MEM : WR_WAIT; - buf_wren_intt = gs_mode & (buf_count <= 'h3); //1'b0; - buf_rden_intt = gs_mode; + buf_wren_intt = shuffle_en ? gs_mode & (buf_count <= 'h3) : (buf_count <= 'h3); //1'b0; + buf_rden_intt = shuffle_en ? gs_mode : 'b1; incr_mem_wr_addr = (ct_mode | gs_mode); //1'b1; - mem_wr_en_fsm = gs_mode; //1'b1; + mem_wr_en_fsm = shuffle_en ? gs_mode : (ct_mode | gs_mode); //1'b1; wr_addr_step = gs_mode ? INTT_WRITE_ADDR_STEP : NTT_WRITE_ADDR_STEP; incr_pw_wr_addr = arc_WR_WAIT_WR_MEM; pw_wren = arc_WR_WAIT_WR_MEM; From ca825ba5f61f91ec255e8ab4eb88345e699d04c8 Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Thu, 3 Oct 2024 12:05:43 -0700 Subject: [PATCH 05/17] shuffling updates for pwo, intt --- src/mldsa_top/rtl/mldsa_top.sv | 27 ++- .../mldsa/tb/testbench/hdl_top.sv | 6 +- src/ntt_top/Model/maksed_gadgets.py | 25 ++- src/ntt_top/Model/testForMasking.py | 31 +++ src/ntt_top/rtl/ntt_ctrl.sv | 185 ++++++++++++------ src/ntt_top/rtl/ntt_top.sv | 39 ++-- src/ntt_top/tb/ntt_top_tb.sv | 49 ++--- 7 files changed, 254 insertions(+), 108 deletions(-) diff --git a/src/mldsa_top/rtl/mldsa_top.sv b/src/mldsa_top/rtl/mldsa_top.sv index 56c5eea..3fcdd71 100644 --- a/src/mldsa_top/rtl/mldsa_top.sv +++ b/src/mldsa_top/rtl/mldsa_top.sv @@ -42,6 +42,7 @@ module mldsa_top input logic hready_i, input logic [1:0] htrans_i, input logic [2:0] hsize_i, + input logic [5:0] random, //ahb output output logic hresp_o, @@ -78,10 +79,10 @@ module mldsa_top logic [MLDSA_MEM_DATA_WIDTH-1:0] sampler_mem_data; logic [MLDSA_MEM_ADDR_WIDTH-1:0] sampler_mem_addr; - logic [1:0] sampler_ntt_dv; + logic [1:0] sampler_ntt_dv, sampler_ntt_dv_f; logic [1:0] sampler_ntt_mode; logic [1:0] sampler_valid; - logic [COEFF_PER_CLK-1:0][MLDSA_Q_WIDTH-1:0] sampler_ntt_data; + logic [COEFF_PER_CLK-1:0][MLDSA_Q_WIDTH-1:0] sampler_ntt_data, sampler_ntt_data_reg; mldsa_ntt_mode_e [1:0] ntt_mode; mode_t [1:0] mode; @@ -384,6 +385,21 @@ mldsa_sampler_top sampler_top_inst .sampler_state_data_o(sampler_state_data) ); +always_ff @(posedge clk or negedge rst_b) begin + if (!rst_b) begin + sampler_ntt_data_reg <= 0; + sampler_ntt_dv_f <= 0; + end + else if (zeroize_reg) begin + sampler_ntt_data_reg <= 0; + sampler_ntt_dv_f <= 0; + end + else begin + sampler_ntt_data_reg <= sampler_ntt_data; + sampler_ntt_dv_f <= sampler_ntt_dv; + end +end + assign sampler_ntt_dv[1] = 0; //no sampler interface to secondary ntt generate @@ -455,6 +471,7 @@ generate .pwo_mem_base_addr(pwo_mem_base_addr[g_inst]), .accumulate(accumulate[g_inst]), .sampler_valid(sampler_valid[g_inst]), + .random(random), //NTT mem IF .mem_wr_req(ntt_mem_wr_req[g_inst]), .mem_rd_req(ntt_mem_rd_req[g_inst]), @@ -464,7 +481,7 @@ generate .pwm_a_rd_req(pwm_a_rd_req[g_inst]), .pwm_b_rd_req(pwm_b_rd_req[g_inst]), .pwm_a_rd_data(pwm_a_rd_data[g_inst]), - .pwm_b_rd_data(sampler_ntt_mode[g_inst] ? sampler_ntt_data : pwm_b_rd_data[g_inst]), + .pwm_b_rd_data(sampler_ntt_mode[g_inst] ? sampler_ntt_data_reg : pwm_b_rd_data[g_inst]), .ntt_busy(ntt_busy[g_inst]), .ntt_done(ntt_done[g_inst]) ); @@ -859,7 +876,7 @@ always_comb begin for (int bank = 0; bank < 2; bank++) begin ntt_mem_re0_bank[0][bank] = (ntt_mem_rd_req[0].rd_wr_en == RW_READ) & (ntt_mem_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (ntt_mem_rd_req[0].addr[0] == bank); pwo_a_mem_re0_bank[0][bank] = (pwm_a_rd_req[0].rd_wr_en == RW_READ) & (pwm_a_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_a_rd_req[0].addr[0] == bank); - pwo_b_mem_re0_bank[0][bank] = ~sampler_ntt_dv & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_b_rd_req[0].addr[0] == bank); + pwo_b_mem_re0_bank[0][bank] = ~sampler_ntt_dv_f & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_b_rd_req[0].addr[0] == bank); ntt_mem_re0_bank[1][bank] = (ntt_mem_rd_req[1].rd_wr_en == RW_READ) & (ntt_mem_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (ntt_mem_rd_req[1].addr[0] == bank); pwo_a_mem_re0_bank[1][bank] = (pwm_a_rd_req[1].rd_wr_en == RW_READ) & (pwm_a_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_a_rd_req[1].addr[0] == bank); @@ -893,7 +910,7 @@ always_comb begin end else begin ntt_mem_re[0][i] = (ntt_mem_rd_req[0].rd_wr_en == RW_READ) & (ntt_mem_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); pwo_a_mem_re[0][i] = (pwm_a_rd_req[0].rd_wr_en == RW_READ) & (pwm_a_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); - pwo_b_mem_re[0][i] = ~sampler_ntt_dv & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); + pwo_b_mem_re[0][i] = ~sampler_ntt_dv_f & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); ntt_mem_re[1][i] = (ntt_mem_rd_req[1].rd_wr_en == RW_READ) & (ntt_mem_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); pwo_a_mem_re[1][i] = (pwm_a_rd_req[1].rd_wr_en == RW_READ) & (pwm_a_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); diff --git a/src/mldsa_top/uvmf/uvmf_template_output/project_benches/mldsa/tb/testbench/hdl_top.sv b/src/mldsa_top/uvmf/uvmf_template_output/project_benches/mldsa/tb/testbench/hdl_top.sv index b04b3b2..e7eb9d2 100644 --- a/src/mldsa_top/uvmf/uvmf_template_output/project_benches/mldsa/tb/testbench/hdl_top.sv +++ b/src/mldsa_top/uvmf/uvmf_template_output/project_benches/mldsa/tb/testbench/hdl_top.sv @@ -54,13 +54,16 @@ import uvmf_base_pkg_hdl::*; // pragma uvmf custom clock_generator begin bit clk; + logic [5:0] random_tb; // Instantiate a clk driver // tbx clkgen initial begin clk = 0; + random_tb = 0; #0ns; forever begin clk = ~clk; + random_tb = $urandom(); #5ns; end end @@ -101,7 +104,8 @@ import uvmf_base_pkg_hdl::*; .hsize_i (uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HSIZE ), .hresp_o (uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HRESP ), .hreadyout_o(uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HREADY ), - .hrdata_o (uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HRDATA ) + .hrdata_o (uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HRDATA ), + .random (random_tb) ); assign uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HBURST = 3'b0; diff --git a/src/ntt_top/Model/maksed_gadgets.py b/src/ntt_top/Model/maksed_gadgets.py index eaea8d8..9403a52 100644 --- a/src/ntt_top/Model/maksed_gadgets.py +++ b/src/ntt_top/Model/maksed_gadgets.py @@ -20,7 +20,7 @@ def one_share_mult(a0, a1, b): randomness = CustomUnsignedInteger(0, 0, MultMod-1) #get a random number ranging [0, MultMod-1] randomness.generate_random() - r1 = int(randomness.value) + r1 = int(randomness.value) #optional #refresh the shares a00 = int(a0+r1) % MultMod a10 = int(a1-r1) % MultMod @@ -28,6 +28,27 @@ def one_share_mult(a0, a1, b): c1 = int(a10*b) % MultMod return c0, c1 +def two_share_mult(a0, a1, b0, b1): + # Construct randomness class + randomness = CustomUnsignedInteger(0, 0, MultMod-1) + #get a random number ranging [0, MultMod-1] + randomness.generate_random() + r1 = int(randomness.value) + # #refresh the shares + # a00 = int(a0+r1) % MultMod + # a10 = int(a1-r1) % MultMod + # c0 = int(a00*b) % MultMod + # c1 = int(a10*b) % MultMod + c0 = int(a0*b0) % MultMod + d0 = int(a1*b0) % MultMod + c1 = int(a0*b1) % MultMod + d1 = int(a1*b1) % MultMod + e0 = int(c1+r1) % MultMod + e1 = int(d0-r1) % MultMod + final_res0 = int(c0+e0) % MultMod + final_res1 = int(d1+e1) % MultMod + return final_res0, final_res1 + def maskedAdder(a0, a1, b0, b1): # Construct randomness class randomness = CustomUnsignedInteger(0, 0, MultMod-1) @@ -185,7 +206,7 @@ def unMaskedModAddition(x0, x1, y0, y1): y = y0 ^ y1 z = (x+y) % DILITHIUM_Q randomness = CustomUnsignedInteger(0, 0, (2**23)-1) - # Refresh the shares + # Refresh the shares -- not optional randomness.generate_random() r1 = int(randomness.value) z0 = z ^ r1 diff --git a/src/ntt_top/Model/testForMasking.py b/src/ntt_top/Model/testForMasking.py index 601330e..9c88c6b 100644 --- a/src/ntt_top/Model/testForMasking.py +++ b/src/ntt_top/Model/testForMasking.py @@ -376,6 +376,37 @@ def test_maskedBFU_CT(numTest = 10): if vNew != exp_v: print(f"CT Lower branch gives an Error; gotten = {vNew}, while exp = {exp_v}") +def test_twoshare_mult(numTest = 10): + randomness = CustomUnsignedInteger(0, 0, MultMod-1) + operands = CustomUnsignedInteger(0, 0, DILITHIUM_Q-1) + for i in range(0, numTest): + #get a random number ranging [0, DILITHIUM_Q-1] + #generate inputs + operands.generate_random() + u = int(operands.value) + operands.generate_random() + v = int(operands.value) + #calculate expected result + exp_uv = u*v + #Split inputs to shares + randomness.generate_random() + r0 = int(randomness.value) + u0 = int(u-r0) % MultMod + u1 = r0 + randomness.generate_random() + r1 = int(randomness.value) + v0 = int(v-r1) % MultMod + v1 = r1 + #Test two share mult + uv0, uv1 = two_share_mult(u0, u1, v0, v1) + uv = int(uv0 + uv1) % MultMod + #Check result + if uv != exp_uv: + print(f"Incorrect mult op. Operands: {u, v}, Shares: {u0, u1, v0, v1} Exp = {exp_uv}, actual = {uv}") + + + test_maskedBFU_CT(numTest = 100000) test_maskedBFU_GS(numTest = 100000) +test_twoshare_mult(numTest = 100000) diff --git a/src/ntt_top/rtl/ntt_ctrl.sv b/src/ntt_top/rtl/ntt_ctrl.sv index 2273913..a370cb0 100644 --- a/src/ntt_top/rtl/ntt_ctrl.sv +++ b/src/ntt_top/rtl/ntt_ctrl.sv @@ -32,7 +32,7 @@ module ntt_ctrl parameter MLDSA_N = 256, parameter MLDSA_LOGN = 8, parameter MEM_ADDR_WIDTH = 15, - parameter BF_LATENCY = 11, //TODO: change back to 10 and test shuffling //5 cycles per butterfly * 2 instances in serial = 10 clks + parameter BF_LATENCY = 10, //11 //TODO: change back to 10 and test shuffling //5 cycles per butterfly * 2 instances in serial = 10 clks parameter NTT_BUF_LATENCY = 4 ) ( @@ -91,6 +91,7 @@ localparam INTT_READ_ADDR_STEP = 1; localparam INTT_WRITE_ADDR_STEP = 16; localparam PWO_READ_ADDR_STEP = 1; localparam PWO_WRITE_ADDR_STEP = 1; +localparam PWM_LATENCY = 5; localparam [MEM_ADDR_WIDTH-1:0] MEM_LAST_ADDR = 63; localparam INTT_WRBUF_LATENCY = 13; //includes BF latency + mem latency for shuffled reads to begin @@ -112,13 +113,15 @@ logic [3:0] chunk_count; logic [1:0] index_rand_offset, index_count, mem_rd_index_ofst; logic [1:0] buf_rdptr_int; logic [1:0] buf_rdptr_f; -logic [BF_LATENCY-1:0][1:0] buf_rdptr_reg; +logic [BF_LATENCY:0][1:0] buf_rdptr_reg; logic [INTT_WRBUF_LATENCY-1:0][1:0] buf_wrptr_reg; -logic [BF_LATENCY-1:0][3:0] chunk_count_reg; +logic [BF_LATENCY:0][3:0] chunk_count_reg; logic latch_chunk_rand_offset, latch_index_rand_offset; logic last_rd_addr, last_wr_addr; logic mem_wr_en_fsm, mem_wr_en_reg; logic mem_rd_en_fsm, mem_rd_en_reg; +logic pw_rden_fsm, pw_rden_reg; +logic pw_wren_fsm, pw_wren_reg, pw_wren_reg_d2; //Mode flags logic ct_mode, gs_mode, pwo_mode; //point-wise operations mode @@ -127,6 +130,7 @@ logic pwm_mode, pwa_mode, pws_mode; //Addr internal wires logic [MEM_ADDR_WIDTH-1:0] src_base_addr, interim_base_addr, dest_base_addr; logic [MEM_ADDR_WIDTH-1:0] pw_base_addr_a, pw_base_addr_b, pw_base_addr_c; +logic [MEM_ADDR_WIDTH-1:0] pw_mem_rd_addr_a_nxt, pw_mem_rd_addr_b_nxt, pw_mem_rd_addr_c_nxt, pw_mem_wr_addr_c_nxt; logic incr_mem_rd_addr; logic incr_mem_wr_addr; logic rst_rd_addr, rst_wr_addr; //TODO: need both? @@ -141,10 +145,10 @@ logic incr_pw_rd_addr, incr_pw_wr_addr; //TODO: need both? logic rst_pw_addr; //Twiddle ROM wires -logic incr_twiddle_addr, incr_twiddle_addr_fsm, incr_twiddle_addr_reg; +logic incr_twiddle_addr, incr_twiddle_addr_fsm, incr_twiddle_addr_reg, incr_twiddle_addr_reg_d2; logic twiddle_mode, rst_twiddle_addr; -logic [6:0] twiddle_end_addr, twiddle_addr_reg, twiddle_addr_reg_d2, twiddle_addr_int, twiddle_offset; logic [6:0] twiddle_rand_offset; +logic [6:0] twiddle_end_addr, twiddle_addr_reg, twiddle_addr_reg_d2, twiddle_addr_reg_d3, twiddle_addr_int, twiddle_offset; //FSM round signals logic [$clog2(NTT_NUM_ROUNDS):0] num_rounds; @@ -189,7 +193,7 @@ logic arc_EXEC_WAIT_RD_EXEC; //Other signals logic buf_wren_ntt, buf_wren_ntt_reg; -logic buf_wren_intt; +logic buf_wren_intt, buf_wren_intt_reg; logic buf_rden_ntt, buf_rden_ntt_reg; logic buf_rden_intt; @@ -273,8 +277,8 @@ always_comb begin wr_addr_wraparound = mem_wr_addr_nxt > {1'b0,mem_wr_base_addr} + MEM_LAST_ADDR; if (shuffle_en) begin - mem_rd_addr_nxt = gs_mode ? (4*chunk_count) + (rd_addr_step*mem_rd_index_ofst) + mem_rd_base_addr : mem_rd_addr + rd_addr_step; - mem_wr_addr_nxt = ct_mode ? (4*(chunk_count_reg[0])) + (wr_addr_step*buf_rdptr_reg[0]) + mem_wr_base_addr : mem_wr_addr + wr_addr_step; + mem_rd_addr_nxt = (gs_mode | pwo_mode) ? (4*chunk_count) + (rd_addr_step*mem_rd_index_ofst) + mem_rd_base_addr : mem_rd_addr + rd_addr_step; //TODO pwo modes + mem_wr_addr_nxt = ct_mode ? (4*(chunk_count_reg[0])) + (wr_addr_step*buf_rdptr_reg[0]) + mem_wr_base_addr : gs_mode ? mem_wr_addr + wr_addr_step : (4*(chunk_count_reg[4])) + (wr_addr_step*buf_rdptr_reg[4]); //TODO: pwo modes end else begin mem_rd_addr_nxt = mem_rd_addr + rd_addr_step; @@ -292,15 +296,15 @@ always_ff @(posedge clk or negedge reset_n) begin end else if (rst_rd_addr) begin if (shuffle_en) - mem_rd_addr <= ct_mode ? mem_rd_base_addr + chunk_rand_offset : gs_mode ? mem_rd_base_addr + (4*chunk_rand_offset) : mem_rd_base_addr; + mem_rd_addr <= ct_mode ? mem_rd_base_addr + chunk_rand_offset : (gs_mode | pwo_mode) ? mem_rd_base_addr + (4*chunk_rand_offset) : mem_rd_base_addr; //TODO: pwo else mem_rd_addr <= mem_rd_base_addr; end else if (incr_mem_rd_addr) begin if (shuffle_en) - mem_rd_addr <= last_rd_addr ? mem_rd_base_addr : rd_addr_wraparound ? MEM_ADDR_WIDTH'(mem_rd_addr_nxt - MEM_LAST_ADDR) : mem_rd_addr_nxt[MEM_ADDR_WIDTH-1:0]; + mem_rd_addr <= (ct_mode & last_rd_addr) ? mem_rd_base_addr : rd_addr_wraparound ? MEM_ADDR_WIDTH'(mem_rd_addr_nxt - MEM_LAST_ADDR) : mem_rd_addr_nxt[MEM_ADDR_WIDTH-1:0]; else - mem_rd_addr <= rd_addr_wraparound ? MEM_ADDR_WIDTH'(mem_rd_addr_nxt - MEM_LAST_ADDR) : mem_rd_addr_nxt[MEM_ADDR_WIDTH-1:0]; + mem_rd_addr <= rd_addr_wraparound ? MEM_ADDR_WIDTH'(mem_rd_addr_nxt - MEM_LAST_ADDR) : mem_rd_addr_nxt[MEM_ADDR_WIDTH-1:0]; end end @@ -314,9 +318,9 @@ always_ff @(posedge clk or negedge reset_n) begin end else if (rst_wr_addr) begin if (shuffle_en) - mem_wr_addr <= ct_mode ? mem_wr_base_addr + (4*chunk_rand_offset) : gs_mode ? mem_wr_base_addr + chunk_rand_offset : mem_wr_base_addr; //TODO: pwo + mem_wr_addr <= (ct_mode | pwo_mode) ? mem_wr_base_addr + (4*chunk_rand_offset) : gs_mode ? mem_wr_base_addr + chunk_rand_offset : mem_wr_base_addr; else - mem_wr_addr <= mem_wr_base_addr; + mem_wr_addr <= mem_wr_base_addr; end else if (incr_mem_wr_addr) begin if (shuffle_en) @@ -326,6 +330,15 @@ always_ff @(posedge clk or negedge reset_n) begin end end +always_comb begin + pw_mem_rd_addr_a_nxt = pw_base_addr_a + (4*chunk_count) + (PWO_READ_ADDR_STEP*mem_rd_index_ofst); + pw_mem_rd_addr_b_nxt = pw_base_addr_b + (4*chunk_count) + (PWO_READ_ADDR_STEP*mem_rd_index_ofst); + pw_mem_rd_addr_c_nxt = accumulate ? pw_base_addr_c + ((4*chunk_count)+(PWO_READ_ADDR_STEP*mem_rd_index_ofst)) : 'h0; //TODO check timing + pw_mem_wr_addr_c_nxt = accumulate ? pw_base_addr_c + (4*chunk_count_reg[PWM_LATENCY-2]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[PWM_LATENCY-2]) + : (pwa_mode | pws_mode) ? pw_base_addr_c + (4*chunk_count_reg[7]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[7]) + : pw_base_addr_c + (4*chunk_count_reg[PWM_LATENCY-1]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[PWM_LATENCY-1]); //2 +end + //PWO addr always_ff @(posedge clk or negedge reset_n) begin if (!reset_n) begin @@ -341,20 +354,20 @@ always_ff @(posedge clk or negedge reset_n) begin pw_mem_wr_addr_c <= '0; end else if (rst_pw_addr) begin - pw_mem_rd_addr_a <= pw_base_addr_a; - pw_mem_rd_addr_b <= pw_base_addr_b; - pw_mem_rd_addr_c <= pw_base_addr_c; - pw_mem_wr_addr_c <= pw_base_addr_c; + pw_mem_rd_addr_a <= (4*chunk_rand_offset) + pw_base_addr_a; + pw_mem_rd_addr_b <= (4*chunk_rand_offset) + pw_base_addr_b; + pw_mem_rd_addr_c <= (4*chunk_rand_offset) + pw_base_addr_c; + pw_mem_wr_addr_c <= (4*chunk_rand_offset) + pw_base_addr_c; end else begin if (incr_pw_rd_addr) begin - pw_mem_rd_addr_a <= pw_mem_rd_addr_a + PWO_READ_ADDR_STEP; - pw_mem_rd_addr_b <= pw_mem_rd_addr_b + PWO_READ_ADDR_STEP; - pw_mem_rd_addr_c <= accumulate ? pw_mem_rd_addr_c + PWO_READ_ADDR_STEP : 'h0; //addr in sync with a, b. However, the data is flopped 4 cycles inside BF to align with mul result + pw_mem_rd_addr_a <= pw_mem_rd_addr_a_nxt;//pw_mem_rd_addr_a + PWO_READ_ADDR_STEP; + pw_mem_rd_addr_b <= pw_mem_rd_addr_b_nxt;//pw_mem_rd_addr_b + PWO_READ_ADDR_STEP; + pw_mem_rd_addr_c <= accumulate ? pw_mem_rd_addr_c_nxt : 'h0;//accumulate ? pw_mem_rd_addr_c + PWO_READ_ADDR_STEP : 'h0; //addr in sync with a, b. However, the data is flopped 4 cycles inside BF to align with mul result end if (incr_pw_wr_addr) begin - pw_mem_wr_addr_c <= pw_mem_wr_addr_c + PWO_WRITE_ADDR_STEP; + pw_mem_wr_addr_c <= pw_mem_wr_addr_c_nxt; //pw_mem_wr_addr_c + PWO_WRITE_ADDR_STEP; end end end @@ -368,17 +381,17 @@ always_comb begin 'h0: begin twiddle_end_addr = ct_mode ? 'd0 : 'd63; twiddle_offset = 'h0; - twiddle_rand_offset = ct_mode ? 'h0 : (chunk_count)*4 + buf_wrptr_reg[10]; + twiddle_rand_offset = ct_mode ? 'h0 : (chunk_count_reg[BF_LATENCY])*4 + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]; end 'h1: begin twiddle_end_addr = ct_mode ? 'd3 : 'd15; twiddle_offset = ct_mode ? 'd1 : 'd64; - twiddle_rand_offset = ct_mode ? buf_rdptr_int : (chunk_count % 4)*4 + buf_wrptr_reg[10]; + twiddle_rand_offset = ct_mode ? buf_rdptr_int : (chunk_count_reg[BF_LATENCY] % 4)*4 + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]; end 'h2: begin twiddle_end_addr = ct_mode ? 'd15 : 'd3; twiddle_offset = ct_mode ? 'd5 : 'd80; - twiddle_rand_offset = ct_mode ? (chunk_count % 'd4)*'d4 + buf_rdptr_int : buf_wrptr_reg[10]; + twiddle_rand_offset = ct_mode ? (chunk_count % 'd4)*'d4 + buf_rdptr_int : buf_wrptr_reg[INTT_WRBUF_LATENCY-1]; end 'h3: begin twiddle_end_addr = ct_mode ? 'd63 : 'd0; @@ -395,13 +408,17 @@ end //Flop the incr and twiddle_addr to align with memory read latency always_ff @(posedge clk or negedge reset_n) begin - if (!reset_n) + if (!reset_n) begin incr_twiddle_addr_reg <= 'b0; - else + incr_twiddle_addr_reg_d2 <= 'b0; + end + else begin incr_twiddle_addr_reg <= incr_twiddle_addr_fsm; + incr_twiddle_addr_reg_d2 <= incr_twiddle_addr_reg; + end end -assign incr_twiddle_addr = ct_mode ? incr_twiddle_addr_fsm : incr_twiddle_addr_reg; +assign incr_twiddle_addr = ct_mode ? incr_twiddle_addr_fsm : incr_twiddle_addr_reg; //_d2; always_ff @(posedge clk or negedge reset_n) begin @@ -415,7 +432,7 @@ always_ff @(posedge clk or negedge reset_n) begin twiddle_addr_reg <= 'h0; end -assign twiddle_addr_int = twiddle_addr_reg + twiddle_offset; +assign twiddle_addr_int = ct_mode ? twiddle_addr_reg + twiddle_offset : twiddle_rand_offset + twiddle_offset; //------------------------------------------ //Busy logic @@ -489,7 +506,7 @@ always_ff @(posedge clk or negedge reset_n) begin chunk_rand_offset <= random[5:2]; chunk_count <= random[5:2]; end - else if ((ct_mode & (buf_count == 'h3)) | (gs_mode & (buf_wrptr_reg[10] == 'h3))) begin //update chunk after every 4 cycles + else if ((ct_mode & (buf_count == 'h3)) | ((gs_mode | (pwo_mode & incr_pw_rd_addr)) & /*(buf_wrptr_reg[10] == 'h3)*/(index_count == 'h3))) begin //update chunk after every 4 cycles - TODO: stop chunk counting when there's no incr_rd_addr in ntt/intt modes chunk_count <= (chunk_count == 'hf) ? 'h0 : chunk_count + 'h1; end end @@ -519,11 +536,14 @@ always_ff @(posedge clk or negedge reset_n) begin buf_wrptr_reg <= 'h0; end else if (ct_mode & (buf_rden_ntt | butterfly_ready)) begin - buf_rdptr_reg <= {buf_rdptr_int, buf_rdptr_reg[BF_LATENCY-1:1]}; + buf_rdptr_reg <= {buf_rdptr_int, buf_rdptr_reg[BF_LATENCY:1]}; end else if (gs_mode & (incr_mem_rd_addr | butterfly_ready)) begin buf_wrptr_reg <= {mem_rd_index_ofst, buf_wrptr_reg[INTT_WRBUF_LATENCY-1:1]}; end + else if (pwo_mode & (incr_pw_rd_addr | butterfly_ready)) begin + buf_rdptr_reg <= {mem_rd_index_ofst, buf_rdptr_reg[BF_LATENCY:1]}; //TODO: create new reg with apt name for PWO + end else begin buf_rdptr_reg <= 'h0; buf_wrptr_reg <= 'h0; @@ -549,7 +569,7 @@ always_ff @(posedge clk or negedge reset_n) begin else if (zeroize) begin index_count <= 'h0; end - else if (gs_mode & incr_mem_rd_addr) begin + else if ((gs_mode & (incr_mem_rd_addr)) | (pwo_mode & incr_pw_rd_addr)) begin index_count <= index_count + 'h1; end end @@ -561,8 +581,8 @@ always_ff @(posedge clk or negedge reset_n) begin else if (zeroize) begin chunk_count_reg <= 'h0; end - else if (buf_rden_ntt | butterfly_ready | (gs_mode & incr_mem_rd_addr)) begin //TODO: replace gs condition with an fsm generated flag perhaps? - chunk_count_reg <= {chunk_count, chunk_count_reg[BF_LATENCY-1:1]}; + else if (buf_rden_ntt | butterfly_ready | (gs_mode & incr_mem_rd_addr) | (pwo_mode & incr_pw_rd_addr)) begin //TODO: replace gs condition with an fsm generated flag perhaps? + chunk_count_reg <= {chunk_count, chunk_count_reg[BF_LATENCY:1]}; end end @@ -576,8 +596,8 @@ always_ff @(posedge clk or negedge reset_n) begin else if (buf_wren & ct_mode) begin //ct mode - buf writes are in order buf_wrptr <= (buf_wrptr == 'h3) ? 'h0 : buf_wrptr + 'h1; end - else if (buf_wren & gs_mode) begin // gs mode - buf_wrptr <= buf_wrptr_reg[1]; //equivalent to [0] due to this flop //(buf_wrptr == 'h3) ? 'h0 : buf_wrptr + 'h1; + else if (buf_wren_intt & gs_mode) begin // gs mode + buf_wrptr <= buf_wrptr_reg[0]; end end @@ -588,8 +608,8 @@ always_comb begin buf_rdptr = ct_mode ? buf_rdptr_f : buf_count; // buf_wrptr = gs_mode ? index_rand_offset + buf_count : buf_count; latch_chunk_rand_offset = arc_IDLE_WR_STAGE | arc_WR_MEM_WR_STAGE | arc_WR_WAIT_WR_STAGE; - latch_index_rand_offset = ct_mode ? (buf_wrptr == 'h3) : gs_mode & (arc_RD_STAGE_RD_EXEC | (index_count == 'h3)); //TODO pwo mode - mem_rd_index_ofst = gs_mode ? (index_count + index_rand_offset) : 'h0; //TODO: pwo mode, not used in ct mode + latch_index_rand_offset = ct_mode ? (buf_wrptr == 'h3) : (gs_mode | (pwo_mode & incr_pw_rd_addr)) & (arc_RD_STAGE_RD_EXEC | (index_count == 'h3)); //TODO pwo mode + mem_rd_index_ofst = (pwo_mode | gs_mode) ? (index_count + index_rand_offset) : 'h0; //TODO: pwo mode, not used in ct mode end @@ -663,7 +683,7 @@ always_comb begin buf_rd_rst_count_ntt = 1'b0; rst_twiddle_addr = 1'b0; incr_pw_rd_addr = 1'b0; - pw_rden = 1'b0; + pw_rden_fsm = 1'b0; unique case(read_fsm_state_ps) RD_IDLE: begin read_fsm_state_ns = arc_IDLE_RD_STAGE ? RD_STAGE : RD_IDLE; @@ -701,12 +721,12 @@ always_comb begin buf_wren_ntt = ct_mode; buf_rden_ntt = ct_mode; incr_mem_rd_addr = (ntt_mode inside {ct, gs}); - mem_rd_en_fsm = (ntt_mode inside {ct, gs}) ? (mem_rd_addr <= MEM_LAST_ADDR + mem_rd_base_addr) : 1'b0; + mem_rd_en_fsm = (ntt_mode inside {ct, gs}) ? (mem_rd_addr <= MEM_LAST_ADDR + mem_rd_base_addr) & ~arc_RD_EXEC_EXEC_WAIT : 1'b0; bf_enable_fsm = pwo_mode ? sampler_valid : 1'b1; incr_twiddle_addr_fsm = ntt_mode inside {ct, gs}; //1'b1; rd_addr_step = ct_mode ? NTT_READ_ADDR_STEP : INTT_READ_ADDR_STEP; incr_pw_rd_addr = sampler_valid & pwo_mode; - pw_rden = sampler_valid & pwo_mode; + pw_rden_fsm = sampler_valid & pwo_mode; end EXEC_WAIT: begin read_fsm_state_ns = arc_EXEC_WAIT_RD_STAGE ? RD_STAGE : arc_EXEC_WAIT_RD_EXEC ? RD_EXEC : EXEC_WAIT; @@ -718,7 +738,7 @@ always_comb begin incr_twiddle_addr_fsm = (ct_mode | gs_mode); rd_addr_step = NTT_READ_ADDR_STEP; incr_pw_rd_addr = (pwo_mode & sampler_valid); - pw_rden = (pwo_mode & sampler_valid); + pw_rden_fsm = (pwo_mode & sampler_valid); end default: begin read_fsm_state_ns = RD_IDLE; @@ -744,13 +764,13 @@ always_comb begin arc_IDLE_WR_STAGE = (write_fsm_state_ps == WR_IDLE) && ntt_enable ; //This arc is only for ct mode. No buffer in the path, so wait for all addr to be written (0-63) before transitioning to WR_STAGE - arc_WR_MEM_WR_STAGE = (write_fsm_state_ps == WR_MEM) && ((ct_mode || pwo_mode) && (wr_valid_count == 'h3f)); //(mem_wr_addr == (mem_wr_base_addr + MEM_LAST_ADDR)); //this arc is for ct mode, + arc_WR_MEM_WR_STAGE = (write_fsm_state_ps == WR_MEM) && ((ct_mode) && (wr_valid_count == 'h3f)); //(mem_wr_addr == (mem_wr_base_addr + MEM_LAST_ADDR)); //this arc is for ct mode, //All rounds of NTT or INTT are done. Go to IDLE and wait for next command arc_WR_STAGE_IDLE = (write_fsm_state_ps == WR_STAGE) && (ntt_done || intt_done || pwo_done); //This arc is only for ct mode since there's no output buffer - arc_WR_STAGE_WR_MEM = (write_fsm_state_ps == WR_STAGE) && ((ct_mode && !ntt_done)); // || (pwo_mode && (!pwo_done /*|| ntt_enable*/))); + arc_WR_STAGE_WR_MEM = (write_fsm_state_ps == WR_STAGE) && ((ct_mode && !ntt_done) || (pwo_mode && !pwo_done)); // || (pwo_mode && (!pwo_done /*|| ntt_enable*/))); //pwm arc. If in WR_STAGE, read fsm is executing, go back to WR_MEM state to perform current round's writes arc_WR_STAGE_WR_MEM_OPT = (write_fsm_state_ps == WR_STAGE) && (read_fsm_state_ps == RD_EXEC) && (pwo_mode && pwo_busy); @@ -772,13 +792,14 @@ always_comb begin //Move to WR_WAIT state when the last outputs from bf2x2 have been captured in the buffers. They still need to be shifted out of the buffers and into memory, so keep buf_wren 1 here //Assumption - no bubbles in NTT or INTT. If bubbles, need to consider sampler_valid //TODO: can WR_WAIT state be removed? fsm can finish all 64 addr in WR_MEM state? - arc_WR_MEM_WR_WAIT = (write_fsm_state_ps == WR_MEM) && ((gs_mode && (buf0_valid && (wr_valid_count == 'h3c))) || (pwo_mode && !butterfly_ready && (wr_valid_count < 'h3f))); + arc_WR_MEM_WR_WAIT = shuffle_en ? (write_fsm_state_ps == WR_MEM) && ((gs_mode && (buf0_valid && (wr_valid_count == 'h3c))) || (pwo_mode && butterfly_ready && (wr_valid_count == 'h3f))) + : (write_fsm_state_ps == WR_MEM) && ((gs_mode && (buf0_valid && (wr_valid_count == 'h3c))) || (pwo_mode && !butterfly_ready && (wr_valid_count < 'h3f))); // || (ct_mode && (wr_valid_count == 'h3f))); //This arc is only for pwo mode. Move back from wait to write state when there's a valid BFU output arc_WR_WAIT_WR_MEM = (write_fsm_state_ps == WR_WAIT) && (pwo_mode && butterfly_ready); //When valid_count is 64 and buf_count is 3 (meaning all 4 buffers have been used), move to WR_STAGE indicating that round is done - arc_WR_WAIT_WR_STAGE = shuffle_en ? (write_fsm_state_ps == WR_WAIT) && ((gs_mode && (buf_count == 'h3)) || ct_mode) + arc_WR_WAIT_WR_STAGE = shuffle_en ? (write_fsm_state_ps == WR_WAIT) && ((gs_mode && (buf_count == 'h3)) || ct_mode || pwo_mode) : (write_fsm_state_ps == WR_WAIT) && (!pwo_mode && (buf_count == 'h3)); end @@ -794,7 +815,7 @@ always_comb begin buf_wr_rst_count_intt = 1'b0; buf_rd_rst_count_intt = 1'b0; incr_pw_wr_addr = 1'b0; - pw_wren = 1'b0; + pw_wren_fsm = 1'b0; rst_pw_addr = 1'b0; unique case(write_fsm_state_ps) WR_IDLE: begin @@ -806,7 +827,7 @@ always_comb begin WR_STAGE: begin write_fsm_state_ns = arc_WR_STAGE_WR_MEM ? WR_MEM : arc_WR_STAGE_WR_BUF ? WR_BUF : - arc_WR_STAGE_WR_WAIT? WR_WAIT : + // arc_WR_STAGE_WR_WAIT? WR_WAIT : arc_WR_STAGE_IDLE ? WR_IDLE : WR_STAGE; rst_wr_addr = 1'b1; rst_wr_valid_count = 1'b1; @@ -832,17 +853,24 @@ always_comb begin mem_wr_en_fsm = ct_mode ? butterfly_ready : gs_mode ? 1'b1 : 1'b0; wr_addr_step = ct_mode ? NTT_WRITE_ADDR_STEP : INTT_WRITE_ADDR_STEP; incr_pw_wr_addr = pwo_mode & butterfly_ready; - pw_wren = pwo_mode & butterfly_ready; + pw_wren_fsm = pwo_mode & butterfly_ready; end WR_WAIT: begin - write_fsm_state_ns = arc_WR_WAIT_WR_STAGE ? WR_STAGE : arc_WR_WAIT_WR_MEM ? WR_MEM : WR_WAIT; - buf_wren_intt = shuffle_en ? gs_mode & (buf_count <= 'h3) : (buf_count <= 'h3); //1'b0; + if (shuffle_en) begin + write_fsm_state_ns = arc_WR_WAIT_WR_STAGE ? WR_STAGE : /*arc_WR_WAIT_WR_MEM ? WR_MEM :*/ WR_WAIT; + wr_addr_step = gs_mode ? INTT_WRITE_ADDR_STEP : NTT_WRITE_ADDR_STEP; + end + else begin + write_fsm_state_ns = arc_WR_WAIT_WR_STAGE ? WR_STAGE : arc_WR_WAIT_WR_MEM ? WR_MEM : WR_WAIT; + wr_addr_step = INTT_WRITE_ADDR_STEP; + end + buf_wren_intt = shuffle_en ? 'b0 : (buf_count <= 'h3); //1'b0; buf_rden_intt = shuffle_en ? gs_mode : 'b1; incr_mem_wr_addr = (ct_mode | gs_mode); //1'b1; mem_wr_en_fsm = shuffle_en ? gs_mode : (ct_mode | gs_mode); //1'b1; - wr_addr_step = gs_mode ? INTT_WRITE_ADDR_STEP : NTT_WRITE_ADDR_STEP; - incr_pw_wr_addr = arc_WR_WAIT_WR_MEM; - pw_wren = arc_WR_WAIT_WR_MEM; + + incr_pw_wr_addr = pwo_mode & arc_WR_WAIT_WR_STAGE; //MEM; + // pw_wren_fsm = arc_WR_WAIT_WR_STAGE; //MEM; end default: begin write_fsm_state_ns = WR_IDLE; @@ -850,44 +878,75 @@ always_comb begin endcase end -assign rst_rounds = (read_fsm_state_ps == RD_IDLE) && (write_fsm_state_ps == WR_IDLE); -assign incr_rounds = arc_WR_MEM_WR_STAGE | arc_WR_WAIT_WR_STAGE; //TODO: revisit for high-perf mode (if we go with above opt) -assign buf_wren = pwo_mode ? 1'b0 : buf_wren_ntt_reg | buf_wren_intt; -assign buf_rden = pwo_mode ? 1'b0 : ct_mode ? buf_rden_ntt_reg : /*buf_rden_ntt |*/ buf_rden_intt; -assign bf_enable = (gs_mode || pwo_mode) ? bf_enable_reg_d2 : bf_enable_reg; //bf_enable_fsm; //In gs mode, memory is directly feeding bf2x2, so we need to enable it one cycle later -assign buf_wr_rst_count = pwo_mode ? 1'b1 : buf_wr_rst_count_ntt | buf_wr_rst_count_intt; -assign buf_rd_rst_count = pwo_mode ? 1'b1 : buf_rd_rst_count_ntt | buf_rd_rst_count_intt; -assign mem_wr_en = gs_mode ? mem_wr_en_fsm : mem_wr_en_reg; //TODO pwo mode, GS mode + shuffling -assign mem_rd_en = gs_mode ? mem_rd_en_reg : mem_rd_en_fsm; //TODO pwo mode -assign twiddle_addr = gs_mode ? twiddle_addr_reg_d2 : twiddle_addr_int; +always_comb begin + rst_rounds = (read_fsm_state_ps == RD_IDLE) && (write_fsm_state_ps == WR_IDLE); + incr_rounds = arc_WR_MEM_WR_STAGE | arc_WR_WAIT_WR_STAGE; //TODO: revisit for high-perf mode (if we go with above opt) + if (shuffle_en) begin + buf_wren = pwo_mode ? 1'b0 : buf_wren_ntt_reg | buf_wren_intt_reg; + buf_rden = pwo_mode ? 1'b0 : ct_mode ? buf_rden_ntt_reg : /*buf_rden_ntt |*/ buf_rden_intt; + bf_enable = (gs_mode || pwo_mode) ? bf_enable_reg_d2 : bf_enable_reg; //bf_enable_fsm; //In gs mode, memory is directly feeding bf2x2, so we need to enable it one cycle later + mem_wr_en = gs_mode ? mem_wr_en_fsm : mem_wr_en_reg; //TODO pwo mode, GS mode + shuffling + mem_rd_en = (gs_mode | pwo_mode) ? mem_rd_en_reg : mem_rd_en_fsm; + twiddle_addr = gs_mode ? twiddle_addr_reg_d3 : twiddle_addr_int; + pw_rden = pw_rden_reg; + pw_wren = pwm_mode ? pw_wren_reg/*_d2*/ : pw_wren_reg; + end + else begin + buf_wren = pwo_mode ? 1'b0 : buf_wren_ntt_reg | buf_wren_intt; + buf_rden = pwo_mode ? 1'b0 : buf_rden_ntt | buf_rden_intt; + bf_enable = (gs_mode | pwo_mode) ? bf_enable_reg : bf_enable_fsm; //In gs mode, memory is directly feeding bf2x2, so we need to enable it one cycle later + mem_wr_en = mem_wr_en_fsm; + mem_rd_en = mem_rd_en_fsm; + twiddle_addr = twiddle_addr_int; + end + buf_wr_rst_count = pwo_mode ? 1'b1 : buf_wr_rst_count_ntt | buf_wr_rst_count_intt; + buf_rd_rst_count = pwo_mode ? 1'b1 : buf_rd_rst_count_ntt | buf_rd_rst_count_intt; + + +end always_ff @(posedge clk or negedge reset_n) begin if (!reset_n) begin buf_wren_ntt_reg <= 'b0; + buf_wren_intt_reg <= 'b0; buf_rden_ntt_reg <= 'b0; bf_enable_reg <= 'b0; bf_enable_reg_d2 <= 'b0; mem_wr_en_reg <= 'b0; mem_rd_en_reg <= 'b0; twiddle_addr_reg_d2 <= 'h0; + twiddle_addr_reg_d3 <= 'h0; + pw_rden_reg <= '0; + pw_wren_reg <= '0; + pw_wren_reg_d2 <= '0; end else if (zeroize) begin buf_wren_ntt_reg <= 'b0; + buf_wren_intt_reg <= 'b0; buf_rden_ntt_reg <= 'b0; bf_enable_reg <= 'b0; bf_enable_reg_d2 <= 'b0; mem_wr_en_reg <= 'b0; mem_rd_en_reg <= 'b0; twiddle_addr_reg_d2 <= 'h0; + twiddle_addr_reg_d3 <= 'h0; + pw_rden_reg <= '0; + pw_wren_reg <= '0; + pw_wren_reg_d2 <= '0; end else begin buf_wren_ntt_reg <= buf_wren_ntt; + buf_wren_intt_reg <= buf_wren_intt; buf_rden_ntt_reg <= buf_rden_ntt; bf_enable_reg <= bf_enable_fsm; bf_enable_reg_d2 <= bf_enable_reg; mem_wr_en_reg <= mem_wr_en_fsm; mem_rd_en_reg <= mem_rd_en_fsm; twiddle_addr_reg_d2 <= twiddle_addr_int; + twiddle_addr_reg_d3 <= twiddle_addr_reg_d2; + pw_rden_reg <= pw_rden_fsm; + pw_wren_reg <= pw_wren_fsm; + pw_wren_reg_d2 <= pw_wren_reg; end end diff --git a/src/ntt_top/rtl/ntt_top.sv b/src/ntt_top/rtl/ntt_top.sv index 68b6490..48ca3a3 100644 --- a/src/ntt_top/rtl/ntt_top.sv +++ b/src/ntt_top/rtl/ntt_top.sv @@ -98,7 +98,7 @@ module ntt_top logic mem_wren, mem_wren_reg, mem_wren_mux; logic [MLDSA_MEM_ADDR_WIDTH-1:0] mem_wr_addr, mem_wr_addr_reg, mem_wr_addr_mux; // logic [(4*REG_SIZE)-1:0] mem_wr_data; - logic [MEM_DATA_WIDTH-1:0] mem_wr_data_int, mem_wr_data_reg; + logic [MEM_DATA_WIDTH-1:0] mem_wr_data_int, mem_wr_data_reg, mem_wr_data_reg_d2; //Read IF logic mem_rden; @@ -129,6 +129,8 @@ module ntt_top pwo_t pwo_uv_o; logic pw_wren, pw_wren_reg; logic pw_rden, pw_rden_dest_mem; + logic sampler_valid_reg; + logic [MEM_DATA_WIDTH-1:0] pwm_b_rd_data_reg; //Flop ntt_ctrl pwm output wr addr to align with BFU output flop logic [MLDSA_MEM_ADDR_WIDTH-1:0] pwm_wr_addr_c_reg; @@ -165,7 +167,7 @@ module ntt_top assign mem_wr_req.addr = !pwo_mode ? mem_wr_addr_mux : pwm_wr_addr_c_reg; assign mem_wr_data_int = !pwo_mode ? (ct_mode ? {1'b0, uv_o_reg.v21_o, 1'b0, uv_o_reg.u21_o, 1'b0, uv_o_reg.v20_o, 1'b0, uv_o_reg.u20_o} : buf_data_o) : pwm_wr_data_reg; - assign mem_wr_data = mem_wr_data_int; //ct_mode ? mem_wr_data_reg : mem_wr_data_int; //TODO: gs, pwo modes + assign mem_wr_data = pwm_mode ? mem_wr_data_reg/*_d2*/ : (pwa_mode | pws_mode) ? mem_wr_data_reg : mem_wr_data_int; //ct_mode ? mem_wr_data_reg : mem_wr_data_int; //TODO: gs, pwo modes //mem rd - NTT/INTT mode, read ntt data. PWM mode, read accumulate data from c mem. PWA/S mode, unused assign mem_rd_req.rd_wr_en = (ct_mode || gs_mode) ? (mem_rden ? RW_READ : RW_IDLE) : pwm_mode ? (pw_rden_dest_mem ? RW_READ : RW_IDLE) : RW_IDLE; @@ -178,9 +180,9 @@ module ntt_top assign pwm_rd_data_a = pwo_mode ? pwm_a_rd_data : 'h0; //TODO: clean up mux. Just connect input directly to logic //pwm rd b - PWO mode - read b operand from mem. Or operand b can also be connected directly to sampler, so in that case, addr/rden are not used - assign pwm_b_rd_req.rd_wr_en = sampler_valid & pwo_mode ? (pw_rden ? RW_READ : RW_IDLE) : RW_IDLE; - assign pwm_b_rd_req.addr = sampler_valid & pwo_mode ? pw_mem_rd_addr_b : 'h0; - assign pwm_rd_data_b = pwm_b_rd_data; + assign pwm_b_rd_req.rd_wr_en = sampler_valid_reg & pwo_mode ? (pw_rden ? RW_READ : RW_IDLE) : RW_IDLE; //pw_rden is delayed a clk due to shuffling, so use delayed sampler_valid to line it up + assign pwm_b_rd_req.addr = sampler_valid_reg & pwo_mode ? pw_mem_rd_addr_b : 'h0; + assign pwm_rd_data_b = pwm_b_rd_data_reg; //sampler_valid_reg ? pwm_b_rd_data_reg : pwm_b_rd_data; ntt_ctrl #( @@ -300,6 +302,9 @@ module ntt_top pw_wren_reg <= 'b0; mem_wr_data_reg <= 'h0; + mem_wr_data_reg_d2 <= 'h0; + sampler_valid_reg <= 'h0; + pwm_b_rd_data_reg <= 'h0; end else if (zeroize) begin @@ -321,6 +326,9 @@ module ntt_top pw_wren_reg <= 'b0; mem_wr_data_reg <= 'h0; + mem_wr_data_reg_d2 <= 'h0; + sampler_valid_reg <= 'h0; + pwm_b_rd_data_reg <= 'h0; end else begin mem_rd_data_reg <= mem_rd_data; @@ -342,11 +350,14 @@ module ntt_top pw_wren_reg <= pw_wren; mem_wr_data_reg <= mem_wr_data_int; + mem_wr_data_reg_d2 <= mem_wr_data_reg; + sampler_valid_reg <= sampler_valid; + pwm_b_rd_data_reg <= pwm_b_rd_data; end end //Buffer (input or output side) - assign buf_data_i = ct_mode ? mem_rd_data : {1'b0, uv_o.v21_o, 1'b0, uv_o.v20_o, 1'b0, uv_o.u21_o, 1'b0, uv_o.u20_o}; + assign buf_data_i = ct_mode ? mem_rd_data : {1'b0, uv_o_reg.v21_o, 1'b0, uv_o_reg.v20_o, 1'b0, uv_o_reg.u21_o, 1'b0, uv_o_reg.u20_o}; always_comb begin unique case(mode) @@ -377,10 +388,10 @@ module ntt_top pw_uvw_i.u2_i = pwm_rd_data_a_reg[(3*REG_SIZE)-2:(2*REG_SIZE)]; pw_uvw_i.u3_i = pwm_rd_data_a_reg[(4*REG_SIZE)-2:(3*REG_SIZE)]; - pw_uvw_i.v0_i = pwm_rd_data_b_reg[REG_SIZE-2:0]; - pw_uvw_i.v1_i = pwm_rd_data_b_reg[(2*REG_SIZE)-2:REG_SIZE]; - pw_uvw_i.v2_i = pwm_rd_data_b_reg[(3*REG_SIZE)-2:(2*REG_SIZE)]; - pw_uvw_i.v3_i = pwm_rd_data_b_reg[(4*REG_SIZE)-2:(3*REG_SIZE)]; + pw_uvw_i.v0_i = pwm_rd_data_b/*_reg*/[REG_SIZE-2:0]; + pw_uvw_i.v1_i = pwm_rd_data_b/*_reg*/[(2*REG_SIZE)-2:REG_SIZE]; + pw_uvw_i.v2_i = pwm_rd_data_b/*_reg*/[(3*REG_SIZE)-2:(2*REG_SIZE)]; + pw_uvw_i.v3_i = pwm_rd_data_b/*_reg*/[(4*REG_SIZE)-2:(3*REG_SIZE)]; pw_uvw_i.w0_i = pwm_rd_data_c_reg[REG_SIZE-2:0]; pw_uvw_i.w1_i = pwm_rd_data_c_reg[(2*REG_SIZE)-2:REG_SIZE]; @@ -398,10 +409,10 @@ module ntt_top pw_uvw_i.u2_i = pwm_rd_data_a_reg[(3*REG_SIZE)-2:(2*REG_SIZE)]; pw_uvw_i.u3_i = pwm_rd_data_a_reg[(4*REG_SIZE)-2:(3*REG_SIZE)]; - pw_uvw_i.v0_i = pwm_rd_data_b_reg[REG_SIZE-2:0]; - pw_uvw_i.v1_i = pwm_rd_data_b_reg[(2*REG_SIZE)-2:REG_SIZE]; - pw_uvw_i.v2_i = pwm_rd_data_b_reg[(3*REG_SIZE)-2:(2*REG_SIZE)]; - pw_uvw_i.v3_i = pwm_rd_data_b_reg[(4*REG_SIZE)-2:(3*REG_SIZE)]; + pw_uvw_i.v0_i = pwm_rd_data_b/*_reg*/[REG_SIZE-2:0]; + pw_uvw_i.v1_i = pwm_rd_data_b/*_reg*/[(2*REG_SIZE)-2:REG_SIZE]; + pw_uvw_i.v2_i = pwm_rd_data_b/*_reg*/[(3*REG_SIZE)-2:(2*REG_SIZE)]; + pw_uvw_i.v3_i = pwm_rd_data_b/*_reg*/[(4*REG_SIZE)-2:(3*REG_SIZE)]; pw_uvw_i.w0_i = 'h0; pw_uvw_i.w1_i = 'h0; diff --git a/src/ntt_top/tb/ntt_top_tb.sv b/src/ntt_top/tb/ntt_top_tb.sv index eb26036..dba8d67 100644 --- a/src/ntt_top/tb/ntt_top_tb.sv +++ b/src/ntt_top/tb/ntt_top_tb.sv @@ -26,6 +26,7 @@ module ntt_top_tb import ntt_defines_pkg::*; + import mldsa_params_pkg::*; #( parameter TEST_VECTOR_NUM = 10, @@ -63,7 +64,7 @@ reg [23:0] zeta_inv [255:0]; reg [(4*(REG_SIZE+1))-1:0] ntt_mem_tb [63:0]; reg load_tb_values; -reg [MEM_ADDR_WIDTH-1:0] load_tb_addr; +reg [MLDSA_MEM_ADDR_WIDTH-1:0] load_tb_addr; reg [7:0] src_base_addr, interim_base_addr, dest_base_addr; reg acc_tb, svalid_tb, sampler_mode_tb; @@ -258,7 +259,7 @@ task init_sim; acc_tb = 1'b0; svalid_tb = 1'b0; sampler_mode_tb = 1'b0; - random_tb = 'h0; + random_tb <= 'h0; $display("End of init\n"); end @@ -328,8 +329,8 @@ endtask task ntt_top_test(); fork begin - while(ntt_done_tb == 1'b0) begin - random_tb = $urandom(); + while(1) begin + random_tb <= $urandom(); @(posedge clk_tb); end end @@ -368,16 +369,16 @@ task ntt_top_test(); // $display("Error: NTT data mismatch at index %0d (dest_base addr = %0d). Actual data = %h, expected data = %h", i, dest_base_addr, dut.ntt_mem.mem[i+dest_base_addr], ntt_mem_tb[i]); // @(posedge clk_tb); // end - end - join - fork - begin - while(ntt_done_tb == 1'b0) begin - random_tb = $urandom(); - @(posedge clk_tb); - end - end - begin + // end + // join + // fork + // begin + // while(ntt_done_tb == 1'b0) begin + // random_tb = $urandom(); + // @(posedge clk_tb); + // end + // end + // begin $display("INTT operation\n"); operation = "INTT"; mode_tb = gs; @@ -392,7 +393,7 @@ task ntt_top_test(); while(ntt_done_tb == 1'b0) @(posedge clk_tb); $display("Received intt_done\n"); - /* + $display("PWM operation 1\n"); operation = "PWM 1 no acc"; // $readmemh("pwm_iter1.hex", ntt_mem_tb); @@ -519,6 +520,8 @@ task ntt_top_test(); svalid_tb = 1'b0; @(posedge clk_tb); + + $display("PWM + sampler operation 1\n"); operation = "PWM sampler"; mode_tb = pwm; @@ -526,26 +529,26 @@ task ntt_top_test(); acc_tb = 1'b0; sampler_mode_tb = 1'b1; repeat(2) @(posedge clk_tb); - svalid_tb = 1'b1; + svalid_tb <= 1'b1; @(posedge clk_tb); enable_tb = 1'b0; repeat(10) @(posedge clk_tb); - svalid_tb = 1'b0; + svalid_tb <= 1'b0; repeat(10) @(posedge clk_tb); - svalid_tb = 1'b1; + svalid_tb <= 1'b1; repeat(10) @(posedge clk_tb); - svalid_tb = 1'b0; + svalid_tb <= 1'b0; repeat(10) @(posedge clk_tb); - svalid_tb = 1'b1; + svalid_tb <= 1'b1; repeat(45) @(posedge clk_tb); - svalid_tb = 1'b0; + svalid_tb <= 1'b0; $display("Waiting for pwo_done\n"); while(ntt_done_tb == 1'b0) @(posedge clk_tb); $display("Received pwo_done\n"); - */ + end - join + join_any $display("End of test\n"); endtask From e6423dc6fbf1d7c8a210e214384805c3223366f8 Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Fri, 4 Oct 2024 08:29:49 -0700 Subject: [PATCH 06/17] Add shuffle_en (wip) --- src/mldsa_top/rtl/mldsa_top.sv | 1 + src/norm_check/rtl/norm_check_ctrl.sv | 2 + src/norm_check/rtl/norm_check_top.sv | 4 ++ src/ntt_top/rtl/ntt_ctrl.sv | 75 +++++++++++++++++--------- src/ntt_top/rtl/ntt_shuffle_buffer.sv | 3 +- src/ntt_top/rtl/ntt_top.sv | 77 ++++++++++++++++++++------- src/ntt_top/tb/ntt_top_tb.sv | 1 + src/ntt_top/tb/ntt_wrapper.sv | 2 + 8 files changed, 119 insertions(+), 46 deletions(-) diff --git a/src/mldsa_top/rtl/mldsa_top.sv b/src/mldsa_top/rtl/mldsa_top.sv index 3fcdd71..37de6cb 100644 --- a/src/mldsa_top/rtl/mldsa_top.sv +++ b/src/mldsa_top/rtl/mldsa_top.sv @@ -471,6 +471,7 @@ generate .pwo_mem_base_addr(pwo_mem_base_addr[g_inst]), .accumulate(accumulate[g_inst]), .sampler_valid(sampler_valid[g_inst]), + .shuffle_en(1'b0), .random(random), //NTT mem IF .mem_wr_req(ntt_mem_wr_req[g_inst]), diff --git a/src/norm_check/rtl/norm_check_ctrl.sv b/src/norm_check/rtl/norm_check_ctrl.sv index d95158b..ef94b38 100644 --- a/src/norm_check/rtl/norm_check_ctrl.sv +++ b/src/norm_check/rtl/norm_check_ctrl.sv @@ -36,6 +36,8 @@ module norm_check_ctrl input wire norm_check_enable, input chk_norm_mode_t mode, + input wire shuffle_en, + input wire [5:0] random, input wire [MLDSA_MEM_ADDR_WIDTH-1:0] mem_base_addr, output mem_if_t mem_rd_req, output logic check_enable, diff --git a/src/norm_check/rtl/norm_check_top.sv b/src/norm_check/rtl/norm_check_top.sv index 7bf3d33..f347628 100644 --- a/src/norm_check/rtl/norm_check_top.sv +++ b/src/norm_check/rtl/norm_check_top.sv @@ -45,6 +45,8 @@ module norm_check_top input wire [MLDSA_MEM_ADDR_WIDTH-1:0] mem_base_addr, output mem_if_t mem_rd_req, input [4*REG_SIZE-1:0] mem_rd_data, + input wire shuffle_en, + input wire [5:0] random, output logic invalid, output logic norm_check_ready, output logic norm_check_done @@ -108,6 +110,8 @@ module norm_check_top .zeroize(zeroize), .norm_check_enable(norm_check_enable), .mode(mode), + .shuffle_en(shuffle_en), + .random(random), .mem_base_addr(mem_base_addr), .mem_rd_req(mem_rd_req), .norm_check_done(norm_check_done_int), diff --git a/src/ntt_top/rtl/ntt_ctrl.sv b/src/ntt_top/rtl/ntt_ctrl.sv index a370cb0..3357697 100644 --- a/src/ntt_top/rtl/ntt_ctrl.sv +++ b/src/ntt_top/rtl/ntt_ctrl.sv @@ -55,6 +55,7 @@ module ntt_ctrl // input wire [MEM_ADDR_WIDTH-1:0] pw_base_addr_b, // input wire [MEM_ADDR_WIDTH-1:0] pw_base_addr_c, //result input pwo_mem_addr_t pwo_mem_base_addr, + input wire shuffle_en, input wire [5:0] random, //4+2 bits output logic bf_enable, @@ -273,8 +274,6 @@ end always_comb begin mem_rd_base_addr = (rounds_count == 'h0) ? src_base_addr : rounds_count[0] ? interim_base_addr : dest_base_addr; mem_wr_base_addr = rounds_count[0] ? dest_base_addr : interim_base_addr; - rd_addr_wraparound = mem_rd_addr_nxt > {1'b0,mem_rd_base_addr} + MEM_LAST_ADDR; - wr_addr_wraparound = mem_wr_addr_nxt > {1'b0,mem_wr_base_addr} + MEM_LAST_ADDR; if (shuffle_en) begin mem_rd_addr_nxt = (gs_mode | pwo_mode) ? (4*chunk_count) + (rd_addr_step*mem_rd_index_ofst) + mem_rd_base_addr : mem_rd_addr + rd_addr_step; //TODO pwo modes @@ -284,6 +283,9 @@ always_comb begin mem_rd_addr_nxt = mem_rd_addr + rd_addr_step; mem_wr_addr_nxt = mem_wr_addr + wr_addr_step; end + + rd_addr_wraparound = mem_rd_addr_nxt > {1'b0,mem_rd_base_addr} + MEM_LAST_ADDR; + wr_addr_wraparound = mem_wr_addr_nxt > {1'b0,mem_wr_base_addr} + MEM_LAST_ADDR; end //Read addr @@ -354,20 +356,26 @@ always_ff @(posedge clk or negedge reset_n) begin pw_mem_wr_addr_c <= '0; end else if (rst_pw_addr) begin - pw_mem_rd_addr_a <= (4*chunk_rand_offset) + pw_base_addr_a; - pw_mem_rd_addr_b <= (4*chunk_rand_offset) + pw_base_addr_b; - pw_mem_rd_addr_c <= (4*chunk_rand_offset) + pw_base_addr_c; - pw_mem_wr_addr_c <= (4*chunk_rand_offset) + pw_base_addr_c; + pw_mem_rd_addr_a <= shuffle_en ? (4*chunk_rand_offset) + pw_base_addr_a : pw_base_addr_a; + pw_mem_rd_addr_b <= shuffle_en ? (4*chunk_rand_offset) + pw_base_addr_b : pw_base_addr_b; + pw_mem_rd_addr_c <= shuffle_en ? (4*chunk_rand_offset) + pw_base_addr_c : pw_base_addr_c; + pw_mem_wr_addr_c <= shuffle_en ? (4*chunk_rand_offset) + pw_base_addr_c : pw_base_addr_c; end else begin if (incr_pw_rd_addr) begin - pw_mem_rd_addr_a <= pw_mem_rd_addr_a_nxt;//pw_mem_rd_addr_a + PWO_READ_ADDR_STEP; - pw_mem_rd_addr_b <= pw_mem_rd_addr_b_nxt;//pw_mem_rd_addr_b + PWO_READ_ADDR_STEP; - pw_mem_rd_addr_c <= accumulate ? pw_mem_rd_addr_c_nxt : 'h0;//accumulate ? pw_mem_rd_addr_c + PWO_READ_ADDR_STEP : 'h0; //addr in sync with a, b. However, the data is flopped 4 cycles inside BF to align with mul result - + if (shuffle_en) begin + pw_mem_rd_addr_a <= pw_mem_rd_addr_a_nxt;//pw_mem_rd_addr_a + PWO_READ_ADDR_STEP; + pw_mem_rd_addr_b <= pw_mem_rd_addr_b_nxt;//pw_mem_rd_addr_b + PWO_READ_ADDR_STEP; + pw_mem_rd_addr_c <= accumulate ? pw_mem_rd_addr_c_nxt : 'h0;//accumulate ? pw_mem_rd_addr_c + PWO_READ_ADDR_STEP : 'h0; //addr in sync with a, b. However, the data is flopped 4 cycles inside BF to align with mul result + end + else begin + pw_mem_rd_addr_a <= pw_mem_rd_addr_a + PWO_READ_ADDR_STEP; + pw_mem_rd_addr_b <= pw_mem_rd_addr_b + PWO_READ_ADDR_STEP; + pw_mem_rd_addr_c <= accumulate ? pw_mem_rd_addr_c + PWO_READ_ADDR_STEP : 'h0; //addr in sync with a, b. However, the data is flopped 4 cycles inside BF to align with mul result + end end if (incr_pw_wr_addr) begin - pw_mem_wr_addr_c <= pw_mem_wr_addr_c_nxt; //pw_mem_wr_addr_c + PWO_WRITE_ADDR_STEP; + pw_mem_wr_addr_c <= shuffle_en ? pw_mem_wr_addr_c_nxt : pw_mem_wr_addr_c + PWO_WRITE_ADDR_STEP; end end end @@ -432,7 +440,7 @@ always_ff @(posedge clk or negedge reset_n) begin twiddle_addr_reg <= 'h0; end -assign twiddle_addr_int = ct_mode ? twiddle_addr_reg + twiddle_offset : twiddle_rand_offset + twiddle_offset; +assign twiddle_addr_int = (~shuffle_en | ct_mode) ? twiddle_addr_reg + twiddle_offset : twiddle_rand_offset + twiddle_offset; //------------------------------------------ //Busy logic @@ -593,10 +601,10 @@ always_ff @(posedge clk or negedge reset_n) begin else if (zeroize) begin buf_wrptr <= 'h0; end - else if (buf_wren & ct_mode) begin //ct mode - buf writes are in order + else if (buf_wren & (ct_mode | ~shuffle_en)) begin //ct mode - buf writes are in order buf_wrptr <= (buf_wrptr == 'h3) ? 'h0 : buf_wrptr + 'h1; end - else if (buf_wren_intt & gs_mode) begin // gs mode + else if (buf_wren_intt & gs_mode & shuffle_en) begin // gs mode buf_wrptr <= buf_wrptr_reg[0]; end end @@ -604,8 +612,8 @@ end always_comb begin last_rd_addr = /*ct_mode &*/ (mem_rd_addr == mem_rd_base_addr + MEM_LAST_ADDR); //TODO: other modes last_wr_addr = /*ct_mode &*/ (mem_wr_addr == mem_wr_base_addr + MEM_LAST_ADDR); //TODO: other modes - buf_rdptr_int = ct_mode ? index_rand_offset + buf_count : buf_count; //TODO: flop - buf_rdptr = ct_mode ? buf_rdptr_f : buf_count; + buf_rdptr_int = (shuffle_en & ct_mode) ? index_rand_offset + buf_count : buf_count; //TODO: flop + buf_rdptr = (shuffle_en & ct_mode) ? buf_rdptr_f : buf_count; // buf_wrptr = gs_mode ? index_rand_offset + buf_count : buf_count; latch_chunk_rand_offset = arc_IDLE_WR_STAGE | arc_WR_MEM_WR_STAGE | arc_WR_WAIT_WR_STAGE; latch_index_rand_offset = ct_mode ? (buf_wrptr == 'h3) : (gs_mode | (pwo_mode & incr_pw_rd_addr)) & (arc_RD_STAGE_RD_EXEC | (index_count == 'h3)); //TODO pwo mode @@ -721,7 +729,10 @@ always_comb begin buf_wren_ntt = ct_mode; buf_rden_ntt = ct_mode; incr_mem_rd_addr = (ntt_mode inside {ct, gs}); - mem_rd_en_fsm = (ntt_mode inside {ct, gs}) ? (mem_rd_addr <= MEM_LAST_ADDR + mem_rd_base_addr) & ~arc_RD_EXEC_EXEC_WAIT : 1'b0; + if (shuffle_en) + mem_rd_en_fsm = (ntt_mode inside {ct, gs}) ? (mem_rd_addr <= MEM_LAST_ADDR + mem_rd_base_addr) & ~arc_RD_EXEC_EXEC_WAIT : 1'b0; + else + mem_rd_en_fsm = (ntt_mode inside {ct, gs}) ? (mem_rd_addr <= MEM_LAST_ADDR + mem_rd_base_addr) : 1'b0; bf_enable_fsm = pwo_mode ? sampler_valid : 1'b1; incr_twiddle_addr_fsm = ntt_mode inside {ct, gs}; //1'b1; rd_addr_step = ct_mode ? NTT_READ_ADDR_STEP : INTT_READ_ADDR_STEP; @@ -763,15 +774,21 @@ always_comb begin //Start NTT/INTT op when fsm is in IDLE state and there's an enable coming in arc_IDLE_WR_STAGE = (write_fsm_state_ps == WR_IDLE) && ntt_enable ; - //This arc is only for ct mode. No buffer in the path, so wait for all addr to be written (0-63) before transitioning to WR_STAGE - arc_WR_MEM_WR_STAGE = (write_fsm_state_ps == WR_MEM) && ((ct_mode) && (wr_valid_count == 'h3f)); //(mem_wr_addr == (mem_wr_base_addr + MEM_LAST_ADDR)); //this arc is for ct mode, + if (shuffle_en) begin + //This arc is only for ct mode. No buffer in the path, so wait for all addr to be written (0-63) before transitioning to WR_STAGE + arc_WR_MEM_WR_STAGE = (write_fsm_state_ps == WR_MEM) && ((ct_mode) && (wr_valid_count == 'h3f)); //(mem_wr_addr == (mem_wr_base_addr + MEM_LAST_ADDR)); //this arc is for ct mode, + + //This arc is only for ct mode since there's no output buffer + arc_WR_STAGE_WR_MEM = (write_fsm_state_ps == WR_STAGE) && ((ct_mode && !ntt_done) || (pwo_mode && !pwo_done)); // || (pwo_mode && (!pwo_done /*|| ntt_enable*/))); + end + else begin + arc_WR_MEM_WR_STAGE = (write_fsm_state_ps == WR_MEM) && ((ct_mode || pwo_mode) && (wr_valid_count == 'h3f)); //(mem_wr_addr == (mem_wr_base_addr + MEM_LAST_ADDR)); //this arc is for ct mode, + arc_WR_STAGE_WR_MEM = (write_fsm_state_ps == WR_STAGE) && (ct_mode && !ntt_done); + end //All rounds of NTT or INTT are done. Go to IDLE and wait for next command arc_WR_STAGE_IDLE = (write_fsm_state_ps == WR_STAGE) && (ntt_done || intt_done || pwo_done); - //This arc is only for ct mode since there's no output buffer - arc_WR_STAGE_WR_MEM = (write_fsm_state_ps == WR_STAGE) && ((ct_mode && !ntt_done) || (pwo_mode && !pwo_done)); // || (pwo_mode && (!pwo_done /*|| ntt_enable*/))); - //pwm arc. If in WR_STAGE, read fsm is executing, go back to WR_MEM state to perform current round's writes arc_WR_STAGE_WR_MEM_OPT = (write_fsm_state_ps == WR_STAGE) && (read_fsm_state_ps == RD_EXEC) && (pwo_mode && pwo_busy); @@ -825,10 +842,16 @@ always_comb begin rst_pw_addr = 1'b1; end WR_STAGE: begin - write_fsm_state_ns = arc_WR_STAGE_WR_MEM ? WR_MEM : + if (shuffle_en) + write_fsm_state_ns = arc_WR_STAGE_WR_MEM ? WR_MEM : arc_WR_STAGE_WR_BUF ? WR_BUF : // arc_WR_STAGE_WR_WAIT? WR_WAIT : arc_WR_STAGE_IDLE ? WR_IDLE : WR_STAGE; + else + write_fsm_state_ns = arc_WR_STAGE_WR_MEM ? WR_MEM : + arc_WR_STAGE_WR_BUF ? WR_BUF : + arc_WR_STAGE_WR_WAIT? WR_WAIT : + arc_WR_STAGE_IDLE ? WR_IDLE : WR_STAGE; rst_wr_addr = 1'b1; rst_wr_valid_count = 1'b1; buf_wr_rst_count_intt = gs_mode; @@ -869,8 +892,8 @@ always_comb begin incr_mem_wr_addr = (ct_mode | gs_mode); //1'b1; mem_wr_en_fsm = shuffle_en ? gs_mode : (ct_mode | gs_mode); //1'b1; - incr_pw_wr_addr = pwo_mode & arc_WR_WAIT_WR_STAGE; //MEM; - // pw_wren_fsm = arc_WR_WAIT_WR_STAGE; //MEM; + incr_pw_wr_addr = shuffle_en ? pwo_mode & arc_WR_WAIT_WR_STAGE : arc_WR_WAIT_WR_MEM; //MEM; + pw_wren_fsm = shuffle_en ? 'b0 : arc_WR_WAIT_WR_MEM; end default: begin write_fsm_state_ns = WR_IDLE; @@ -898,6 +921,8 @@ always_comb begin mem_wr_en = mem_wr_en_fsm; mem_rd_en = mem_rd_en_fsm; twiddle_addr = twiddle_addr_int; + pw_rden = pw_rden_fsm; + pw_wren = pw_wren_fsm; end buf_wr_rst_count = pwo_mode ? 1'b1 : buf_wr_rst_count_ntt | buf_wr_rst_count_intt; buf_rd_rst_count = pwo_mode ? 1'b1 : buf_rd_rst_count_ntt | buf_rd_rst_count_intt; diff --git a/src/ntt_top/rtl/ntt_shuffle_buffer.sv b/src/ntt_top/rtl/ntt_shuffle_buffer.sv index cf970c7..ee2855b 100644 --- a/src/ntt_top/rtl/ntt_shuffle_buffer.sv +++ b/src/ntt_top/rtl/ntt_shuffle_buffer.sv @@ -39,6 +39,7 @@ module ntt_shuffle_buffer input wire reset_n, input wire zeroize, input mode_t mode, + input wire shuffle_en, input wire wren, input wire rden, input wire [1:0] wrptr, @@ -92,7 +93,7 @@ module ntt_shuffle_buffer always_comb begin buf_valid = (data_i_count_reg == 'd3); lo_hi = buf_valid ^ lo_hi_reg; - lo_hi_rd = (mode == 0) ? lo_hi_reg : lo_hi; + lo_hi_rd = (shuffle_en & (mode == 0)) ? lo_hi_reg : lo_hi; //shuffling delays logic by a cycle, so that needs to be accounted for here as well end //lo hi diff --git a/src/ntt_top/rtl/ntt_top.sv b/src/ntt_top/rtl/ntt_top.sv index 48ca3a3..567a4d5 100644 --- a/src/ntt_top/rtl/ntt_top.sv +++ b/src/ntt_top/rtl/ntt_top.sv @@ -73,6 +73,7 @@ module ntt_top //Sampler IF input wire sampler_valid, + input wire shuffle_en, input wire [5:0] random, //Memory if @@ -167,7 +168,8 @@ module ntt_top assign mem_wr_req.addr = !pwo_mode ? mem_wr_addr_mux : pwm_wr_addr_c_reg; assign mem_wr_data_int = !pwo_mode ? (ct_mode ? {1'b0, uv_o_reg.v21_o, 1'b0, uv_o_reg.u21_o, 1'b0, uv_o_reg.v20_o, 1'b0, uv_o_reg.u20_o} : buf_data_o) : pwm_wr_data_reg; - assign mem_wr_data = pwm_mode ? mem_wr_data_reg/*_d2*/ : (pwa_mode | pws_mode) ? mem_wr_data_reg : mem_wr_data_int; //ct_mode ? mem_wr_data_reg : mem_wr_data_int; //TODO: gs, pwo modes + assign mem_wr_data = shuffle_en ? pwm_mode ? mem_wr_data_reg/*_d2*/ : (pwa_mode | pws_mode) ? mem_wr_data_reg : mem_wr_data_int + : mem_wr_data_int; //ct_mode ? mem_wr_data_reg : mem_wr_data_int; //TODO: gs, pwo modes //mem rd - NTT/INTT mode, read ntt data. PWM mode, read accumulate data from c mem. PWA/S mode, unused assign mem_rd_req.rd_wr_en = (ct_mode || gs_mode) ? (mem_rden ? RW_READ : RW_IDLE) : pwm_mode ? (pw_rden_dest_mem ? RW_READ : RW_IDLE) : RW_IDLE; @@ -180,9 +182,18 @@ module ntt_top assign pwm_rd_data_a = pwo_mode ? pwm_a_rd_data : 'h0; //TODO: clean up mux. Just connect input directly to logic //pwm rd b - PWO mode - read b operand from mem. Or operand b can also be connected directly to sampler, so in that case, addr/rden are not used - assign pwm_b_rd_req.rd_wr_en = sampler_valid_reg & pwo_mode ? (pw_rden ? RW_READ : RW_IDLE) : RW_IDLE; //pw_rden is delayed a clk due to shuffling, so use delayed sampler_valid to line it up - assign pwm_b_rd_req.addr = sampler_valid_reg & pwo_mode ? pw_mem_rd_addr_b : 'h0; - assign pwm_rd_data_b = pwm_b_rd_data_reg; //sampler_valid_reg ? pwm_b_rd_data_reg : pwm_b_rd_data; + always_comb begin + if (shuffle_en) begin + pwm_b_rd_req.rd_wr_en = sampler_valid_reg & pwo_mode ? (pw_rden ? RW_READ : RW_IDLE) : RW_IDLE; //pw_rden is delayed a clk due to shuffling, so use delayed sampler_valid to line it up + pwm_b_rd_req.addr = sampler_valid_reg & pwo_mode ? pw_mem_rd_addr_b : 'h0; + pwm_rd_data_b = pwm_b_rd_data_reg; //sampler_valid_reg ? pwm_b_rd_data_reg : pwm_b_rd_data; + end + else begin + pwm_b_rd_req.rd_wr_en = sampler_valid & pwo_mode ? (pw_rden ? RW_READ : RW_IDLE) : RW_IDLE; + pwm_b_rd_req.addr = sampler_valid & pwo_mode ? pw_mem_rd_addr_b : 'h0; + pwm_rd_data_b = pwm_b_rd_data; + end + end ntt_ctrl #( @@ -197,6 +208,7 @@ module ntt_top .butterfly_ready(bf_ready), .buf0_valid(buf0_valid), .sampler_valid(sampler_valid), + .shuffle_en(shuffle_en), .random(random), .ntt_mem_base_addr(ntt_mem_base_addr), @@ -240,17 +252,24 @@ module ntt_top always_comb begin unique case(mode) ct: begin - //with shuffling, twiddle factor needs to be delayed uvw_i.w00_i = twiddle_factor[NTT_REG_SIZE-1:0]; uvw_i.w01_i = twiddle_factor[NTT_REG_SIZE-1:0]; uvw_i.w10_i = twiddle_factor[(2*NTT_REG_SIZE)-1:NTT_REG_SIZE]; uvw_i.w11_i = twiddle_factor[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; end gs: begin - uvw_i.w11_i = twiddle_factor/*_reg*/[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; - uvw_i.w10_i = twiddle_factor/*_reg*/[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; - uvw_i.w01_i = twiddle_factor/*_reg*/[(2*NTT_REG_SIZE)-1:NTT_REG_SIZE]; - uvw_i.w00_i = twiddle_factor/*_reg*/[NTT_REG_SIZE-1:0]; + if (shuffle_en) begin + uvw_i.w11_i = twiddle_factor/*_reg*/[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; + uvw_i.w10_i = twiddle_factor/*_reg*/[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; + uvw_i.w01_i = twiddle_factor/*_reg*/[(2*NTT_REG_SIZE)-1:NTT_REG_SIZE]; + uvw_i.w00_i = twiddle_factor/*_reg*/[NTT_REG_SIZE-1:0]; + end + else begin + uvw_i.w11_i = twiddle_factor_reg[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; + uvw_i.w10_i = twiddle_factor_reg[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; + uvw_i.w01_i = twiddle_factor_reg[(2*NTT_REG_SIZE)-1:NTT_REG_SIZE]; + uvw_i.w00_i = twiddle_factor_reg[NTT_REG_SIZE-1:0]; + end end default: begin uvw_i.w11_i = 'h0; @@ -357,7 +376,8 @@ module ntt_top end //Buffer (input or output side) - assign buf_data_i = ct_mode ? mem_rd_data : {1'b0, uv_o_reg.v21_o, 1'b0, uv_o_reg.v20_o, 1'b0, uv_o_reg.u21_o, 1'b0, uv_o_reg.u20_o}; + assign buf_data_i = ct_mode ? mem_rd_data : shuffle_en ? {1'b0, uv_o_reg.v21_o, 1'b0, uv_o_reg.v20_o, 1'b0, uv_o_reg.u21_o, 1'b0, uv_o_reg.u20_o} + : {1'b0, uv_o.v21_o, 1'b0, uv_o.v20_o, 1'b0, uv_o.u21_o, 1'b0, uv_o.u20_o}; always_comb begin unique case(mode) @@ -388,10 +408,18 @@ module ntt_top pw_uvw_i.u2_i = pwm_rd_data_a_reg[(3*REG_SIZE)-2:(2*REG_SIZE)]; pw_uvw_i.u3_i = pwm_rd_data_a_reg[(4*REG_SIZE)-2:(3*REG_SIZE)]; - pw_uvw_i.v0_i = pwm_rd_data_b/*_reg*/[REG_SIZE-2:0]; - pw_uvw_i.v1_i = pwm_rd_data_b/*_reg*/[(2*REG_SIZE)-2:REG_SIZE]; - pw_uvw_i.v2_i = pwm_rd_data_b/*_reg*/[(3*REG_SIZE)-2:(2*REG_SIZE)]; - pw_uvw_i.v3_i = pwm_rd_data_b/*_reg*/[(4*REG_SIZE)-2:(3*REG_SIZE)]; + if (shuffle_en) begin + pw_uvw_i.v0_i = pwm_rd_data_b/*_reg*/[REG_SIZE-2:0]; + pw_uvw_i.v1_i = pwm_rd_data_b/*_reg*/[(2*REG_SIZE)-2:REG_SIZE]; + pw_uvw_i.v2_i = pwm_rd_data_b/*_reg*/[(3*REG_SIZE)-2:(2*REG_SIZE)]; + pw_uvw_i.v3_i = pwm_rd_data_b/*_reg*/[(4*REG_SIZE)-2:(3*REG_SIZE)]; + end + else begin + pw_uvw_i.v0_i = pwm_rd_data_b_reg[REG_SIZE-2:0]; + pw_uvw_i.v1_i = pwm_rd_data_b_reg[(2*REG_SIZE)-2:REG_SIZE]; + pw_uvw_i.v2_i = pwm_rd_data_b_reg[(3*REG_SIZE)-2:(2*REG_SIZE)]; + pw_uvw_i.v3_i = pwm_rd_data_b_reg[(4*REG_SIZE)-2:(3*REG_SIZE)]; + end pw_uvw_i.w0_i = pwm_rd_data_c_reg[REG_SIZE-2:0]; pw_uvw_i.w1_i = pwm_rd_data_c_reg[(2*REG_SIZE)-2:REG_SIZE]; @@ -409,10 +437,18 @@ module ntt_top pw_uvw_i.u2_i = pwm_rd_data_a_reg[(3*REG_SIZE)-2:(2*REG_SIZE)]; pw_uvw_i.u3_i = pwm_rd_data_a_reg[(4*REG_SIZE)-2:(3*REG_SIZE)]; - pw_uvw_i.v0_i = pwm_rd_data_b/*_reg*/[REG_SIZE-2:0]; - pw_uvw_i.v1_i = pwm_rd_data_b/*_reg*/[(2*REG_SIZE)-2:REG_SIZE]; - pw_uvw_i.v2_i = pwm_rd_data_b/*_reg*/[(3*REG_SIZE)-2:(2*REG_SIZE)]; - pw_uvw_i.v3_i = pwm_rd_data_b/*_reg*/[(4*REG_SIZE)-2:(3*REG_SIZE)]; + if (shuffle_en) begin + pw_uvw_i.v0_i = pwm_rd_data_b/*_reg*/[REG_SIZE-2:0]; + pw_uvw_i.v1_i = pwm_rd_data_b/*_reg*/[(2*REG_SIZE)-2:REG_SIZE]; + pw_uvw_i.v2_i = pwm_rd_data_b/*_reg*/[(3*REG_SIZE)-2:(2*REG_SIZE)]; + pw_uvw_i.v3_i = pwm_rd_data_b/*_reg*/[(4*REG_SIZE)-2:(3*REG_SIZE)]; + end + else begin + pw_uvw_i.v0_i = pwm_rd_data_b_reg[REG_SIZE-2:0]; + pw_uvw_i.v1_i = pwm_rd_data_b_reg[(2*REG_SIZE)-2:REG_SIZE]; + pw_uvw_i.v2_i = pwm_rd_data_b_reg[(3*REG_SIZE)-2:(2*REG_SIZE)]; + pw_uvw_i.v3_i = pwm_rd_data_b_reg[(4*REG_SIZE)-2:(3*REG_SIZE)]; + end pw_uvw_i.w0_i = 'h0; pw_uvw_i.w1_i = 'h0; @@ -443,8 +479,8 @@ module ntt_top endcase end assign bf_enable_mux = ct_mode ? bf_enable : bf_enable_reg; - assign mem_wren_mux = mem_wren; //ct_mode ? mem_wren_reg : mem_wren; - assign mem_wr_addr_mux = mem_wr_addr; //ct_mode ? mem_wr_addr_reg : mem_wr_addr; + assign mem_wren_mux = ~shuffle_en & ct_mode ? mem_wren_reg : mem_wren; + assign mem_wr_addr_mux = ~shuffle_en & ct_mode ? mem_wr_addr_reg : mem_wr_addr; /* ntt_buffer #( @@ -471,6 +507,7 @@ module ntt_top .reset_n(reset_n), .zeroize(zeroize), .mode(mode), + .shuffle_en(shuffle_en), .wren(buf_wren), .rden(buf_rden), .wrptr(buf_wrptr), diff --git a/src/ntt_top/tb/ntt_top_tb.sv b/src/ntt_top/tb/ntt_top_tb.sv index dba8d67..9ff0616 100644 --- a/src/ntt_top/tb/ntt_top_tb.sv +++ b/src/ntt_top/tb/ntt_top_tb.sv @@ -158,6 +158,7 @@ ntt_wrapper dut ( .ntt_enable(enable_tb), .load_tb_values(load_tb_values), .load_tb_addr(load_tb_addr), + .shuffle_en(1'b0), .random(random_tb), // .src_base_addr(src_base_addr), // .interim_base_addr(interim_base_addr), diff --git a/src/ntt_top/tb/ntt_wrapper.sv b/src/ntt_top/tb/ntt_wrapper.sv index 0e15f61..dc511d4 100644 --- a/src/ntt_top/tb/ntt_wrapper.sv +++ b/src/ntt_top/tb/ntt_wrapper.sv @@ -37,6 +37,7 @@ module ntt_wrapper input mode_t mode, input wire ntt_enable, + input wire shuffle_en, input wire [5:0] random, //TB purpose - remove later TODO @@ -179,6 +180,7 @@ module ntt_wrapper .pwo_mem_base_addr(pwo_mem_base_addr), .accumulate(accumulate), .sampler_valid(sampler_valid), + .shuffle_en(shuffle_en), .random(random), //NTT mem IF .mem_wr_req(mem_wr_req), From cac0b8b3016895bf8dc16b344f1e8d4cc1de9174 Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Mon, 28 Oct 2024 12:30:21 -0700 Subject: [PATCH 07/17] Add shuffle_en input to ntt --- src/mldsa_top/rtl/mldsa_top.sv | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/mldsa_top/rtl/mldsa_top.sv b/src/mldsa_top/rtl/mldsa_top.sv index 7beb195..f906497 100644 --- a/src/mldsa_top/rtl/mldsa_top.sv +++ b/src/mldsa_top/rtl/mldsa_top.sv @@ -109,6 +109,7 @@ module mldsa_top logic [1:0][MLDSA_MEM_DATA_WIDTH-1:0] pwm_b_rd_data; logic [1:0] ntt_done; logic [1:0] ntt_busy; + logic [1:0] shuffle_en; mem_if_t w1_mem_wr_req; logic [3:0] w1_mem_wr_data; @@ -439,15 +440,18 @@ generate accumulate[g_inst] = '0; sampler_valid[g_inst] = 0; sampler_ntt_mode[g_inst] = 0; + shuffle_en[g_inst] = 0; unique case (ntt_mode[g_inst]) inside MLDSA_NTT_NONE: begin end MLDSA_NTT: begin mode[g_inst] = ct; + shuffle_en[g_inst] = 1; end MLDSA_INTT: begin mode[g_inst] = gs; + shuffle_en[g_inst] = 1; end MLDSA_PWM_SMPL: begin mode[g_inst] = pwm; @@ -463,19 +467,23 @@ generate MLDSA_PWM: begin mode[g_inst] = pwm; sampler_valid[g_inst] = 1; + shuffle_en[g_inst] = 1; end MLDSA_PWM_ACCUM: begin mode[g_inst] = pwm; accumulate[g_inst] = 1; sampler_valid[g_inst] = 1; + shuffle_en[g_inst] = 1; end MLDSA_PWA: begin mode[g_inst] = pwa; sampler_valid[g_inst] = 1; + shuffle_en[g_inst] = 1; end MLDSA_PWS: begin mode[g_inst] = pws; sampler_valid[g_inst] = 1; + shuffle_en[g_inst] = 1; end default: begin end @@ -499,7 +507,7 @@ generate .pwo_mem_base_addr(pwo_mem_base_addr[g_inst]), .accumulate(accumulate[g_inst]), .sampler_valid(sampler_valid[g_inst]), - .shuffle_en(1'b0), + .shuffle_en(shuffle_en[g_inst]), .random(random), //NTT mem IF .mem_wr_req(ntt_mem_wr_req[g_inst]), @@ -510,7 +518,7 @@ generate .pwm_a_rd_req(pwm_a_rd_req[g_inst]), .pwm_b_rd_req(pwm_b_rd_req[g_inst]), .pwm_a_rd_data(pwm_a_rd_data[g_inst]), - .pwm_b_rd_data(sampler_ntt_mode[g_inst] ? sampler_ntt_data_reg : pwm_b_rd_data[g_inst]), + .pwm_b_rd_data(sampler_ntt_mode[g_inst] ? sampler_ntt_data : pwm_b_rd_data[g_inst]), .ntt_busy(ntt_busy[g_inst]), .ntt_done(ntt_done[g_inst]) ); @@ -935,7 +943,7 @@ always_comb begin for (int bank = 0; bank < 2; bank++) begin ntt_mem_re0_bank[0][bank] = (ntt_mem_rd_req[0].rd_wr_en == RW_READ) & (ntt_mem_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (ntt_mem_rd_req[0].addr[0] == bank); pwo_a_mem_re0_bank[0][bank] = (pwm_a_rd_req[0].rd_wr_en == RW_READ) & (pwm_a_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_a_rd_req[0].addr[0] == bank); - pwo_b_mem_re0_bank[0][bank] = ~sampler_ntt_dv_f & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_b_rd_req[0].addr[0] == bank); + pwo_b_mem_re0_bank[0][bank] = (shuffle_en[0] ? ~sampler_ntt_dv_f : ~sampler_ntt_dv) & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_b_rd_req[0].addr[0] == bank); ntt_mem_re0_bank[1][bank] = (ntt_mem_rd_req[1].rd_wr_en == RW_READ) & (ntt_mem_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (ntt_mem_rd_req[1].addr[0] == bank); pwo_a_mem_re0_bank[1][bank] = (pwm_a_rd_req[1].rd_wr_en == RW_READ) & (pwm_a_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]) & (pwm_a_rd_req[1].addr[0] == bank); @@ -969,7 +977,7 @@ always_comb begin end else begin ntt_mem_re[0][i] = (ntt_mem_rd_req[0].rd_wr_en == RW_READ) & (ntt_mem_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); pwo_a_mem_re[0][i] = (pwm_a_rd_req[0].rd_wr_en == RW_READ) & (pwm_a_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); - pwo_b_mem_re[0][i] = ~sampler_ntt_dv_f & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); + pwo_b_mem_re[0][i] = (shuffle_en[0] ? ~sampler_ntt_dv_f : ~sampler_ntt_dv) & (pwm_b_rd_req[0].rd_wr_en == RW_READ) & (pwm_b_rd_req[0].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); ntt_mem_re[1][i] = (ntt_mem_rd_req[1].rd_wr_en == RW_READ) & (ntt_mem_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); pwo_a_mem_re[1][i] = (pwm_a_rd_req[1].rd_wr_en == RW_READ) & (pwm_a_rd_req[1].addr[MLDSA_MEM_ADDR_WIDTH-1:MLDSA_MEM_ADDR_WIDTH-3] == i[2:0]); From 6b2cf17d9e0646866e6ca5db330cca9bfd40bd20 Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Mon, 28 Oct 2024 14:35:28 -0700 Subject: [PATCH 08/17] Remove random input and use lfsr bits --- src/mldsa_top/rtl/mldsa_top.sv | 3 +-- .../project_benches/mldsa/tb/testbench/hdl_top.sv | 6 +----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/mldsa_top/rtl/mldsa_top.sv b/src/mldsa_top/rtl/mldsa_top.sv index f906497..5b1774f 100644 --- a/src/mldsa_top/rtl/mldsa_top.sv +++ b/src/mldsa_top/rtl/mldsa_top.sv @@ -51,7 +51,6 @@ module mldsa_top input logic hready_i, input logic [1:0] htrans_i, input logic [2:0] hsize_i, - input logic [5:0] random, //ahb output output logic hresp_o, @@ -508,7 +507,7 @@ generate .accumulate(accumulate[g_inst]), .sampler_valid(sampler_valid[g_inst]), .shuffle_en(shuffle_en[g_inst]), - .random(random), + .random(rand_bits[g_inst*6+5:g_inst*6]), //NTT mem IF .mem_wr_req(ntt_mem_wr_req[g_inst]), .mem_rd_req(ntt_mem_rd_req[g_inst]), diff --git a/src/mldsa_top/uvmf/uvmf_template_output/project_benches/mldsa/tb/testbench/hdl_top.sv b/src/mldsa_top/uvmf/uvmf_template_output/project_benches/mldsa/tb/testbench/hdl_top.sv index e7eb9d2..b04b3b2 100644 --- a/src/mldsa_top/uvmf/uvmf_template_output/project_benches/mldsa/tb/testbench/hdl_top.sv +++ b/src/mldsa_top/uvmf/uvmf_template_output/project_benches/mldsa/tb/testbench/hdl_top.sv @@ -54,16 +54,13 @@ import uvmf_base_pkg_hdl::*; // pragma uvmf custom clock_generator begin bit clk; - logic [5:0] random_tb; // Instantiate a clk driver // tbx clkgen initial begin clk = 0; - random_tb = 0; #0ns; forever begin clk = ~clk; - random_tb = $urandom(); #5ns; end end @@ -104,8 +101,7 @@ import uvmf_base_pkg_hdl::*; .hsize_i (uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HSIZE ), .hresp_o (uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HRESP ), .hreadyout_o(uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HREADY ), - .hrdata_o (uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HRDATA ), - .random (random_tb) + .hrdata_o (uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HRDATA ) ); assign uvm_test_top_environment_qvip_ahb_lite_slave_subenv_qvip_hdl.ahb_lite_slave_0_HBURST = 3'b0; From 8d56be429cc003ae9e6f617db375cb0bdefeb437 Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Tue, 29 Oct 2024 10:19:22 -0700 Subject: [PATCH 09/17] Clean up --- src/mldsa_top/rtl/mldsa_top.sv | 7 +--- src/ntt_top/rtl/ntt_ctrl.sv | 68 ++++++++++++++++------------------ src/ntt_top/rtl/ntt_top.sv | 35 ++++++----------- 3 files changed, 45 insertions(+), 65 deletions(-) diff --git a/src/mldsa_top/rtl/mldsa_top.sv b/src/mldsa_top/rtl/mldsa_top.sv index 5b1774f..8228127 100644 --- a/src/mldsa_top/rtl/mldsa_top.sv +++ b/src/mldsa_top/rtl/mldsa_top.sv @@ -90,7 +90,7 @@ module mldsa_top logic [1:0] sampler_ntt_dv, sampler_ntt_dv_f; logic [1:0] sampler_ntt_mode; logic [1:0] sampler_valid; - logic [COEFF_PER_CLK-1:0][MLDSA_Q_WIDTH-1:0] sampler_ntt_data, sampler_ntt_data_reg; + logic [COEFF_PER_CLK-1:0][MLDSA_Q_WIDTH-1:0] sampler_ntt_data; mldsa_ntt_mode_e [1:0] ntt_mode; mode_t [1:0] mode; @@ -415,15 +415,12 @@ mldsa_sampler_top sampler_top_inst always_ff @(posedge clk or negedge rst_b) begin if (!rst_b) begin - sampler_ntt_data_reg <= 0; sampler_ntt_dv_f <= 0; end else if (zeroize_reg) begin - sampler_ntt_data_reg <= 0; sampler_ntt_dv_f <= 0; end else begin - sampler_ntt_data_reg <= sampler_ntt_data; sampler_ntt_dv_f <= sampler_ntt_dv; end end @@ -439,7 +436,7 @@ generate accumulate[g_inst] = '0; sampler_valid[g_inst] = 0; sampler_ntt_mode[g_inst] = 0; - shuffle_en[g_inst] = 0; + shuffle_en[g_inst] = 0; //TODO: temp change for testing, remove and add to opcodes unique case (ntt_mode[g_inst]) inside MLDSA_NTT_NONE: begin diff --git a/src/ntt_top/rtl/ntt_ctrl.sv b/src/ntt_top/rtl/ntt_ctrl.sv index 3357697..e51d14d 100644 --- a/src/ntt_top/rtl/ntt_ctrl.sv +++ b/src/ntt_top/rtl/ntt_ctrl.sv @@ -32,7 +32,7 @@ module ntt_ctrl parameter MLDSA_N = 256, parameter MLDSA_LOGN = 8, parameter MEM_ADDR_WIDTH = 15, - parameter BF_LATENCY = 10, //11 //TODO: change back to 10 and test shuffling //5 cycles per butterfly * 2 instances in serial = 10 clks + parameter BF_LATENCY = 10, //5 cycles per butterfly * 2 instances in serial = 10 clks parameter NTT_BUF_LATENCY = 4 ) ( @@ -46,14 +46,8 @@ module ntt_ctrl input wire sampler_valid, input wire accumulate, //NTT/INTT base addr - // input wire [MEM_ADDR_WIDTH-1:0] src_base_addr, - // input wire [MEM_ADDR_WIDTH-1:0] interim_base_addr, - // input wire [MEM_ADDR_WIDTH-1:0] dest_base_addr, input ntt_mem_addr_t ntt_mem_base_addr, //PWO base addr - // input wire [MEM_ADDR_WIDTH-1:0] pw_base_addr_a, - // input wire [MEM_ADDR_WIDTH-1:0] pw_base_addr_b, - // input wire [MEM_ADDR_WIDTH-1:0] pw_base_addr_c, //result input pwo_mem_addr_t pwo_mem_base_addr, input wire shuffle_en, input wire [5:0] random, //4+2 bits @@ -122,7 +116,7 @@ logic last_rd_addr, last_wr_addr; logic mem_wr_en_fsm, mem_wr_en_reg; logic mem_rd_en_fsm, mem_rd_en_reg; logic pw_rden_fsm, pw_rden_reg; -logic pw_wren_fsm, pw_wren_reg, pw_wren_reg_d2; +logic pw_wren_fsm, pw_wren_reg; //Mode flags logic ct_mode, gs_mode, pwo_mode; //point-wise operations mode @@ -364,9 +358,9 @@ always_ff @(posedge clk or negedge reset_n) begin else begin if (incr_pw_rd_addr) begin if (shuffle_en) begin - pw_mem_rd_addr_a <= pw_mem_rd_addr_a_nxt;//pw_mem_rd_addr_a + PWO_READ_ADDR_STEP; - pw_mem_rd_addr_b <= pw_mem_rd_addr_b_nxt;//pw_mem_rd_addr_b + PWO_READ_ADDR_STEP; - pw_mem_rd_addr_c <= accumulate ? pw_mem_rd_addr_c_nxt : 'h0;//accumulate ? pw_mem_rd_addr_c + PWO_READ_ADDR_STEP : 'h0; //addr in sync with a, b. However, the data is flopped 4 cycles inside BF to align with mul result + pw_mem_rd_addr_a <= pw_mem_rd_addr_a_nxt; + pw_mem_rd_addr_b <= pw_mem_rd_addr_b_nxt; + pw_mem_rd_addr_c <= accumulate ? pw_mem_rd_addr_c_nxt : 'h0; //addr in sync with a, b. However, the data is flopped 4 cycles inside BF to align with mul result end else begin pw_mem_rd_addr_a <= pw_mem_rd_addr_a + PWO_READ_ADDR_STEP; @@ -404,7 +398,7 @@ always_comb begin 'h3: begin twiddle_end_addr = ct_mode ? 'd63 : 'd0; twiddle_offset = ct_mode ? 'd21 : 'd84; - twiddle_rand_offset = ct_mode ? (chunk_count % 'd16)*4 + buf_rdptr_int : 'h0; //gs mode: 0 + twiddle_rand_offset = ct_mode ? (chunk_count % 'd16)*4 + buf_rdptr_int : 'h0; end default: begin twiddle_end_addr = 'h0; @@ -420,13 +414,17 @@ always_ff @(posedge clk or negedge reset_n) begin incr_twiddle_addr_reg <= 'b0; incr_twiddle_addr_reg_d2 <= 'b0; end + else if (zeroize) begin + incr_twiddle_addr_reg <= 'b0; + incr_twiddle_addr_reg_d2 <= 'b0; + end else begin incr_twiddle_addr_reg <= incr_twiddle_addr_fsm; incr_twiddle_addr_reg_d2 <= incr_twiddle_addr_reg; end end -assign incr_twiddle_addr = ct_mode ? incr_twiddle_addr_fsm : incr_twiddle_addr_reg; //_d2; +assign incr_twiddle_addr = ct_mode ? incr_twiddle_addr_fsm : incr_twiddle_addr_reg; always_ff @(posedge clk or negedge reset_n) begin @@ -514,7 +512,7 @@ always_ff @(posedge clk or negedge reset_n) begin chunk_rand_offset <= random[5:2]; chunk_count <= random[5:2]; end - else if ((ct_mode & (buf_count == 'h3)) | ((gs_mode | (pwo_mode & incr_pw_rd_addr)) & /*(buf_wrptr_reg[10] == 'h3)*/(index_count == 'h3))) begin //update chunk after every 4 cycles - TODO: stop chunk counting when there's no incr_rd_addr in ntt/intt modes + else if ((ct_mode & (buf_count == 'h3)) | ((gs_mode | (pwo_mode & incr_pw_rd_addr)) & (index_count == 'h3))) begin //update chunk after every 4 cycles - TODO: stop chunk counting when there's no incr_rd_addr in ntt/intt modes chunk_count <= (chunk_count == 'hf) ? 'h0 : chunk_count + 'h1; end end @@ -610,14 +608,13 @@ always_ff @(posedge clk or negedge reset_n) begin end always_comb begin - last_rd_addr = /*ct_mode &*/ (mem_rd_addr == mem_rd_base_addr + MEM_LAST_ADDR); //TODO: other modes - last_wr_addr = /*ct_mode &*/ (mem_wr_addr == mem_wr_base_addr + MEM_LAST_ADDR); //TODO: other modes - buf_rdptr_int = (shuffle_en & ct_mode) ? index_rand_offset + buf_count : buf_count; //TODO: flop + last_rd_addr = (mem_rd_addr == mem_rd_base_addr + MEM_LAST_ADDR); + last_wr_addr = (mem_wr_addr == mem_wr_base_addr + MEM_LAST_ADDR); + buf_rdptr_int = (shuffle_en & ct_mode) ? index_rand_offset + buf_count : buf_count; //TODO: flop? buf_rdptr = (shuffle_en & ct_mode) ? buf_rdptr_f : buf_count; - // buf_wrptr = gs_mode ? index_rand_offset + buf_count : buf_count; latch_chunk_rand_offset = arc_IDLE_WR_STAGE | arc_WR_MEM_WR_STAGE | arc_WR_WAIT_WR_STAGE; - latch_index_rand_offset = ct_mode ? (buf_wrptr == 'h3) : (gs_mode | (pwo_mode & incr_pw_rd_addr)) & (arc_RD_STAGE_RD_EXEC | (index_count == 'h3)); //TODO pwo mode - mem_rd_index_ofst = (pwo_mode | gs_mode) ? (index_count + index_rand_offset) : 'h0; //TODO: pwo mode, not used in ct mode + latch_index_rand_offset = ct_mode ? (buf_wrptr == 'h3) : (gs_mode | (pwo_mode & incr_pw_rd_addr)) & (arc_RD_STAGE_RD_EXEC | (index_count == 'h3)); + mem_rd_index_ofst = (pwo_mode | gs_mode) ? (index_count + index_rand_offset) : 'h0; end @@ -734,7 +731,7 @@ always_comb begin else mem_rd_en_fsm = (ntt_mode inside {ct, gs}) ? (mem_rd_addr <= MEM_LAST_ADDR + mem_rd_base_addr) : 1'b0; bf_enable_fsm = pwo_mode ? sampler_valid : 1'b1; - incr_twiddle_addr_fsm = ntt_mode inside {ct, gs}; //1'b1; + incr_twiddle_addr_fsm = ntt_mode inside {ct, gs}; rd_addr_step = ct_mode ? NTT_READ_ADDR_STEP : INTT_READ_ADDR_STEP; incr_pw_rd_addr = sampler_valid & pwo_mode; pw_rden_fsm = sampler_valid & pwo_mode; @@ -852,11 +849,11 @@ always_comb begin arc_WR_STAGE_WR_BUF ? WR_BUF : arc_WR_STAGE_WR_WAIT? WR_WAIT : arc_WR_STAGE_IDLE ? WR_IDLE : WR_STAGE; - rst_wr_addr = 1'b1; - rst_wr_valid_count = 1'b1; - buf_wr_rst_count_intt = gs_mode; - buf_rd_rst_count_intt = gs_mode; - rst_pw_addr = pwo_mode; + rst_wr_addr = 1'b1; + rst_wr_valid_count = 1'b1; + buf_wr_rst_count_intt = gs_mode; + buf_rd_rst_count_intt = gs_mode; + rst_pw_addr = pwo_mode; end WR_BUF: begin write_fsm_state_ns = arc_WR_BUF_WR_MEM ? WR_MEM : WR_BUF; @@ -880,19 +877,19 @@ always_comb begin end WR_WAIT: begin if (shuffle_en) begin - write_fsm_state_ns = arc_WR_WAIT_WR_STAGE ? WR_STAGE : /*arc_WR_WAIT_WR_MEM ? WR_MEM :*/ WR_WAIT; + write_fsm_state_ns = arc_WR_WAIT_WR_STAGE ? WR_STAGE : WR_WAIT; wr_addr_step = gs_mode ? INTT_WRITE_ADDR_STEP : NTT_WRITE_ADDR_STEP; end else begin write_fsm_state_ns = arc_WR_WAIT_WR_STAGE ? WR_STAGE : arc_WR_WAIT_WR_MEM ? WR_MEM : WR_WAIT; wr_addr_step = INTT_WRITE_ADDR_STEP; end - buf_wren_intt = shuffle_en ? 'b0 : (buf_count <= 'h3); //1'b0; + buf_wren_intt = shuffle_en ? 'b0 : (buf_count <= 'h3); buf_rden_intt = shuffle_en ? gs_mode : 'b1; - incr_mem_wr_addr = (ct_mode | gs_mode); //1'b1; - mem_wr_en_fsm = shuffle_en ? gs_mode : (ct_mode | gs_mode); //1'b1; + incr_mem_wr_addr = (ct_mode | gs_mode); + mem_wr_en_fsm = shuffle_en ? gs_mode : (ct_mode | gs_mode); - incr_pw_wr_addr = shuffle_en ? pwo_mode & arc_WR_WAIT_WR_STAGE : arc_WR_WAIT_WR_MEM; //MEM; + incr_pw_wr_addr = shuffle_en ? pwo_mode & arc_WR_WAIT_WR_STAGE : arc_WR_WAIT_WR_MEM; pw_wren_fsm = shuffle_en ? 'b0 : arc_WR_WAIT_WR_MEM; end default: begin @@ -906,13 +903,13 @@ always_comb begin incr_rounds = arc_WR_MEM_WR_STAGE | arc_WR_WAIT_WR_STAGE; //TODO: revisit for high-perf mode (if we go with above opt) if (shuffle_en) begin buf_wren = pwo_mode ? 1'b0 : buf_wren_ntt_reg | buf_wren_intt_reg; - buf_rden = pwo_mode ? 1'b0 : ct_mode ? buf_rden_ntt_reg : /*buf_rden_ntt |*/ buf_rden_intt; + buf_rden = pwo_mode ? 1'b0 : ct_mode ? buf_rden_ntt_reg : buf_rden_intt; bf_enable = (gs_mode || pwo_mode) ? bf_enable_reg_d2 : bf_enable_reg; //bf_enable_fsm; //In gs mode, memory is directly feeding bf2x2, so we need to enable it one cycle later - mem_wr_en = gs_mode ? mem_wr_en_fsm : mem_wr_en_reg; //TODO pwo mode, GS mode + shuffling + mem_wr_en = gs_mode ? mem_wr_en_fsm : mem_wr_en_reg; mem_rd_en = (gs_mode | pwo_mode) ? mem_rd_en_reg : mem_rd_en_fsm; twiddle_addr = gs_mode ? twiddle_addr_reg_d3 : twiddle_addr_int; pw_rden = pw_rden_reg; - pw_wren = pwm_mode ? pw_wren_reg/*_d2*/ : pw_wren_reg; + pw_wren = pwm_mode ? pw_wren_reg : pw_wren_reg; end else begin buf_wren = pwo_mode ? 1'b0 : buf_wren_ntt_reg | buf_wren_intt; @@ -943,7 +940,6 @@ always_ff @(posedge clk or negedge reset_n) begin twiddle_addr_reg_d3 <= 'h0; pw_rden_reg <= '0; pw_wren_reg <= '0; - pw_wren_reg_d2 <= '0; end else if (zeroize) begin buf_wren_ntt_reg <= 'b0; @@ -957,7 +953,6 @@ always_ff @(posedge clk or negedge reset_n) begin twiddle_addr_reg_d3 <= 'h0; pw_rden_reg <= '0; pw_wren_reg <= '0; - pw_wren_reg_d2 <= '0; end else begin buf_wren_ntt_reg <= buf_wren_ntt; @@ -971,7 +966,6 @@ always_ff @(posedge clk or negedge reset_n) begin twiddle_addr_reg_d3 <= twiddle_addr_reg_d2; pw_rden_reg <= pw_rden_fsm; pw_wren_reg <= pw_wren_fsm; - pw_wren_reg_d2 <= pw_wren_reg; end end diff --git a/src/ntt_top/rtl/ntt_top.sv b/src/ntt_top/rtl/ntt_top.sv index 567a4d5..fa0b97e 100644 --- a/src/ntt_top/rtl/ntt_top.sv +++ b/src/ntt_top/rtl/ntt_top.sv @@ -51,20 +51,11 @@ module ntt_top //Ctrl signal ports input mode_t mode, input wire ntt_enable, - //TB purpose - remove and refine tb - // input wire load_tb_values, - // input wire [5:0] load_tb_addr, //NTT base addr ports - // input wire [MEM_ADDR_WIDTH-1:0] src_base_addr, - // input wire [MEM_ADDR_WIDTH-1:0] interim_base_addr, - // input wire [MEM_ADDR_WIDTH-1:0] dest_base_addr, input ntt_mem_addr_t ntt_mem_base_addr, //PWM base addr ports - // input wire [MEM_ADDR_WIDTH-1:0] pw_base_addr_a, - // input wire [MEM_ADDR_WIDTH-1:0] pw_base_addr_b, - // input wire [MEM_ADDR_WIDTH-1:0] pw_base_addr_c, input pwo_mem_addr_t pwo_mem_base_addr, //PWM control @@ -98,7 +89,6 @@ module ntt_top //Write IF logic mem_wren, mem_wren_reg, mem_wren_mux; logic [MLDSA_MEM_ADDR_WIDTH-1:0] mem_wr_addr, mem_wr_addr_reg, mem_wr_addr_mux; - // logic [(4*REG_SIZE)-1:0] mem_wr_data; logic [MEM_DATA_WIDTH-1:0] mem_wr_data_int, mem_wr_data_reg, mem_wr_data_reg_d2; //Read IF @@ -107,7 +97,6 @@ module ntt_top logic [(4*REG_SIZE)-1:0] mem_rd_data_reg; //Butterfly IF signals - // logic bf_mode; bf_uvwi_t uvw_i; bf_uvo_t uv_o, uv_o_reg; logic bf_enable, bf_enable_reg, bf_enable_mux; @@ -168,8 +157,8 @@ module ntt_top assign mem_wr_req.addr = !pwo_mode ? mem_wr_addr_mux : pwm_wr_addr_c_reg; assign mem_wr_data_int = !pwo_mode ? (ct_mode ? {1'b0, uv_o_reg.v21_o, 1'b0, uv_o_reg.u21_o, 1'b0, uv_o_reg.v20_o, 1'b0, uv_o_reg.u20_o} : buf_data_o) : pwm_wr_data_reg; - assign mem_wr_data = shuffle_en ? pwm_mode ? mem_wr_data_reg/*_d2*/ : (pwa_mode | pws_mode) ? mem_wr_data_reg : mem_wr_data_int - : mem_wr_data_int; //ct_mode ? mem_wr_data_reg : mem_wr_data_int; //TODO: gs, pwo modes + assign mem_wr_data = shuffle_en ? pwm_mode ? mem_wr_data_reg : (pwa_mode | pws_mode) ? mem_wr_data_reg : mem_wr_data_int + : mem_wr_data_int; //mem rd - NTT/INTT mode, read ntt data. PWM mode, read accumulate data from c mem. PWA/S mode, unused assign mem_rd_req.rd_wr_en = (ct_mode || gs_mode) ? (mem_rden ? RW_READ : RW_IDLE) : pwm_mode ? (pw_rden_dest_mem ? RW_READ : RW_IDLE) : RW_IDLE; @@ -186,7 +175,7 @@ module ntt_top if (shuffle_en) begin pwm_b_rd_req.rd_wr_en = sampler_valid_reg & pwo_mode ? (pw_rden ? RW_READ : RW_IDLE) : RW_IDLE; //pw_rden is delayed a clk due to shuffling, so use delayed sampler_valid to line it up pwm_b_rd_req.addr = sampler_valid_reg & pwo_mode ? pw_mem_rd_addr_b : 'h0; - pwm_rd_data_b = pwm_b_rd_data_reg; //sampler_valid_reg ? pwm_b_rd_data_reg : pwm_b_rd_data; + pwm_rd_data_b = pwm_b_rd_data_reg; end else begin pwm_b_rd_req.rd_wr_en = sampler_valid & pwo_mode ? (pw_rden ? RW_READ : RW_IDLE) : RW_IDLE; @@ -259,10 +248,10 @@ module ntt_top end gs: begin if (shuffle_en) begin - uvw_i.w11_i = twiddle_factor/*_reg*/[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; - uvw_i.w10_i = twiddle_factor/*_reg*/[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; - uvw_i.w01_i = twiddle_factor/*_reg*/[(2*NTT_REG_SIZE)-1:NTT_REG_SIZE]; - uvw_i.w00_i = twiddle_factor/*_reg*/[NTT_REG_SIZE-1:0]; + uvw_i.w11_i = twiddle_factor[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; + uvw_i.w10_i = twiddle_factor[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; + uvw_i.w01_i = twiddle_factor[(2*NTT_REG_SIZE)-1:NTT_REG_SIZE]; + uvw_i.w00_i = twiddle_factor[NTT_REG_SIZE-1:0]; end else begin uvw_i.w11_i = twiddle_factor_reg[(3*NTT_REG_SIZE)-1:(2*NTT_REG_SIZE)]; @@ -390,7 +379,7 @@ module ntt_top pw_uvw_i = 'h0; end gs: begin - uvw_i.u00_i = mem_rd_data_reg[REG_SIZE-2:0]; //[22:0] + uvw_i.u00_i = mem_rd_data_reg[REG_SIZE-2:0]; uvw_i.u01_i = mem_rd_data_reg[(3*REG_SIZE)-2:(2*REG_SIZE)]; uvw_i.v00_i = mem_rd_data_reg[(2*REG_SIZE)-2:REG_SIZE]; uvw_i.v01_i = mem_rd_data_reg[(4*REG_SIZE)-2:(3*REG_SIZE)]; @@ -409,10 +398,10 @@ module ntt_top pw_uvw_i.u3_i = pwm_rd_data_a_reg[(4*REG_SIZE)-2:(3*REG_SIZE)]; if (shuffle_en) begin - pw_uvw_i.v0_i = pwm_rd_data_b/*_reg*/[REG_SIZE-2:0]; - pw_uvw_i.v1_i = pwm_rd_data_b/*_reg*/[(2*REG_SIZE)-2:REG_SIZE]; - pw_uvw_i.v2_i = pwm_rd_data_b/*_reg*/[(3*REG_SIZE)-2:(2*REG_SIZE)]; - pw_uvw_i.v3_i = pwm_rd_data_b/*_reg*/[(4*REG_SIZE)-2:(3*REG_SIZE)]; + pw_uvw_i.v0_i = pwm_rd_data_b[REG_SIZE-2:0]; + pw_uvw_i.v1_i = pwm_rd_data_b[(2*REG_SIZE)-2:REG_SIZE]; + pw_uvw_i.v2_i = pwm_rd_data_b[(3*REG_SIZE)-2:(2*REG_SIZE)]; + pw_uvw_i.v3_i = pwm_rd_data_b[(4*REG_SIZE)-2:(3*REG_SIZE)]; end else begin pw_uvw_i.v0_i = pwm_rd_data_b_reg[REG_SIZE-2:0]; From 9c3d3e89f1d071755e387af7c27e1e3e5ef06af9 Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Tue, 29 Oct 2024 12:18:23 -0700 Subject: [PATCH 10/17] Lint cleanup --- src/norm_check/rtl/norm_check_ctrl.sv | 2 -- src/norm_check/rtl/norm_check_top.sv | 2 -- src/ntt_top/rtl/ntt_ctrl.sv | 10 +++++----- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/norm_check/rtl/norm_check_ctrl.sv b/src/norm_check/rtl/norm_check_ctrl.sv index ef94b38..d95158b 100644 --- a/src/norm_check/rtl/norm_check_ctrl.sv +++ b/src/norm_check/rtl/norm_check_ctrl.sv @@ -36,8 +36,6 @@ module norm_check_ctrl input wire norm_check_enable, input chk_norm_mode_t mode, - input wire shuffle_en, - input wire [5:0] random, input wire [MLDSA_MEM_ADDR_WIDTH-1:0] mem_base_addr, output mem_if_t mem_rd_req, output logic check_enable, diff --git a/src/norm_check/rtl/norm_check_top.sv b/src/norm_check/rtl/norm_check_top.sv index f347628..b291ca0 100644 --- a/src/norm_check/rtl/norm_check_top.sv +++ b/src/norm_check/rtl/norm_check_top.sv @@ -45,8 +45,6 @@ module norm_check_top input wire [MLDSA_MEM_ADDR_WIDTH-1:0] mem_base_addr, output mem_if_t mem_rd_req, input [4*REG_SIZE-1:0] mem_rd_data, - input wire shuffle_en, - input wire [5:0] random, output logic invalid, output logic norm_check_ready, output logic norm_check_done diff --git a/src/ntt_top/rtl/ntt_ctrl.sv b/src/ntt_top/rtl/ntt_ctrl.sv index e51d14d..90abee1 100644 --- a/src/ntt_top/rtl/ntt_ctrl.sv +++ b/src/ntt_top/rtl/ntt_ctrl.sv @@ -271,7 +271,7 @@ always_comb begin if (shuffle_en) begin mem_rd_addr_nxt = (gs_mode | pwo_mode) ? (4*chunk_count) + (rd_addr_step*mem_rd_index_ofst) + mem_rd_base_addr : mem_rd_addr + rd_addr_step; //TODO pwo modes - mem_wr_addr_nxt = ct_mode ? (4*(chunk_count_reg[0])) + (wr_addr_step*buf_rdptr_reg[0]) + mem_wr_base_addr : gs_mode ? mem_wr_addr + wr_addr_step : (4*(chunk_count_reg[4])) + (wr_addr_step*buf_rdptr_reg[4]); //TODO: pwo modes + mem_wr_addr_nxt = ct_mode ? (MEM_ADDR_WIDTH+1)'((4*chunk_count_reg[0]) + (wr_addr_step*buf_rdptr_reg[0]) + mem_wr_base_addr) : gs_mode ? mem_wr_addr + wr_addr_step : (MEM_ADDR_WIDTH+1)'((4*chunk_count_reg[4]) + (wr_addr_step*buf_rdptr_reg[4])); end else begin mem_rd_addr_nxt = mem_rd_addr + rd_addr_step; @@ -383,22 +383,22 @@ always_comb begin 'h0: begin twiddle_end_addr = ct_mode ? 'd0 : 'd63; twiddle_offset = 'h0; - twiddle_rand_offset = ct_mode ? 'h0 : (chunk_count_reg[BF_LATENCY])*4 + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]; + twiddle_rand_offset = ct_mode ? 'h0 : 7'((4*chunk_count_reg[BF_LATENCY]) + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]); end 'h1: begin twiddle_end_addr = ct_mode ? 'd3 : 'd15; twiddle_offset = ct_mode ? 'd1 : 'd64; - twiddle_rand_offset = ct_mode ? buf_rdptr_int : (chunk_count_reg[BF_LATENCY] % 4)*4 + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]; + twiddle_rand_offset = ct_mode ? 7'(buf_rdptr_int) : 7'((chunk_count_reg[BF_LATENCY] % 4)*4 + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]); end 'h2: begin twiddle_end_addr = ct_mode ? 'd15 : 'd3; twiddle_offset = ct_mode ? 'd5 : 'd80; - twiddle_rand_offset = ct_mode ? (chunk_count % 'd4)*'d4 + buf_rdptr_int : buf_wrptr_reg[INTT_WRBUF_LATENCY-1]; + twiddle_rand_offset = ct_mode ? 7'((chunk_count % 'd4)*'d4 + buf_rdptr_int) : 7'(buf_wrptr_reg[INTT_WRBUF_LATENCY-1]); end 'h3: begin twiddle_end_addr = ct_mode ? 'd63 : 'd0; twiddle_offset = ct_mode ? 'd21 : 'd84; - twiddle_rand_offset = ct_mode ? (chunk_count % 'd16)*4 + buf_rdptr_int : 'h0; + twiddle_rand_offset = ct_mode ? 7'((chunk_count % 'd16)*4 + buf_rdptr_int) : 'h0; end default: begin twiddle_end_addr = 'h0; From 1be066366ef573e97f1e61fe31fa684fe53d1ba3 Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Tue, 29 Oct 2024 12:31:23 -0700 Subject: [PATCH 11/17] Fix lint --- src/norm_check/rtl/norm_check_top.sv | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/norm_check/rtl/norm_check_top.sv b/src/norm_check/rtl/norm_check_top.sv index b291ca0..7bf3d33 100644 --- a/src/norm_check/rtl/norm_check_top.sv +++ b/src/norm_check/rtl/norm_check_top.sv @@ -108,8 +108,6 @@ module norm_check_top .zeroize(zeroize), .norm_check_enable(norm_check_enable), .mode(mode), - .shuffle_en(shuffle_en), - .random(random), .mem_base_addr(mem_base_addr), .mem_rd_req(mem_rd_req), .norm_check_done(norm_check_done_int), From 89881b4f726adf6ca4a1b1bc97b204e93e027c1d Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Tue, 29 Oct 2024 19:50:54 +0000 Subject: [PATCH 12/17] MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/gh_ntt_shuffling' with updated timestamp and hash after successful run --- .github/workflow_metadata/pr_hash | 2 +- .github/workflow_metadata/pr_timestamp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflow_metadata/pr_hash b/.github/workflow_metadata/pr_hash index 381f645..ce98afc 100644 --- a/.github/workflow_metadata/pr_hash +++ b/.github/workflow_metadata/pr_hash @@ -1 +1 @@ -b7a9e1c8a62a99338ac4f080636de0e3920d3876368e5fc942000a11764ff6de6cbb3165eeb4de0707aa78cd3a52a7b0 \ No newline at end of file +a10a0bd46115b8e9f8dfaaf3594287f13021fa412498ebe98278970b0ecc59178a157346f02ed0c810fe7d0bcc13a367 \ No newline at end of file diff --git a/.github/workflow_metadata/pr_timestamp b/.github/workflow_metadata/pr_timestamp index 41395e9..0a43a3f 100644 --- a/.github/workflow_metadata/pr_timestamp +++ b/.github/workflow_metadata/pr_timestamp @@ -1 +1 @@ -1729808326 \ No newline at end of file +1730231452 \ No newline at end of file From 4f6ce4fc658b2244a17fa91f632c994ba42e447e Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Tue, 29 Oct 2024 13:02:06 -0700 Subject: [PATCH 13/17] Disable shuffling for now --- src/mldsa_top/rtl/mldsa_top.sv | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mldsa_top/rtl/mldsa_top.sv b/src/mldsa_top/rtl/mldsa_top.sv index 8228127..f76af93 100644 --- a/src/mldsa_top/rtl/mldsa_top.sv +++ b/src/mldsa_top/rtl/mldsa_top.sv @@ -443,11 +443,11 @@ generate end MLDSA_NTT: begin mode[g_inst] = ct; - shuffle_en[g_inst] = 1; + // shuffle_en[g_inst] = 1; end MLDSA_INTT: begin mode[g_inst] = gs; - shuffle_en[g_inst] = 1; + // shuffle_en[g_inst] = 1; end MLDSA_PWM_SMPL: begin mode[g_inst] = pwm; @@ -463,23 +463,23 @@ generate MLDSA_PWM: begin mode[g_inst] = pwm; sampler_valid[g_inst] = 1; - shuffle_en[g_inst] = 1; + // shuffle_en[g_inst] = 1; end MLDSA_PWM_ACCUM: begin mode[g_inst] = pwm; accumulate[g_inst] = 1; sampler_valid[g_inst] = 1; - shuffle_en[g_inst] = 1; + // shuffle_en[g_inst] = 1; end MLDSA_PWA: begin mode[g_inst] = pwa; sampler_valid[g_inst] = 1; - shuffle_en[g_inst] = 1; + // shuffle_en[g_inst] = 1; end MLDSA_PWS: begin mode[g_inst] = pws; sampler_valid[g_inst] = 1; - shuffle_en[g_inst] = 1; + // shuffle_en[g_inst] = 1; end default: begin end From f6347ec00c73f877e2091055eb4be37928148586 Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Tue, 29 Oct 2024 20:26:38 +0000 Subject: [PATCH 14/17] MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/gh_ntt_shuffling' with updated timestamp and hash after successful run --- .github/workflow_metadata/pr_hash | 2 +- .github/workflow_metadata/pr_timestamp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflow_metadata/pr_hash b/.github/workflow_metadata/pr_hash index ce98afc..27b723f 100644 --- a/.github/workflow_metadata/pr_hash +++ b/.github/workflow_metadata/pr_hash @@ -1 +1 @@ -a10a0bd46115b8e9f8dfaaf3594287f13021fa412498ebe98278970b0ecc59178a157346f02ed0c810fe7d0bcc13a367 \ No newline at end of file +b129cdf5497e923bdc43e6e0e49d0b7a7e995caba71d5afef78cb887d7d6ba3bef0f59747da13c80bb7646534a3e1b1c \ No newline at end of file diff --git a/.github/workflow_metadata/pr_timestamp b/.github/workflow_metadata/pr_timestamp index 0a43a3f..235075c 100644 --- a/.github/workflow_metadata/pr_timestamp +++ b/.github/workflow_metadata/pr_timestamp @@ -1 +1 @@ -1730231452 \ No newline at end of file +1730233596 \ No newline at end of file From 1c9e1b0f32ba78a068d28f54488c3e0513d2549f Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Wed, 30 Oct 2024 12:53:12 -0700 Subject: [PATCH 15/17] Remove old TODO --- src/ntt_top/rtl/ntt_ctrl.sv | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ntt_top/rtl/ntt_ctrl.sv b/src/ntt_top/rtl/ntt_ctrl.sv index 90abee1..51bb9d5 100644 --- a/src/ntt_top/rtl/ntt_ctrl.sv +++ b/src/ntt_top/rtl/ntt_ctrl.sv @@ -270,7 +270,7 @@ always_comb begin mem_wr_base_addr = rounds_count[0] ? dest_base_addr : interim_base_addr; if (shuffle_en) begin - mem_rd_addr_nxt = (gs_mode | pwo_mode) ? (4*chunk_count) + (rd_addr_step*mem_rd_index_ofst) + mem_rd_base_addr : mem_rd_addr + rd_addr_step; //TODO pwo modes + mem_rd_addr_nxt = (gs_mode | pwo_mode) ? (4*chunk_count) + (rd_addr_step*mem_rd_index_ofst) + mem_rd_base_addr : mem_rd_addr + rd_addr_step; mem_wr_addr_nxt = ct_mode ? (MEM_ADDR_WIDTH+1)'((4*chunk_count_reg[0]) + (wr_addr_step*buf_rdptr_reg[0]) + mem_wr_base_addr) : gs_mode ? mem_wr_addr + wr_addr_step : (MEM_ADDR_WIDTH+1)'((4*chunk_count_reg[4]) + (wr_addr_step*buf_rdptr_reg[4])); end else begin @@ -292,7 +292,7 @@ always_ff @(posedge clk or negedge reset_n) begin end else if (rst_rd_addr) begin if (shuffle_en) - mem_rd_addr <= ct_mode ? mem_rd_base_addr + chunk_rand_offset : (gs_mode | pwo_mode) ? mem_rd_base_addr + (4*chunk_rand_offset) : mem_rd_base_addr; //TODO: pwo + mem_rd_addr <= ct_mode ? mem_rd_base_addr + chunk_rand_offset : (gs_mode | pwo_mode) ? mem_rd_base_addr + (4*chunk_rand_offset) : mem_rd_base_addr; else mem_rd_addr <= mem_rd_base_addr; end @@ -704,7 +704,7 @@ always_comb begin arc_RD_STAGE_IDLE ? RD_IDLE : RD_STAGE; rst_rd_addr = 1'b1; rst_rd_valid_count = 1'b1; - //reset if in ntt mode, since writes won't use the buffer, it's safe to reset buffer - TODO: in shuffled NTT, the buffer is sram style, so reset may not be an option. Check on this! + //reset if in ntt mode, since writes won't use the buffer, it's safe to reset buffer buf_wr_rst_count_ntt = ct_mode; buf_rd_rst_count_ntt = ct_mode; rst_twiddle_addr = !butterfly_ready; From 4008aff4e016999363d8ce5b07edbf78fd8016ca Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Thu, 31 Oct 2024 18:08:39 +0000 Subject: [PATCH 16/17] MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/gh_ntt_shuffling' with updated timestamp and hash after successful run --- .github/workflow_metadata/pr_hash | 2 +- .github/workflow_metadata/pr_timestamp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflow_metadata/pr_hash b/.github/workflow_metadata/pr_hash index bfd0b7b..425c709 100644 --- a/.github/workflow_metadata/pr_hash +++ b/.github/workflow_metadata/pr_hash @@ -1 +1 @@ -ccfbc1a0345264c8c044b6e441c9757f47467e8478b89f57bf502ad6f3b1b29f3be1047b7ee043d5790c8a24977ceed9 +6125c7fa394b48b54b85ad7453c0c374b7853fa525492ba552f8bef665069a6d4878c5906115ca22510e83bc2c9be4dd \ No newline at end of file diff --git a/.github/workflow_metadata/pr_timestamp b/.github/workflow_metadata/pr_timestamp index 5c9100d..6dba180 100644 --- a/.github/workflow_metadata/pr_timestamp +++ b/.github/workflow_metadata/pr_timestamp @@ -1 +1 @@ -1730388545 +1730398117 \ No newline at end of file From b2ca140525b2dabd30b399bc72e0aca0f811c80c Mon Sep 17 00:00:00 2001 From: Kiran Upadhyayula Date: Thu, 31 Oct 2024 19:06:49 +0000 Subject: [PATCH 17/17] MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/gh_ntt_shuffling' with updated timestamp and hash after successful run --- .github/workflow_metadata/pr_hash | 2 +- .github/workflow_metadata/pr_timestamp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflow_metadata/pr_hash b/.github/workflow_metadata/pr_hash index 6955128..c90da91 100644 --- a/.github/workflow_metadata/pr_hash +++ b/.github/workflow_metadata/pr_hash @@ -1 +1 @@ -6125c7fa394b48b54b85ad7453c0c374b7853fa525492ba552f8bef665069a6d4878c5906115ca22510e83bc2c9be4dd +7a3eb92951b68c97aac175fd0faf18d82ba97c85e5669450c340c4ccba5752f566eee2cfbcd5cc15b632a79a57d2bea7 \ No newline at end of file diff --git a/.github/workflow_metadata/pr_timestamp b/.github/workflow_metadata/pr_timestamp index c72021f..074a8d0 100644 --- a/.github/workflow_metadata/pr_timestamp +++ b/.github/workflow_metadata/pr_timestamp @@ -1 +1 @@ -1730398117 +1730401607 \ No newline at end of file