From 55c28b0a7845a3b9e367966734143bfa81d80946 Mon Sep 17 00:00:00 2001 From: Jerry Zhao Date: Tue, 9 Mar 2021 00:26:04 -0800 Subject: [PATCH] Make LoopMatmul FSM prefetch ahead --- src/main/scala/gemmini/Controller.scala | 5 +- src/main/scala/gemmini/DMA.scala | 76 ++++++++++++++------- src/main/scala/gemmini/LoopMatmul.scala | 88 +++++++++++++++++++------ src/main/scala/gemmini/Scratchpad.scala | 3 + 4 files changed, 126 insertions(+), 46 deletions(-) diff --git a/src/main/scala/gemmini/Controller.scala b/src/main/scala/gemmini/Controller.scala index 83d781d2..29d9e39e 100644 --- a/src/main/scala/gemmini/Controller.scala +++ b/src/main/scala/gemmini/Controller.scala @@ -198,7 +198,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] val max_lds = rob_entries * 1 / 4 val max_exs = rob_entries * 3 / 4 val max_sts = rob_entries * 1 / 8 - val (loop_cmd, loop_matmul_unroller_busy) = LoopMatmul(raw_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, + val (loop_cmd, loop_matmul_unroller_busy, prefetch) = LoopMatmul(raw_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization, meshRows*tileRows, coreMaxAddrBits, rob_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, inputType.getWidth, accType.getWidth, dma_maxbytes) val unrolled_cmd = Queue(loop_cmd) @@ -303,6 +303,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] ex_controller.io.acc.read <> spad.module.io.acc.read ex_controller.io.acc.write <> spad.module.io.acc.write + spad.module.io.prefetch <> prefetch + prefetch.ready := spad.module.io.prefetch.ready + // Im2Col unit val im2col = Module(new Im2Col(outer.config)) diff --git a/src/main/scala/gemmini/DMA.scala b/src/main/scala/gemmini/DMA.scala index c12b28ac..8eb424ab 100644 --- a/src/main/scala/gemmini/DMA.scala +++ b/src/main/scala/gemmini/DMA.scala @@ -6,8 +6,8 @@ import chisel3.experimental.DataMirror import freechips.rocketchip.config.Parameters import freechips.rocketchip.diplomacy.{IdRange, LazyModule, LazyModuleImp} -import freechips.rocketchip.tile.{CoreBundle, HasCoreParameters} -import freechips.rocketchip.tilelink.TLBundleA +import freechips.rocketchip.tile.{CoreBundle, HasCoreParameters, RoCCCommand} +import freechips.rocketchip.tilelink.{TLBundleA, TLHints, TLMessages} import testchipip.TLHelper import freechips.rocketchip.rocket.MStatus import freechips.rocketchip.rocket.constants.MemoryOpConstants @@ -58,6 +58,7 @@ class StreamReader[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T val io = IO(new Bundle { val req = Flipped(Decoupled(new StreamReadRequest(spad_rows, acc_rows, config.mvin_scale_t_bits))) val resp = Decoupled(new StreamReadResponse(spadWidth, accWidth, spad_rows, acc_rows, aligned_to, config.mvin_scale_t_bits)) + val prefetch = Flipped(Decoupled(new RoCCCommand)) val tlb = new FrontendTLBIO val busy = Output(Bool()) val flush = Input(Bool()) @@ -70,16 +71,17 @@ class StreamReader[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T val beatPacker = Module(new BeatMerger(beatBits, maxBytes, spadWidth, accWidth, spad_rows, acc_rows, maxBytes, aligned_to, meshRows, config.mvin_scale_t_bits, nCmds)) core.module.io.req <> io.req + core.module.io.prefetch <> io.prefetch io.tlb <> core.module.io.tlb io.busy := xactTracker.io.busy core.module.io.flush := io.flush xactTracker.io.alloc <> core.module.io.reserve - xactTracker.io.peek.xactid := RegEnableThru(core.module.io.beatData.bits.xactid, beatPacker.io.req.fire()) - xactTracker.io.peek.pop := beatPacker.io.in.fire() && core.module.io.beatData.bits.last + xactTracker.io.peek.xactid := RegEnableThru(core.module.io.beatData.bits.xactid, core.module.io.beatData.fire()) + xactTracker.io.peek.pop := core.module.io.beatData.fire() && core.module.io.beatData.bits.last - core.module.io.beatData.ready := beatPacker.io.in.ready - beatPacker.io.req.valid := core.module.io.beatData.valid + core.module.io.beatData.ready := beatPacker.io.in.ready || core.module.io.beatData.bits.is_hintack + beatPacker.io.req.valid := core.module.io.beatData.valid && !core.module.io.beatData.bits.is_hintack beatPacker.io.req.bits := xactTracker.io.peek.entry beatPacker.io.req.bits.lg_len_req := core.module.io.beatData.bits.lg_len_req beatPacker.io.in.valid := core.module.io.beatData.valid @@ -106,6 +108,7 @@ class StreamReadBeat (val nXacts: Int, val beatBits: Int, val maxReqBytes: Int) val data = UInt(beatBits.W) val lg_len_req = UInt(log2Up(log2Up(maxReqBytes+1)+1).W) val last = Bool() + val is_hintack = Bool() } // TODO StreamReaderCore and StreamWriter are actually very alike. Is there some parent class they could both inherit from? @@ -131,6 +134,7 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf val io = IO(new Bundle { val req = Flipped(Decoupled(new StreamReadRequest(spad_rows, acc_rows, config.mvin_scale_t_bits))) + val prefetch = Flipped(Decoupled(new RoCCCommand)) val reserve = new XactTrackerAllocIO(nXacts, maxBytes, spadWidth, accWidth, spad_rows, acc_rows, maxBytes, config.mvin_scale_t_bits, nCmds) val beatData = Decoupled(new StreamReadBeat(nXacts, beatBits, maxBytes)) val tlb = new FrontendTLBIO @@ -193,6 +197,13 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf lgSize = read_lg_size )._2 + val prefetch = edge.Hint( + fromSource = io.reserve.xactid, + toAddress = 0.U, + lgSize = 1.U, + param = TLHints.PREFETCH_READ + )._2 + class TLBundleAWithInfo extends Bundle { val tl_a = DataMirror.internal.chiselTypeClone[TLBundleA](tl.a.bits) val vaddr = Output(UInt(vaddrBits.W)) @@ -200,10 +211,39 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf } val untranslated_a = Wire(Decoupled(new TLBundleAWithInfo)) - untranslated_a.valid := state === s_req_new_block && io.reserve.ready - untranslated_a.bits.tl_a := get - untranslated_a.bits.vaddr := read_vaddr - untranslated_a.bits.status := req.status + untranslated_a.valid := false.B + untranslated_a.bits := DontCare + io.prefetch.ready := false.B + io.reserve.valid := false.B + when (state === s_req_new_block) { + io.reserve.valid := untranslated_a.ready + untranslated_a.valid := io.reserve.ready + untranslated_a.bits.tl_a := get + untranslated_a.bits.vaddr := read_vaddr + untranslated_a.bits.status := req.status + + when (untranslated_a.fire()) { + val next_vaddr = req.vaddr + read_bytes_read // send_size + val new_page = next_vaddr(pgIdxBits-1, 0) === 0.U + req.vaddr := next_vaddr + + bytesRequested := bytesRequested + read_bytes_read // send_size + + // when (send_size >= bytesLeft) { + when (read_bytes_read >= bytesLeft) { + // We're done with this request at this point + state_machine_ready_for_req := true.B + state := s_idle + } + } + } .elsewhen (io.prefetch.valid) { + io.reserve.valid := untranslated_a.ready + untranslated_a.valid := io.reserve.ready + io.prefetch.ready := untranslated_a.ready && io.reserve.ready + untranslated_a.bits.tl_a := prefetch + untranslated_a.bits.vaddr := io.prefetch.bits.rs1 + untranslated_a.bits.status := io.prefetch.bits.status + } // 0 goes to retries, 1 goes to state machine val retry_a = Wire(Decoupled(new TLBundleAWithInfo)) @@ -233,7 +273,6 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf tl.a.bits := translate_q.io.deq.bits.tl_a tl.a.bits.address := io.tlb.resp.paddr - io.reserve.valid := state === s_req_new_block && untranslated_a.ready // TODO decouple "reserve.valid" from "tl.a.ready" io.reserve.entry.shift := read_shift io.reserve.entry.is_acc := req.is_acc io.reserve.entry.accumulate := req.accumulate @@ -253,20 +292,6 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf if (bytesRequested.getWidth >= log2Up(spadWidthBytes+1)) bytesRequested / spadWidthBytes.U else 0.U) io.reserve.entry.spad_row_offset := Mux(req.has_acc_bitwidth, bytesRequested % accWidthBytes.U, bytesRequested % spadWidthBytes.U) - when (untranslated_a.fire()) { - val next_vaddr = req.vaddr + read_bytes_read // send_size - val new_page = next_vaddr(pgIdxBits-1, 0) === 0.U - req.vaddr := next_vaddr - - bytesRequested := bytesRequested + read_bytes_read // send_size - - // when (send_size >= bytesLeft) { - when (read_bytes_read >= bytesLeft) { - // We're done with this request at this point - state_machine_ready_for_req := true.B - state := s_idle - } - } // Forward TileLink read responses to the reservation buffer tl.d.ready := io.beatData.ready @@ -275,6 +300,7 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf io.beatData.bits.data := tl.d.bits.data io.beatData.bits.lg_len_req := tl.d.bits.size io.beatData.bits.last := edge.last(tl.d) + io.beatData.bits.is_hintack := tl.d.bits.opcode === TLMessages.HintAck // TODO the size data is already returned from TileLink, so there's no need for us to store it in the XactTracker ourselves // Accepting requests to kick-start the state machine diff --git a/src/main/scala/gemmini/LoopMatmul.scala b/src/main/scala/gemmini/LoopMatmul.scala index 5932a51a..f1d62679 100644 --- a/src/main/scala/gemmini/LoopMatmul.scala +++ b/src/main/scala/gemmini/LoopMatmul.scala @@ -28,6 +28,7 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulLdAReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops))) val cmd = Decoupled(Output(new RoCCCommand)) + val prefetch = Output(Valid(new RoCCCommand)) val i = Output(UInt(iterator_bitwidth.W)) val k = Output(UInt(iterator_bitwidth.W)) val idle = Output(Bool()) @@ -72,17 +73,31 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In mvin_cmd.rs1 := dram_addr mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr - io.req.ready := state === idle - io.i := i - io.k := k - io.idle := state === idle - - io.cmd.valid := state =/= idle && !io.rob_overloaded - io.cmd.bits := mvin_cmd - + class CmdQEntry extends Bundle { + val cmd = new RoCCCommand + val i = UInt() + val k = UInt() + } + val cmd_q = Module(new Queue(new CmdQEntry, 8)) + cmd_q.io.enq.valid := state =/= idle && !io.rob_overloaded + cmd_q.io.enq.bits.cmd := mvin_cmd + cmd_q.io.enq.bits.i := i + cmd_q.io.enq.bits.k := k + + io.cmd.valid := cmd_q.io.deq.valid + cmd_q.io.deq.ready := io.cmd.ready + io.cmd.bits := cmd_q.io.deq.bits.cmd + + io.prefetch.valid := cmd_q.io.enq.fire() + io.prefetch.bits := cmd_q.io.enq.bits.cmd + + io.req.ready := state === idle && !cmd_q.io.deq.valid + io.i := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.i, 0.U) + io.k := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.k, 0.U) io.loop_id := req.loop_id + io.idle := state === idle && !cmd_q.io.deq.valid - when (io.cmd.fire()) { + when (cmd_q.io.enq.fire()) { // The order here is k, j, i val i_blocks = Mux(req.transpose, max_blocks, 1.U) val k_blocks = Mux(req.transpose, 1.U, max_blocks) @@ -126,6 +141,7 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In val io = IO(new Bundle { val req = Flipped(Decoupled(new LoopMatmulLdBReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops))) val cmd = Decoupled(Output(new RoCCCommand)) + val prefetch = Output(Valid(new RoCCCommand)) val k = Output(UInt(iterator_bitwidth.W)) val j = Output(UInt(iterator_bitwidth.W)) @@ -173,17 +189,31 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In mvin_cmd.rs1 := dram_addr mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr - io.req.ready := state === idle - io.k := k - io.j := j - io.idle := state === idle - - io.cmd.valid := state =/= idle && !io.rob_overloaded - io.cmd.bits := mvin_cmd - + class CmdQEntry extends Bundle { + val cmd = new RoCCCommand + val k = UInt() + val j = UInt() + } + val cmd_q = Module(new Queue(new CmdQEntry, 8)) + cmd_q.io.enq.valid := state =/= idle && !io.rob_overloaded + cmd_q.io.enq.bits.cmd := mvin_cmd + cmd_q.io.enq.bits.k := k + cmd_q.io.enq.bits.j := j + + io.cmd.valid := cmd_q.io.deq.valid + cmd_q.io.deq.ready := io.cmd.ready + io.cmd.bits := cmd_q.io.deq.bits.cmd + + io.prefetch.valid := cmd_q.io.enq.fire() + io.prefetch.bits := cmd_q.io.enq.bits.cmd + + io.req.ready := state === idle && !cmd_q.io.deq.valid + io.k := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.k, 0.U) + io.j := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.j, 0.U) io.loop_id := req.loop_id + io.idle := state === idle && !cmd_q.io.deq.valid - when (io.cmd.fire()) { + when (cmd_q.io.enq.fire()) { // The order here is k, j, i val j_blocks = Mux(req.transpose, 1.U, max_blocks) val k_blocks = Mux(req.transpose, max_blocks, 1.U) @@ -606,6 +636,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: val io = IO(new Bundle { val in = Flipped(Decoupled(new RoCCCommand)) val out = Decoupled(new RoCCCommand) + val prefetch = Decoupled(new RoCCCommand) val ld_utilization = Input(UInt(log2Up(rob_size+1).W)) val st_utilization = Input(UInt(log2Up(rob_size+1).W)) val ex_utilization = Input(UInt(log2Up(rob_size+1).W)) @@ -645,6 +676,23 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: ldab_arb.io.forceA := !ab_loads_on_same_loop && ldA.io.loop_id === head_loop_id ldab_arb.io.forceB := !ab_loads_on_same_loop && ldB.io.loop_id === head_loop_id + val prefetch_arb = Module(new Arbiter(new RoCCCommand, 2)) + prefetch_arb.io.in(0).valid := ldA.io.prefetch.fire() + prefetch_arb.io.in(0).bits := ldA.io.prefetch.bits + prefetch_arb.io.in(1).valid := ldB.io.prefetch.fire() + prefetch_arb.io.in(1).bits := ldB.io.prefetch.bits + val prefetch_q_size = 4 + val prefetch_q = Module(new Queue(new RoCCCommand, prefetch_q_size, pipe=true)) + io.prefetch <> prefetch_q.io.deq + io.prefetch.bits.status := cmd.bits.status // TODO This is not guaranteed to be the correct fix! We must fix this + prefetch_q.io.enq <> prefetch_arb.io.out + when (prefetch_q.io.enq.valid && prefetch_q.io.count === prefetch_q_size.U) { + prefetch_q.io.deq.ready := true.B + } + + io.busy := cmd.valid || loop_configured + + // Create global arbiter val arb = Module(new Arbiter(new RoCCCommand(), 4)) arb.io.in(0) <> stC.io.cmd @@ -890,13 +938,13 @@ object LoopMatmul { def apply(in: DecoupledIO[RoCCCommand], ld_utilization: UInt, st_utilization: UInt, ex_utilization: UInt, block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: Int, max_exs: Int, max_sts: Int, max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int) - (implicit p: Parameters): Tuple2[DecoupledIO[RoCCCommand], Bool] = { + (implicit p: Parameters): (DecoupledIO[RoCCCommand], Bool, DecoupledIO[RoCCCommand]) = { val mod = Module(new LoopMatmul(block_size, coreMaxAddrBits, rob_size, max_lds, max_exs, max_sts, max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes)) mod.io.in <> in mod.io.ld_utilization := ld_utilization mod.io.st_utilization := st_utilization mod.io.ex_utilization := ex_utilization - (mod.io.out, mod.io.busy) + (mod.io.out, mod.io.busy, mod.io.prefetch) } } diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index 2dbaf6c7..7c50e38c 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -196,6 +196,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // Misc. ports val busy = Output(Bool()) val flush = Input(Bool()) + + val prefetch = Flipped(Decoupled(new RoCCCommand)) }) val write_dispatch_q = Queue(io.dma.write.req) @@ -269,6 +271,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val (mvin_scale_in, mvin_scale_out) = VectorScalarMultiplier(config.mvin_scale_args, config.inputType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), is_acc = false) val (mvin_scale_acc_in, mvin_scale_acc_out) = if (mvin_scale_shared) (mvin_scale_in, mvin_scale_out) else VectorScalarMultiplier(config.mvin_scale_acc_args, config.accType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), is_acc = true) + reader.module.io.prefetch <> io.prefetch mvin_scale_in.valid := reader.module.io.resp.valid && (mvin_scale_shared.B || !reader.module.io.resp.bits.is_acc || (reader.module.io.resp.bits.is_acc && !reader.module.io.resp.bits.has_acc_bitwidth))