Skip to content

Commit

Permalink
[spatz_vfu] added optimized reduction implementation for ELEN=64
Browse files Browse the repository at this point in the history
Tested with dotproduct 4096 : 38% util (5% improv.)
  • Loading branch information
Navaneeth-KunhiPurayil committed Jan 5, 2025
1 parent 6abaa04 commit 628afd2
Showing 1 changed file with 124 additions and 47 deletions.
171 changes: 124 additions & 47 deletions hw/ip/spatz/src/spatz_vfu.sv
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,24 @@ module spatz_vfu
Reduction_NormalExecution,
Reduction_Wait,
Reduction_Init,
Reduction_Fill,
Reduction_Reduce,
Reduction_IntraLane,
Reduction_InterLane,
Reduction_WriteBack
} reduction_state_t;
reduction_state_t reduction_state_d, reduction_state_q;
`FF(reduction_state_q, reduction_state_d, Reduction_NormalExecution)

// Reduction intralane
logic result_buf_valid_d, result_buf_valid_q;
vrf_data_t result_buf_d, result_buf_q;
logic [idx_width(ELEN*N_FPU)-1 : 0] shift_amnt_d, shift_amnt_q;

`FF(result_buf_valid_q, result_buf_valid_d, 1'b0)
`FF(result_buf_q, result_buf_d, '0)
`FF(shift_amnt_q, shift_amnt_d, ELEN)

// Is the reduction done?
logic reduction_done;

Expand Down Expand Up @@ -284,7 +296,7 @@ module spatz_vfu
//////////////

// Reduction registers
elen_t [1:0] reduction_q, reduction_d;
vrf_data_t [1:0] reduction_q, reduction_d;
`FFL(reduction_q, reduction_d, reduction_operand_ready_d, '0)

// IPU results
Expand Down Expand Up @@ -369,7 +381,11 @@ module spatz_vfu

// Only request when initializing the reduction register
reduction_operand_request[0] = (reduction_state_q == Reduction_Init) || !spatz_req.op_arith.is_reduction;
reduction_operand_request[1] = (reduction_state_q inside {Reduction_Init, Reduction_Reduce}) || !spatz_req.op_arith.is_reduction;
reduction_operand_request[1] = (reduction_state_q inside {Reduction_Init, Reduction_Fill, Reduction_Reduce}) || !spatz_req.op_arith.is_reduction;

result_buf_d = result_buf_q;
result_buf_valid_d = result_buf_valid_q;
shift_amnt_d = shift_amnt_q;

unique case (reduction_state_q)
Reduction_NormalExecution: begin
Expand Down Expand Up @@ -401,77 +417,74 @@ module spatz_vfu
unique case (spatz_req.vtype.vsew)
EW_8 : begin
reduction_d[0] = $unsigned(vrf_rdata_i[0][7:0]);
reduction_d[1] = $unsigned(vrf_rdata_i[1][8*reduction_pointer_q[idx_width(N_FU*ELENB)-1:0] +: 8]);
reduction_d[1] = $unsigned(vrf_rdata_i[1]);
end
EW_16: begin
reduction_d[0] = $unsigned(vrf_rdata_i[0][15:0]);
reduction_d[1] = $unsigned(vrf_rdata_i[1][16*reduction_pointer_q[idx_width(N_FU*ELENB)-2:0] +: 16]);
reduction_d[1] = $unsigned(vrf_rdata_i[1]);
end
EW_32: begin
reduction_d[0] = $unsigned(vrf_rdata_i[0][31:0]);
reduction_d[1] = $unsigned(vrf_rdata_i[1][32*reduction_pointer_q[idx_width(N_FU*ELENB)-3:0] +: 32]);
reduction_d[1] = $unsigned(vrf_rdata_i[1]);
end
default: begin
`ifdef MEMPOOL_SPATZ
reduction_d = '0;
`else
if (MAXEW == EW_64) begin
reduction_d[0] = $unsigned(vrf_rdata_i[0][63:0]);
reduction_d[1] = $unsigned(vrf_rdata_i[1][64*reduction_pointer_q[idx_width(N_FU*ELENB)-4:0] +: 64]);
reduction_d[0] = {{((N_FPU-1)*ELEN){1'b0}}, $unsigned(vrf_rdata_i[0][63:0])};
reduction_d[1] = $unsigned(vrf_rdata_i[1]);
end
`endif
end
endcase
// verilator lint_on SELRANGE

if (vrf_rvalid_i[0] && vrf_rvalid_i[1]) begin
automatic logic [idx_width(N_FU*ELENB)-1:0] pnt;

reduction_operand_ready_d = 1'b1;
reduction_pointer_d = reduction_pointer_q + 1;
reduction_state_d = Reduction_Reduce;
reduction_state_d = Reduction_Fill;

// Request next word
pnt = reduction_pointer_d << int'(spatz_req.vtype.vsew);
if (!(|pnt))
word_issued = 1'b1;
word_issued = 1'b1;
end
end

Reduction_Fill: begin
// Initialize the reduction
// verilator lint_off SELRANGE
`ifdef MEMPOOL_SPATZ
reduction_d = '0;
`else
reduction_d[0] = result_valid ? result : '0;
reduction_d[1] = $unsigned(vrf_rdata_i[1]);
`endif

// verilator lint_on SELRANGE
if (vrf_rvalid_i[1]) begin
reduction_operand_ready_d = 1'b1;
reduction_pointer_d = reduction_pointer_q + 1;
word_issued = 1'b1;
if (result_valid) begin
reduction_state_d = Reduction_Reduce;
result_ready = 1'b1;
end
end
end

Reduction_Reduce: begin
// Forward result
// verilator lint_off SELRANGE
unique case (spatz_req.vtype.vsew)
EW_8 : begin
reduction_d[0] = $unsigned(result[7:0]);
reduction_d[1] = $unsigned(vrf_rdata_i[1][8*reduction_pointer_q[idx_width(N_FU*ELENB)-1:0] +: 8]);
end
EW_16: begin
reduction_d[0] = $unsigned(result[15:0]);
reduction_d[1] = $unsigned(vrf_rdata_i[1][16*reduction_pointer_q[idx_width(N_FU*ELENB)-2:0] +: 16]);
end
EW_32: begin
reduction_d[0] = $unsigned(result[31:0]);
reduction_d[1] = $unsigned(vrf_rdata_i[1][32*reduction_pointer_q[idx_width(N_FU*ELENB)-3:0] +: 32]);
end
default: begin
`ifdef MEMPOOL_SPATZ
reduction_d = '0;
`else
if (MAXEW == EW_64) begin
reduction_d[0] = $unsigned(result[63:0]);
reduction_d[1] = $unsigned(vrf_rdata_i[1][64*reduction_pointer_q[idx_width(N_FU*ELENB)-4:0] +: 64]);
end
`endif
end
endcase
// verilator lint_on SELRANGE
`ifdef MEMPOOL_SPATZ
reduction_d = '0;
`else
reduction_d[0] = $unsigned(result);
reduction_d[1] = $unsigned(vrf_rdata_i[1]);
`endif

// Got a result!
// verilator lint_on SELRANGE
if (result_valid[0]) begin
// Did we get an operand?
if (vrf_rvalid_i[1]) begin
automatic logic [idx_width(N_FU*ELENB)-1:0] pnt;

// Bump pointer
reduction_pointer_d = reduction_pointer_q + 1;
Expand All @@ -483,17 +496,81 @@ module spatz_vfu
reduction_operand_ready_d = 1'b1;

// Request next word
pnt = reduction_pointer_d << int'(spatz_req.vtype.vsew);
if (!(|pnt))
word_issued = 1'b1;
word_issued = 1'b1;
end
end

// Are we done?
if (reduction_pointer_q == spatz_req.vl) begin
if ((reduction_pointer_q == ((spatz_req.vl >> $clog2(N_FPU))-1)) && (reduction_operand_ready_d==1'b1)) begin
reduction_state_d = Reduction_IntraLane;
end
end

Reduction_IntraLane: begin
// verilator lint_off SELRANGE
`ifdef MEMPOOL_SPATZ
reduction_d = '0;
`else
reduction_d[0] = $unsigned(result);
reduction_d[1] = $unsigned(result_buf_d);
`endif

// verilator lint_on SELRANGE
if (~result_buf_valid_q & result_valid[0]) begin
result_buf_valid_d = 1'b1;
result_buf_d = result;
result_ready = 1'b1;
end

// verilator lint_on SELRANGE
if (result_buf_valid_q & result_valid[0]) begin
result_buf_valid_d = 1'b0;

// Trigger a request
reduction_operand_ready_d = 1'b1;

// Bump pointer
reduction_pointer_d = reduction_pointer_q + 1;

// Acknowledge result
result_ready = 1'b1;
end

// Are we done?
if ((reduction_pointer_q == ((spatz_req.vl >> $clog2(N_FPU)) + 2)) && (reduction_operand_ready_d==1'b1)) begin
reduction_state_d = Reduction_InterLane;
end
end

Reduction_InterLane: begin
// verilator lint_off SELRANGE
`ifdef MEMPOOL_SPATZ
reduction_d = '0;
`else
reduction_d[0] = $unsigned(result) >> shift_amnt_q;
reduction_d[1] = $unsigned(result);
`endif

// verilator lint_on SELRANGE
if (result_valid[0]) begin

// Trigger a request
reduction_operand_ready_d = 1'b1;

// Bump pointer
reduction_pointer_d = reduction_pointer_q + 1;

// Acknowledge result
result_ready = 1'b1;

// Update shift amnt
shift_amnt_d = shift_amnt_q << 1;
end

// Are we done?
if ((reduction_pointer_q == ((spatz_req.vl >> $clog2(N_FPU)) + 2 + $clog2(N_FPU))) && (reduction_operand_ready_d == 1'b1)) begin
reduction_state_d = Reduction_WriteBack;
result_ready = 1'b0;
reduction_operand_ready_d = 1'b0;
shift_amnt_d = ELEN;
end
end

Expand Down

0 comments on commit 628afd2

Please sign in to comment.