diff --git a/src/main/scala/gemmini/ExecuteController.scala b/src/main/scala/gemmini/ExecuteController.scala index 300dab16..e6ddd028 100644 --- a/src/main/scala/gemmini/ExecuteController.scala +++ b/src/main/scala/gemmini/ExecuteController.scala @@ -23,7 +23,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val srams = new Bundle { val read = Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width)) - val write = Vec(sp_banks, new ScratchpadWriteIO(sp_bank_entries, sp_width, (sp_width / (aligned_to * 8)) max 1)) + val write = Vec(sp_banks, Valid(new ScratchpadWriteBundle(sp_bank_entries, sp_width, (sp_width / (aligned_to * 8)) max 1))) } val acc = new Bundle { @@ -882,15 +882,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In }))) if (ex_write_to_spad) { - io.srams.write(i).en := start_array_outputting && w_bank === i.U && !write_to_acc && !is_garbage_addr && write_this_row - io.srams.write(i).addr := w_row - io.srams.write(i).data := activated_wdata.asUInt() - io.srams.write(i).mask := w_mask.flatMap(b => Seq.fill(inputType.getWidth / (aligned_to * 8))(b)) + io.srams.write(i).valid := start_array_outputting && w_bank === i.U && !write_to_acc && !is_garbage_addr && write_this_row + io.srams.write(i).bits.addr := w_row + io.srams.write(i).bits.data := activated_wdata.asUInt() + io.srams.write(i).bits.mask := w_mask.flatMap(b => Seq.fill(inputType.getWidth / (aligned_to * 8))(b)) } else { - io.srams.write(i).en := false.B - io.srams.write(i).addr := DontCare - io.srams.write(i).data := DontCare - io.srams.write(i).mask := DontCare + io.srams.write(i).valid := false.B + io.srams.write(i).bits := DontCare } } diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index f4812cc6..467df24d 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -86,11 +86,10 @@ class ScratchpadReadIO(val n: Int, val w: Int) extends Bundle { val resp = Flipped(Decoupled(new ScratchpadReadResp(w))) } -class ScratchpadWriteIO(val n: Int, val w: Int, val mask_len: Int) extends Bundle { - val en = Output(Bool()) - val addr = Output(UInt(log2Ceil(n).W)) - val mask = Output(Vec(mask_len, Bool())) - val data = Output(UInt(w.W)) +class ScratchpadWriteBundle(val n: Int, val w: Int, val mask_len: Int) extends Bundle { + val addr = UInt(log2Ceil(n).W) + val mask = Vec(mask_len, Bool()) + val data = UInt(w.W) } class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int, single_ported: Boolean) extends Module { @@ -101,27 +100,39 @@ class ScratchpadBank(n: Int, w: Int, mem_pipeline: Int, aligned_to: Int, single_ val mask_elem = UInt((w min (aligned_to * 8)).W) // What datatype does each mask bit correspond to? val io = IO(new Bundle { + // Priority for single port is valid_write > read > decoupled_write val read = Flipped(new ScratchpadReadIO(n, w)) - val write = Flipped(new ScratchpadWriteIO(n, w, mask_len)) + val decoupled_write = Flipped(Decoupled(new ScratchpadWriteBundle(n, w, mask_len))) + val valid_write = Flipped(Valid(new ScratchpadWriteBundle(n, w, mask_len))) }) val mem = SyncReadMem(n, Vec(mask_len, mask_elem)) // When the scratchpad is single-ported, the writes take precedence - val singleport_busy_with_write = single_ported.B && io.write.en + val singleport_busy_with_write = single_ported.B && io.valid_write.valid + + val wen = io.valid_write.valid || (io.decoupled_write.valid && (!single_ported.B || !io.read.req.valid)) + val wr = Mux(io.valid_write.valid, io.valid_write.bits, io.decoupled_write.bits) + io.decoupled_write.ready := !io.valid_write.valid && (!single_ported.B || !io.read.req.valid) - when (io.write.en) { - if (aligned_to >= w) - mem.write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem))) - else - mem.write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem)), io.write.mask) + when (wen) { + val mask = if (aligned_to >= w) { + VecInit((~(0.U(mask_len.W))).asBools) + } else { + wr.mask + } + mem.write( + wr.addr, + wr.data.asTypeOf(Vec(mask_len, mask_elem)), + mask + ) } val raddr = io.read.req.bits.addr val ren = io.read.req.fire() val rdata = if (single_ported) { - assert(!(ren && io.write.en)) - mem.read(raddr, ren && !io.write.en).asUInt() + assert(!(ren && io.valid_write.valid)) + mem.read(raddr, ren && !wen).asUInt() } else { mem.read(raddr, ren).asUInt() } @@ -184,7 +195,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // SRAM ports val srams = new Bundle { val read = Flipped(Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, spad_w))) - val write = Flipped(Vec(sp_banks, new ScratchpadWriteIO(sp_bank_entries, spad_w, (spad_w / (aligned_to * 8)) max 1))) + val write = Flipped(Vec(sp_banks, Valid(new ScratchpadWriteBundle(sp_bank_entries, spad_w, (spad_w / (aligned_to * 8)) max 1)))) } // Accumulator ports @@ -382,9 +393,9 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val exread = ex_read_req.valid // TODO we tie the write dispatch queue's, and write issue queue's, ready and valid signals together here - val dmawrite = write_dispatch_q.valid && write_scale_q.io.enq.ready && + val dmawrite = write_dispatch_q.valid && write_issue_q.io.enq.ready && + bio.read.req.ready && !write_dispatch_q.bits.laddr.is_garbage() && - !(bio.write.en && config.sp_singleported.B) && !write_dispatch_q.bits.laddr.is_acc_addr && write_dispatch_q.bits.laddr.sp_bank() === i.U bio.read.req.valid := exread || dmawrite @@ -420,7 +431,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // Writing to the SRAM banks bank_ios.zipWithIndex.foreach { case (bio, i) => - val exwrite = io.srams.write(i).en + val exwrite = io.srams.write(i) val laddr = mvin_scale_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_out.bits.row @@ -433,33 +444,27 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, zero_writer.io.resp.bits.laddr.sp_bank() === i.U && !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) - bio.write.en := exwrite || dmaread || zerowrite + bio.valid_write := exwrite - when (exwrite) { - bio.write.addr := io.srams.write(i).addr - bio.write.data := io.srams.write(i).data - bio.write.mask := io.srams.write(i).mask - }.elsewhen (dmaread) { - bio.write.addr := laddr.sp_row() - bio.write.data := mvin_scale_out.bits.out.asUInt() - bio.write.mask := mvin_scale_out.bits.tag.mask take ((spad_w / (aligned_to * 8)) max 1) - - mvin_scale_out.ready := true.B // TODO we combinationally couple valid and ready signals + bio.decoupled_write.valid := dmaread || zerowrite + bio.decoupled_write.bits := DontCare + when (dmaread) { + bio.decoupled_write.bits.addr := laddr.sp_row() + bio.decoupled_write.bits.data := mvin_scale_out.bits.out.asUInt() + bio.decoupled_write.bits.mask := mvin_scale_out.bits.tag.mask take ((spad_w / (aligned_to * 8)) max 1) + + mvin_scale_out.ready := bio.decoupled_write.ready }.elsewhen (zerowrite) { - bio.write.addr := zero_writer.io.resp.bits.laddr.sp_row() - bio.write.data := 0.U - bio.write.mask := { + bio.decoupled_write.bits.addr := zero_writer.io.resp.bits.laddr.sp_row() + bio.decoupled_write.bits.data := 0.U + bio.decoupled_write.bits.mask := { val n = inputType.getWidth / 8 val mask = zero_writer.io.resp.bits.mask val expanded = VecInit(mask.flatMap(e => Seq.fill(n)(e))) expanded } - zero_writer.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals - }.otherwise { - bio.write.addr := DontCare - bio.write.data := DontCare - bio.write.mask := DontCare + zero_writer.io.resp.ready := bio.decoupled_write.ready } } }