Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LoopMatmul FSM prefetch ahead #75

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/main/scala/gemmini/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,15 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
val rob = Module(new ROB(outer.config, new RoCCCommand))

val raw_cmd = Queue(io.cmd)
val max_lds = rob_entries * 1 / 4
val max_exs = rob_entries * 3 / 4
val max_sts = rob_entries * 1 / 8

// TODO replace 4,12,2 with parameters based on ROB size
val (conv_cmd, loop_conv_unroller_busy) = LoopConv(raw_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization,
meshRows*tileRows, coreMaxAddrBits, rob_entries, 4, 12, 2, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
meshRows*tileRows, coreMaxAddrBits, rob_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes)

// val (compressed_cmd, compressor_busy) = InstCompressor(unrolled_cmd)
// compressed_cmd.ready := false.B

// val (unrolled_cmd, loop_matmul_unroller_busy) = LoopMatmul(unrolled_cmd_after_conv, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization,
val max_lds = rob_entries * 1 / 4
val max_exs = rob_entries * 3 / 4
val max_sts = rob_entries * 1 / 8
val (loop_cmd, loop_matmul_unroller_busy) = LoopMatmul(conv_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization,
val (loop_cmd, loop_matmul_unroller_busy, prefetch) = LoopMatmul(conv_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization,
meshRows*tileRows, coreMaxAddrBits, rob_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes)

Expand Down Expand Up @@ -228,6 +223,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
ex_controller.io.acc.read_resp <> spad.module.io.acc.read_resp
ex_controller.io.acc.write <> spad.module.io.acc.write

spad.module.io.prefetch <> prefetch
prefetch.ready := spad.module.io.prefetch.ready

// Im2Col unit
val im2col = Module(new Im2Col(outer.config))

Expand Down
76 changes: 51 additions & 25 deletions src/main/scala/gemmini/DMA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import chisel3.experimental.DataMirror

import freechips.rocketchip.config.Parameters
import freechips.rocketchip.diplomacy.{IdRange, LazyModule, LazyModuleImp}
import freechips.rocketchip.tile.{CoreBundle, HasCoreParameters}
import freechips.rocketchip.tilelink.TLBundleA
import freechips.rocketchip.tile.{CoreBundle, HasCoreParameters, RoCCCommand}
import freechips.rocketchip.tilelink.{TLBundleA, TLHints, TLMessages}
import testchipip.TLHelper
import freechips.rocketchip.rocket.MStatus
import freechips.rocketchip.rocket.constants.MemoryOpConstants
Expand Down Expand Up @@ -58,6 +58,7 @@ class StreamReader[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T
val io = IO(new Bundle {
val req = Flipped(Decoupled(new StreamReadRequest(spad_rows, acc_rows, config.mvin_scale_t_bits)))
val resp = Decoupled(new StreamReadResponse(spadWidth, accWidth, spad_rows, acc_rows, aligned_to, config.mvin_scale_t_bits))
val prefetch = Flipped(Decoupled(new RoCCCommand))
val tlb = new FrontendTLBIO
val busy = Output(Bool())
val flush = Input(Bool())
Expand All @@ -70,16 +71,17 @@ class StreamReader[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T
val beatPacker = Module(new BeatMerger(beatBits, maxBytes, spadWidth, accWidth, spad_rows, acc_rows, maxBytes, aligned_to, meshRows, config.mvin_scale_t_bits, nCmds))

core.module.io.req <> io.req
core.module.io.prefetch <> io.prefetch
io.tlb <> core.module.io.tlb
io.busy := xactTracker.io.busy
core.module.io.flush := io.flush

xactTracker.io.alloc <> core.module.io.reserve
xactTracker.io.peek.xactid := RegEnableThru(core.module.io.beatData.bits.xactid, beatPacker.io.req.fire())
xactTracker.io.peek.pop := beatPacker.io.in.fire() && core.module.io.beatData.bits.last
xactTracker.io.peek.xactid := RegEnableThru(core.module.io.beatData.bits.xactid, core.module.io.beatData.fire())
xactTracker.io.peek.pop := core.module.io.beatData.fire() && core.module.io.beatData.bits.last

core.module.io.beatData.ready := beatPacker.io.in.ready
beatPacker.io.req.valid := core.module.io.beatData.valid
core.module.io.beatData.ready := beatPacker.io.in.ready || core.module.io.beatData.bits.is_hintack
beatPacker.io.req.valid := core.module.io.beatData.valid && !core.module.io.beatData.bits.is_hintack
beatPacker.io.req.bits := xactTracker.io.peek.entry
beatPacker.io.req.bits.lg_len_req := core.module.io.beatData.bits.lg_len_req
beatPacker.io.in.valid := core.module.io.beatData.valid
Expand All @@ -106,6 +108,7 @@ class StreamReadBeat (val nXacts: Int, val beatBits: Int, val maxReqBytes: Int)
val data = UInt(beatBits.W)
val lg_len_req = UInt(log2Up(log2Up(maxReqBytes+1)+1).W)
val last = Bool()
val is_hintack = Bool()
}

// TODO StreamReaderCore and StreamWriter are actually very alike. Is there some parent class they could both inherit from?
Expand All @@ -131,6 +134,7 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf

val io = IO(new Bundle {
val req = Flipped(Decoupled(new StreamReadRequest(spad_rows, acc_rows, config.mvin_scale_t_bits)))
val prefetch = Flipped(Decoupled(new RoCCCommand))
val reserve = new XactTrackerAllocIO(nXacts, maxBytes, spadWidth, accWidth, spad_rows, acc_rows, maxBytes, config.mvin_scale_t_bits, nCmds)
val beatData = Decoupled(new StreamReadBeat(nXacts, beatBits, maxBytes))
val tlb = new FrontendTLBIO
Expand Down Expand Up @@ -193,17 +197,53 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf
lgSize = read_lg_size
)._2

val prefetch = edge.Hint(
fromSource = io.reserve.xactid,
toAddress = 0.U,
lgSize = 1.U,
param = TLHints.PREFETCH_READ
)._2

class TLBundleAWithInfo extends Bundle {
val tl_a = DataMirror.internal.chiselTypeClone[TLBundleA](tl.a.bits)
val vaddr = Output(UInt(vaddrBits.W))
val status = Output(new MStatus)
}

val untranslated_a = Wire(Decoupled(new TLBundleAWithInfo))
untranslated_a.valid := state === s_req_new_block && io.reserve.ready
untranslated_a.bits.tl_a := get
untranslated_a.bits.vaddr := read_vaddr
untranslated_a.bits.status := req.status
untranslated_a.valid := false.B
untranslated_a.bits := DontCare
io.prefetch.ready := false.B
io.reserve.valid := false.B
when (state === s_req_new_block) {
io.reserve.valid := untranslated_a.ready
untranslated_a.valid := io.reserve.ready
untranslated_a.bits.tl_a := get
untranslated_a.bits.vaddr := read_vaddr
untranslated_a.bits.status := req.status

when (untranslated_a.fire()) {
val next_vaddr = req.vaddr + read_bytes_read // send_size
val new_page = next_vaddr(pgIdxBits-1, 0) === 0.U
req.vaddr := next_vaddr

bytesRequested := bytesRequested + read_bytes_read // send_size

// when (send_size >= bytesLeft) {
when (read_bytes_read >= bytesLeft) {
// We're done with this request at this point
state_machine_ready_for_req := true.B
state := s_idle
}
}
} .elsewhen (io.prefetch.valid) {
io.reserve.valid := untranslated_a.ready
untranslated_a.valid := io.reserve.ready
io.prefetch.ready := untranslated_a.ready && io.reserve.ready
untranslated_a.bits.tl_a := prefetch
untranslated_a.bits.vaddr := io.prefetch.bits.rs1
untranslated_a.bits.status := io.prefetch.bits.status
}

// 0 goes to retries, 1 goes to state machine
val retry_a = Wire(Decoupled(new TLBundleAWithInfo))
Expand Down Expand Up @@ -233,7 +273,6 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf
tl.a.bits := translate_q.io.deq.bits.tl_a
tl.a.bits.address := io.tlb.resp.paddr

io.reserve.valid := state === s_req_new_block && untranslated_a.ready // TODO decouple "reserve.valid" from "tl.a.ready"
io.reserve.entry.shift := read_shift
io.reserve.entry.is_acc := req.is_acc
io.reserve.entry.accumulate := req.accumulate
Expand All @@ -253,20 +292,6 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf
if (bytesRequested.getWidth >= log2Up(spadWidthBytes+1)) bytesRequested / spadWidthBytes.U else 0.U)
io.reserve.entry.spad_row_offset := Mux(req.has_acc_bitwidth, bytesRequested % accWidthBytes.U, bytesRequested % spadWidthBytes.U)

when (untranslated_a.fire()) {
val next_vaddr = req.vaddr + read_bytes_read // send_size
val new_page = next_vaddr(pgIdxBits-1, 0) === 0.U
req.vaddr := next_vaddr

bytesRequested := bytesRequested + read_bytes_read // send_size

// when (send_size >= bytesLeft) {
when (read_bytes_read >= bytesLeft) {
// We're done with this request at this point
state_machine_ready_for_req := true.B
state := s_idle
}
}

// Forward TileLink read responses to the reservation buffer
tl.d.ready := io.beatData.ready
Expand All @@ -275,6 +300,7 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf
io.beatData.bits.data := tl.d.bits.data
io.beatData.bits.lg_len_req := tl.d.bits.size
io.beatData.bits.last := edge.last(tl.d)
io.beatData.bits.is_hintack := tl.d.bits.opcode === TLMessages.HintAck
// TODO the size data is already returned from TileLink, so there's no need for us to store it in the XactTracker ourselves

// Accepting requests to kick-start the state machine
Expand Down
88 changes: 68 additions & 20 deletions src/main/scala/gemmini/LoopMatmul.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
val io = IO(new Bundle {
val req = Flipped(Decoupled(new LoopMatmulLdAReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops)))
val cmd = Decoupled(Output(new RoCCCommand))
val prefetch = Output(Valid(new RoCCCommand))
val i = Output(UInt(iterator_bitwidth.W))
val k = Output(UInt(iterator_bitwidth.W))
val idle = Output(Bool())
Expand Down Expand Up @@ -72,17 +73,31 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
mvin_cmd.rs1 := dram_addr
mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr

io.req.ready := state === idle
io.i := i
io.k := k
io.idle := state === idle

io.cmd.valid := state =/= idle && !io.rob_overloaded
io.cmd.bits := mvin_cmd

class CmdQEntry extends Bundle {
val cmd = new RoCCCommand
val i = UInt()
val k = UInt()
}
val cmd_q = Module(new Queue(new CmdQEntry, 8))
cmd_q.io.enq.valid := state =/= idle && !io.rob_overloaded
cmd_q.io.enq.bits.cmd := mvin_cmd
cmd_q.io.enq.bits.i := i
cmd_q.io.enq.bits.k := k

io.cmd.valid := cmd_q.io.deq.valid
cmd_q.io.deq.ready := io.cmd.ready
io.cmd.bits := cmd_q.io.deq.bits.cmd

io.prefetch.valid := cmd_q.io.enq.fire()
io.prefetch.bits := cmd_q.io.enq.bits.cmd

io.req.ready := state === idle && !cmd_q.io.deq.valid
io.i := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.i, 0.U)
io.k := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.k, 0.U)
io.loop_id := req.loop_id
io.idle := state === idle && !cmd_q.io.deq.valid

when (io.cmd.fire()) {
when (cmd_q.io.enq.fire()) {
// The order here is k, j, i
val i_blocks = Mux(req.transpose, max_blocks, 1.U)
val k_blocks = Mux(req.transpose, 1.U, max_blocks)
Expand Down Expand Up @@ -126,6 +141,7 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
val io = IO(new Bundle {
val req = Flipped(Decoupled(new LoopMatmulLdBReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops)))
val cmd = Decoupled(Output(new RoCCCommand))
val prefetch = Output(Valid(new RoCCCommand))

val k = Output(UInt(iterator_bitwidth.W))
val j = Output(UInt(iterator_bitwidth.W))
Expand Down Expand Up @@ -173,17 +189,31 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
mvin_cmd.rs1 := dram_addr
mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr

io.req.ready := state === idle
io.k := k
io.j := j
io.idle := state === idle

io.cmd.valid := state =/= idle && !io.rob_overloaded
io.cmd.bits := mvin_cmd

class CmdQEntry extends Bundle {
val cmd = new RoCCCommand
val k = UInt()
val j = UInt()
}
val cmd_q = Module(new Queue(new CmdQEntry, 8))
cmd_q.io.enq.valid := state =/= idle && !io.rob_overloaded
cmd_q.io.enq.bits.cmd := mvin_cmd
cmd_q.io.enq.bits.k := k
cmd_q.io.enq.bits.j := j

io.cmd.valid := cmd_q.io.deq.valid
cmd_q.io.deq.ready := io.cmd.ready
io.cmd.bits := cmd_q.io.deq.bits.cmd

io.prefetch.valid := cmd_q.io.enq.fire()
io.prefetch.bits := cmd_q.io.enq.bits.cmd

io.req.ready := state === idle && !cmd_q.io.deq.valid
io.k := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.k, 0.U)
io.j := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.j, 0.U)
io.loop_id := req.loop_id
io.idle := state === idle && !cmd_q.io.deq.valid

when (io.cmd.fire()) {
when (cmd_q.io.enq.fire()) {
// The order here is k, j, i
val j_blocks = Mux(req.transpose, 1.U, max_blocks)
val k_blocks = Mux(req.transpose, max_blocks, 1.U)
Expand Down Expand Up @@ -613,6 +643,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds:
val io = IO(new Bundle {
val in = Flipped(Decoupled(new RoCCCommand))
val out = Decoupled(new RoCCCommand)
val prefetch = Decoupled(new RoCCCommand)
val ld_utilization = Input(UInt(log2Up(rob_size+1).W))
val st_utilization = Input(UInt(log2Up(rob_size+1).W))
val ex_utilization = Input(UInt(log2Up(rob_size+1).W))
Expand Down Expand Up @@ -652,6 +683,23 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds:
ldab_arb.io.forceA := !ab_loads_on_same_loop && ldA.io.loop_id === head_loop_id
ldab_arb.io.forceB := !ab_loads_on_same_loop && ldB.io.loop_id === head_loop_id

val prefetch_arb = Module(new Arbiter(new RoCCCommand, 2))
prefetch_arb.io.in(0).valid := ldA.io.prefetch.fire()
prefetch_arb.io.in(0).bits := ldA.io.prefetch.bits
prefetch_arb.io.in(1).valid := ldB.io.prefetch.fire()
prefetch_arb.io.in(1).bits := ldB.io.prefetch.bits
val prefetch_q_size = 4
val prefetch_q = Module(new Queue(new RoCCCommand, prefetch_q_size, pipe=true))
io.prefetch <> prefetch_q.io.deq
io.prefetch.bits.status := cmd.bits.status // TODO This is not guaranteed to be the correct fix! We must fix this
prefetch_q.io.enq <> prefetch_arb.io.out
when (prefetch_q.io.enq.valid && prefetch_q.io.count === prefetch_q_size.U) {
prefetch_q.io.deq.ready := true.B
}

io.busy := cmd.valid || loop_configured


// Create global arbiter
val arb = Module(new Arbiter(new RoCCCommand(), 4))
arb.io.in(0) <> stC.io.cmd
Expand Down Expand Up @@ -902,13 +950,13 @@ object LoopMatmul {
def apply(in: DecoupledIO[RoCCCommand], ld_utilization: UInt, st_utilization: UInt, ex_utilization: UInt,
block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: Int, max_exs: Int, max_sts: Int,
max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int)
(implicit p: Parameters): Tuple2[DecoupledIO[RoCCCommand], Bool] = {
(implicit p: Parameters): (DecoupledIO[RoCCCommand], Bool, DecoupledIO[RoCCCommand]) = {
val mod = Module(new LoopMatmul(block_size, coreMaxAddrBits, rob_size, max_lds, max_exs, max_sts,
max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes))
mod.io.in <> in
mod.io.ld_utilization := ld_utilization
mod.io.st_utilization := st_utilization
mod.io.ex_utilization := ex_utilization
(mod.io.out, mod.io.busy)
(mod.io.out, mod.io.busy, mod.io.prefetch)
}
}
3 changes: 3 additions & 0 deletions src/main/scala/gemmini/Scratchpad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
// Misc. ports
val busy = Output(Bool())
val flush = Input(Bool())

val prefetch = Flipped(Decoupled(new RoCCCommand))
})

val write_dispatch_q = Queue(io.dma.write.req)
Expand Down Expand Up @@ -297,6 +299,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
reader.module.io.req.bits.block_stride := read_issue_q.io.deq.bits.block_stride
reader.module.io.req.bits.status := read_issue_q.io.deq.bits.status
reader.module.io.req.bits.cmd_id := read_issue_q.io.deq.bits.cmd_id
reader.module.io.prefetch <> io.prefetch

val (mvin_scale_in, mvin_scale_out) = VectorScalarMultiplier(
config.mvin_scale_args,
Expand Down