Skip to content

Commit

Permalink
Replace AXI-Stream output with ping-pong buffer. Works for K!=1
Browse files Browse the repository at this point in the history
  • Loading branch information
Aba committed Aug 25, 2023
1 parent e7f4af2 commit b73b23b
Show file tree
Hide file tree
Showing 8 changed files with 409 additions and 191 deletions.
21 changes: 11 additions & 10 deletions rtl/axis_out_shift.sv
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,26 @@ module axis_out_shift #(
input logic m_ready,
output logic [ROWS -1:0][WORD_WIDTH -1:0] m_data,
output tuser_st m_user,
output logic m_valid, m_last
output logic m_valid, m_last, m_last_pkt
);

logic [COLS-1:0][ROWS -1:0][WORD_WIDTH-1:0] shift_data;
logic [KW_MAX/2:0][COLS-1:0] lut_valid, lut_valid_last, lut_last;
logic [COLS-1:0] shift_last, shift_valid;
logic [KW_MAX/2:0][COLS-1:0] lut_valid, lut_valid_last, lut_last_pkt, lut_last;
logic [COLS-1:0] shift_last, shift_last_pkt, shift_valid;

genvar k2, c_1;
for (k2=0; k2 <= KW_MAX/2; k2++)
for (c_1=0; c_1 < COLS; c_1++) begin
localparam k = k2*2+1, c = c_1 + 1;
assign lut_valid [k2][c_1] = (c % k == 0);
assign lut_valid_last [k2][c_1] = ((c % k > k2) || (c % k == 0)) && (c <= (COLS/k)*k);
assign lut_last [k2][c_1] = (c == k2+1);
assign lut_last [k2][c_1] = (c == k);
assign lut_last_pkt [k2][c_1] = (c == k2+1);
end

wire valid_mask = !s_user.is_w_first_kw2 && !s_user.is_config;
wire [COLS-1:0] s_valid_cols_sel = s_user.is_w_last ? lut_valid_last[s_user.kw2] : lut_valid[s_user.kw2];
wire [COLS-1:0] s_last_cols_sel = s_user.is_w_last ? lut_last_pkt [s_user.kw2] : lut_last [s_user.kw2];


logic [$clog2(COLS+1)-1:0] counter;
Expand All @@ -44,21 +46,23 @@ module axis_out_shift #(
if (!aresetn) begin
state <= IDLE;
s_ready <= 1;
{shift_valid, shift_last} <= '0;
{shift_valid, shift_last_pkt, shift_last} <= '0;
end else case (state)
IDLE : if (s_valid && valid_mask) begin
state <= SHIFT;
s_ready <= 0;

shift_data <= s_data;
shift_valid <= s_valid_cols_sel & {COLS{valid_mask}};
shift_last <= {COLS{s_last}} & lut_last[s_user.kw2];
shift_last <= s_last_cols_sel;
shift_last_pkt <= {COLS{s_last}} & lut_last_pkt[s_user.kw2];
end
SHIFT : if (m_ready) begin

shift_data <= shift_data << (ROWS * WORD_WIDTH);
shift_valid <= shift_valid << 1;
shift_last <= shift_last << 1;
shift_last_pkt <= shift_last_pkt << 1;

if (counter == 1) begin
state <= IDLE;
Expand All @@ -75,9 +79,6 @@ module axis_out_shift #(
assign m_data = shift_data [COLS-1];
assign m_valid = shift_valid[COLS-1];
assign m_last = shift_last [COLS-1];

// assign last_kw = last && ((COLS-(COLS+1-counter)) < kw);
// assign m_valid = state == SHIFT && (counter % kw == 0 || (last && (counter % kw > kw/2)));
// assign m_last = m_valid && last_kw && (counter == kw/2+1);
assign m_last_pkt = shift_last_pkt [COLS-1];

endmodule
65 changes: 29 additions & 36 deletions rtl/dnn_engine.v
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ module dnn_engine #(
Y_BITS = `Y_BITS ,
M_DATA_WIDTH_HF_CONV = COLS * ROWS * Y_BITS,
M_DATA_WIDTH_HF_CONV_DW = ROWS * Y_BITS,
DW_IN_KEEP_WIDTH = M_DATA_WIDTH_HF_CONV_DW/8,

S_PIXELS_WIDTH_LF = `S_PIXELS_WIDTH_LF ,
S_WEIGHTS_WIDTH_LF = `S_WEIGHTS_WIDTH_LF ,
M_OUTPUT_WIDTH_LF = `M_OUTPUT_WIDTH_LF

OUT_ADDR_WIDTH = 10,
OUT_BITS = 32
)(
input wire aclk,
input wire aresetn,
Expand All @@ -36,11 +37,11 @@ module dnn_engine #(
input wire [S_WEIGHTS_WIDTH_LF -1:0] s_axis_weights_tdata,
input wire [S_WEIGHTS_KEEP_WIDTH-1:0] s_axis_weights_tkeep,

input wire m_axis_tready,
output wire m_axis_tvalid,
output wire m_axis_tlast ,
output wire [M_OUTPUT_WIDTH_LF-1:0] m_axis_tdata,
output wire [M_KEEP_WIDTH -1:0] m_axis_tkeep
input wire [(OUT_ADDR_WIDTH+2)-1:0] bram_addr_a,
output wire [ OUT_BITS -1:0] bram_rddata_a,
input wire bram_en_a,
output wire done_fill,
input wire done_firmware
);

localparam TUSER_WIDTH = `TUSER_WIDTH;
Expand All @@ -54,8 +55,8 @@ module dnn_engine #(
wire [K_BITS*COLS -1:0] weights_m_data;
wire [TUSER_WIDTH -1:0] weights_m_user;

wire dw_s_axis_tready, dw_s_axis_tvalid, dw_s_axis_tlast ;
wire [M_DATA_WIDTH_HF_CONV_DW -1:0] dw_s_axis_tdata ;
wire out_s_ready, out_s_valid, out_s_last;
wire [M_DATA_WIDTH_HF_CONV_DW -1:0] out_s_data;

axis_pixels PIXELS (
.aclk (aclk ),
Expand Down Expand Up @@ -104,35 +105,25 @@ module dnn_engine #(
.s_user (weights_m_user ),
.s_data_pixels (pixels_m_data ),
.s_data_weights (weights_m_data ),
.m_ready (dw_s_axis_tready),
.m_valid (dw_s_axis_tvalid),
.m_data (dw_s_axis_tdata ),
.m_last (dw_s_axis_tlast )
.m_ready (out_s_ready ),
.m_valid (out_s_valid ),
.m_data (out_s_data ),
.m_last (out_s_last )
);

alex_axis_adapter_any #(
.S_DATA_WIDTH (M_DATA_WIDTH_HF_CONV_DW),
.M_DATA_WIDTH (M_OUTPUT_WIDTH_LF),
.S_KEEP_ENABLE (1),
.M_KEEP_ENABLE (1),
.S_KEEP_WIDTH (DW_IN_KEEP_WIDTH),
.M_KEEP_WIDTH (M_KEEP_WIDTH),
.ID_ENABLE (0),
.DEST_ENABLE (0),
.USER_ENABLE (0)
) DW_OUT (
.clk (aclk ),
.rst (~aresetn),
.s_axis_tready (dw_s_axis_tready),
.s_axis_tvalid (dw_s_axis_tvalid),
.s_axis_tdata (dw_s_axis_tdata ),
.s_axis_tkeep ({DW_IN_KEEP_WIDTH{1'b1}}),
.s_axis_tlast (dw_s_axis_tlast ),
.m_axis_tready (m_axis_tready),
.m_axis_tvalid (m_axis_tvalid),
.m_axis_tdata (m_axis_tdata ),
.m_axis_tkeep (m_axis_tkeep ),
.m_axis_tlast (m_axis_tlast )
out_ram_switch OUT_RAM (
.clk (aclk ),
.rstn (aresetn ),
.s_ready (out_s_ready ),
.s_valid (out_s_valid ),
.s_data (out_s_data ),
.s_last (out_s_last ),

.bram_addr_a (bram_addr_a ),
.bram_rddata_a(bram_rddata_a ),
.bram_en_a (bram_en_a ),
.done_fill (done_fill ),
.done_firmware(done_firmware )
);
endmodule

Expand All @@ -154,6 +145,7 @@ module proc_engine_out #(
input wire m_ready,
output wire m_valid,
output wire [M_DATA_WIDTH_HF_CONV_DW-1:0] m_data,
output wire m_last_pkt,
output wire m_last
);

Expand Down Expand Up @@ -187,6 +179,7 @@ module proc_engine_out #(
.m_ready (m_ready ),
.m_valid (m_valid ),
.m_data (m_data ),
.m_last_pkt (m_last_pkt ),
.m_last (m_last )
);

Expand Down
159 changes: 159 additions & 0 deletions rtl/out_ram_switch.sv
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
`include "../rtl/include/params.svh"

module out_ram_switch #(
localparam ROWS = `ROWS ,
COLS = `COLS ,
KW_MAX = `KW_MAX ,
Y_BITS = `Y_BITS ,
RAM_LATENCY = 2,
WORD_WIDTH = 32, // always 32, byte enable available for smaller width, but complicated
ADDR_WIDTH = 10 // word address
)(
input logic clk, rstn,

output logic s_ready,
input logic [ROWS -1:0][Y_BITS -1:0] s_data,
input logic s_valid, s_last,

input logic [(ADDR_WIDTH+2)-1:0] bram_addr_a,
output logic [ WORD_WIDTH -1:0] bram_rddata_a,
input logic bram_en_a,

output logic done_fill,
input logic done_firmware
);

localparam BITS_COLS = $clog2(COLS), BITS_ROWS = $clog2(ROWS);
enum {W_IDLE_S, W_WRITE_S, W_FILL_S, W_SWITCH_S} state_write, state_write_next;
enum {R_IDLE_S, R_DONE_FILL, R_READ_S, R_WAIT_S, R_SWITCH_S} state_read, state_read_next;

logic i_read, i_write, s_first, en_shift, last, df_was_high, lc_rows, l_rows;

logic [ADDR_WIDTH-1:0] ram_w_addr, ram_r_addr;
logic [ROWS-1:0][Y_BITS -1:0] shift_reg;
logic [Y_BITS -1:0] ram_din;

logic [1:0][ADDR_WIDTH-1:0] ram_addr;
logic [1:0][Y_BITS -1:0] ram_dout;
logic [1:0] done_read, done_write, ram_wen;

// Switching RAMs
always_ff @(posedge clk)
if (!rstn) {i_write, i_read} <= 0;
else begin
if (state_write == W_SWITCH_S) i_write <= !i_write;
if (state_read == R_SWITCH_S) i_read <= !i_read;
end

always_ff @(posedge clk) begin
state_write <= !rstn ? W_IDLE_S : state_write_next;
state_read <= !rstn ? R_IDLE_S : state_read_next;
end


// -----
// WRITE
// -----
always_comb
unique case (state_write)
W_IDLE_S : if (done_read [i_write]) state_write_next = W_WRITE_S; // counter
W_WRITE_S : if (lc_rows && last ) state_write_next = W_FILL_S;
W_FILL_S : state_write_next = W_SWITCH_S;
W_SWITCH_S : state_write_next = W_IDLE_S;
endcase

always_ff @(posedge clk) // Special case - first beat of a packet. Bcz lc_rows = 0 at start
if (!rstn || (state_write == W_FILL_S)) s_first <= 1;
else if (s_valid && s_ready) s_first <= 0;

always_comb begin
s_ready = (state_write == W_WRITE_S && state_write_next == W_WRITE_S) && (s_first || l_rows); // first or after shifting rows
en_shift = (state_write == W_WRITE_S) && (l_rows ? s_valid || last : 1) && !s_first; // if last, wait for valid
ram_din = shift_reg[0];
end

always_ff @(posedge clk) // SHIFT REG - write data
if (s_valid && s_ready) shift_reg <= s_data;
else if (en_shift) shift_reg <= shift_reg >> Y_BITS;

counter #(.W(BITS_ROWS)) C_ROWS (.clk(clk), .reset(state_write == W_IDLE_S), .en(en_shift), .max_in(BITS_ROWS'(ROWS-1)), .last_clk(lc_rows), .last(l_rows));

always_ff @(posedge clk) // w_addr
if (!rstn || state_write==W_IDLE_S) ram_w_addr <= 0;
else if (en_shift) ram_w_addr <= ram_w_addr + 1'b1;

always_ff @(posedge clk) // Store last
if (!rstn) last <= 0;
else if (s_valid && s_ready) last <= s_last;



// -----
// READ
// -----
always_comb
unique case (state_read)
R_IDLE_S : if (done_write [i_read]) state_read_next = R_DONE_FILL;
R_DONE_FILL : state_read_next = R_READ_S;
R_READ_S : if (!df_was_high && done_firmware) state_read_next = R_WAIT_S;
R_WAIT_S : state_read_next = R_SWITCH_S;
R_SWITCH_S : state_read_next = R_IDLE_S;
endcase

assign ram_r_addr = bram_addr_a[(ADDR_WIDTH+2)-1:2];
assign bram_rddata_a = WORD_WIDTH'(signed'(ram_dout[i_read])); // pad to 32
assign done_fill = state_read == R_DONE_FILL; // one clock

// Done Firmware Was High
// To prevent the case: fsm waits READ, firmware raises df, fsm leaves READ, goes around WAIT, SWITCH, READ - df is still high, so, fsm moves to WAIT.
// with this, df being pulled down is recorded in df_was_high. fsm waits in READ until !df_was_high && done_firmware
always_ff @(posedge clk)
if (!rstn) df_was_high <= 0;
else if (done_firmware==0 && df_was_high) df_was_high <= 0; // df is going to zero
else if (state_read == R_READ_S)
if (!df_was_high && done_firmware) df_was_high <= 1; // df is going to high during READ_S

// Wait for done firmware to fall before deasserting done_fill - to prevent the loop in firmware missing done_fill
// always_ff @(posedge clk)
// if (!rstn) done_fill <= 0;
// else if (state_read_next == R_DONE_FILL) done_fill <= 1;
// else if (df_was_high && !done_firmware ) done_fill <= 0;


// -----
// PING PONG
// -----
generate
for (genvar i=0; i<2; i++) begin: I

always_ff @(posedge clk)
if (!rstn) done_write[i] <= 0;
else if (i==i_write)
if (state_write_next == W_WRITE_S) done_write[i] <= 0;
else if (state_write == W_SWITCH_S) done_write[i] <= 1;

always_ff @(posedge clk)
if (!rstn) done_read [i] <= 1;
else if (i==i_read)
if (state_read_next == R_DONE_FILL) done_read [i] <= 0;
else if (state_read == R_SWITCH_S) done_read [i] <= 1;

assign ram_wen [i] = i == i_write && en_shift && !s_first;
assign ram_addr [i] = (i == i_write && state_write == W_WRITE_S) ? ram_w_addr : ram_r_addr;

localparam RAM_ADDR_BITS = $clog2(COLS*ROWS);
ram_output #(
.DEPTH (COLS * ROWS),
.WIDTH (Y_BITS ),
.LATENCY (RAM_LATENCY)
) RAM (
.clka (clk),
.ena (1'b1),
.wea (ram_wen [i] ),
.addra (RAM_ADDR_BITS'(ram_addr[i])),
.dina (ram_din ),
.douta (ram_dout[i] )
);
end
endgenerate
endmodule
6 changes: 3 additions & 3 deletions test/py/param_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def test_dnn_engine(COMPILE):
Config(7, 16),
Config(5, 16),
Config(3, 24),
Config(1, 50, flatten=True),
Config(1, 10, dense= True),
# Config(1, 10, flatten=True),
# Config(1, 10, dense= True),
]

'''
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_dnn_engine(COMPILE):

y_wpt = b.r.CO_PRL*b.c.ROWS
y_wpt_last = b.r.CO_PRL*b.c.ROWS*(b.r.KW//2+1)
f.write(f" '{{w_wpt:{b.we[-1].size}, w_wpt_p0:{b.we[0].size}, x_wpt:{b.xe[-1].size}, x_wpt_p0:{b.xe[0].size}, y_wpt:{y_wpt}, y_wpt_last:{y_wpt_last}, n_it:{b.r.IT}, n_p:{b.r.CP} }}")
f.write(f" '{{w_wpt:{b.we[-1].size}, w_wpt_p0:{b.we[0].size}, x_wpt:{b.xe[-1].size}, x_wpt_p0:{b.xe[0].size}, y_wpt:{y_wpt}, y_wpt_last:{y_wpt_last}, y_nl:{b.r.XN*b.r.L}, y_w:{b.r.XW-b.r.KW//2}, n_it:{b.r.IT}, n_p:{b.r.CP} }}")
if b.idx != len(bundles)-1:
f.write(',\n')
f.write(f"\n}};")
Expand Down
Loading

0 comments on commit b73b23b

Please sign in to comment.