diff --git a/core/src/main/scala/chisel3/Mem.scala b/core/src/main/scala/chisel3/Mem.scala index caf6fb31f1e..7001997a266 100644 --- a/core/src/main/scala/chisel3/Mem.scala +++ b/core/src/main/scala/chisel3/Mem.scala @@ -193,8 +193,6 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf idx: UInt, writeData: T, mask: Seq[Bool] - )( - implicit evidence: T <:< Vec[_] ): Unit = macro SourceInfoTransform.idxDataMaskArg def do_write( @@ -202,8 +200,7 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf data: T, mask: Seq[Bool] )( - implicit evidence: T <:< Vec[_], - sourceInfo: SourceInfo + implicit sourceInfo: SourceInfo ): Unit = masked_write_impl(idx, data, mask, Builder.forcedClock, true) @@ -223,8 +220,6 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf writeData: T, mask: Seq[Bool], clock: Clock - )( - implicit evidence: T <:< Vec[_] ): Unit = macro SourceInfoTransform.idxDataMaskClockArg def do_write( @@ -233,8 +228,7 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf mask: Seq[Bool], clock: Clock )( - implicit evidence: T <:< Vec[_], - sourceInfo: SourceInfo + implicit sourceInfo: SourceInfo ): Unit = masked_write_impl(idx, data, mask, clock, false) @@ -245,21 +239,20 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf clock: Clock, warn: Boolean )( - implicit evidence: T <:< Vec[_], - sourceInfo: SourceInfo + implicit sourceInfo: SourceInfo ): Unit = { if (warn && clockInst.isDefined && clock != clockInst.get) { clockWarning(None, MemPortDirection.WRITE) } - val accessor = makePort(sourceInfo, idx, MemPortDirection.WRITE, clock).asInstanceOf[Vec[Data]] - val dataVec = data.asInstanceOf[Vec[Data]] - if (accessor.length != dataVec.length) { - Builder.error(s"Mem write data must contain ${accessor.length} elements (found ${dataVec.length})") + val accessorVec: Seq[Data] = seqFromData(makePort(sourceInfo, idx, MemPortDirection.WRITE, clock)) + val dataVec: Seq[Data] = seqFromData(data) + if (accessorVec.length != dataVec.length) { + Builder.error(s"Mem write data must contain ${accessorVec.length} elements (found ${dataVec.length})") } - if (accessor.length != mask.length) { - Builder.error(s"Mem write mask must contain ${accessor.length} elements (found ${mask.length})") + if (accessorVec.length != mask.length) { + Builder.error(s"Mem write mask must contain ${accessorVec.length} elements (found ${mask.length})") } - for (((cond, port), datum) <- mask.zip(accessor).zip(dataVec)) + for (((cond, port), datum) <- mask.zip(accessorVec).zip(dataVec)) when(cond) { port := datum } } @@ -284,6 +277,12 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf port.bind(MemoryPortBinding(Builder.forcedUserModule, Builder.currentWhen)) port } + + private[chisel3] def seqFromData(data: Data): Seq[Data] = data match { + case a: Record => a.getElements.flatMap(seqFromData) + case a: Vec[_] => a.allElements.flatMap(seqFromData) + case _ => Seq(data) + } } /** A combinational/asynchronous-read, sequential/synchronous-write memory. @@ -551,8 +550,6 @@ sealed class SyncReadMem[T <: Data] private[chisel3] ( mask: Seq[Bool], en: Bool, isWrite: Bool - )( - implicit evidence: T <:< Vec[_] ): T = macro SourceInfoTransform.idxDataMaskEnIswArg def do_readWrite( @@ -562,8 +559,7 @@ sealed class SyncReadMem[T <: Data] private[chisel3] ( en: Bool, isWrite: Bool )( - implicit evidence: T <:< Vec[_], - sourceInfo: SourceInfo + implicit sourceInfo: SourceInfo ): T = masked_readWrite_impl(idx, writeData, mask, en, isWrite, Builder.forcedClock, true) /** Generates an explicit read-write port for this SyncReadMem, with a bytemask for @@ -591,8 +587,6 @@ sealed class SyncReadMem[T <: Data] private[chisel3] ( en: Bool, isWrite: Bool, clock: Clock - )( - implicit evidence: T <:< Vec[_] ): T = macro SourceInfoTransform.idxDataMaskEnIswClockArg def do_readWrite( @@ -603,8 +597,7 @@ sealed class SyncReadMem[T <: Data] private[chisel3] ( isWrite: Bool, clock: Clock )( - implicit evidence: T <:< Vec[_], - sourceInfo: SourceInfo + implicit sourceInfo: SourceInfo ) = masked_readWrite_impl(idx, writeData, mask, en, isWrite, clock, false) private def masked_readWrite_impl( @@ -616,18 +609,19 @@ sealed class SyncReadMem[T <: Data] private[chisel3] ( clock: Clock, warn: Boolean )( - implicit evidence: T <:< Vec[_], - sourceInfo: SourceInfo + implicit sourceInfo: SourceInfo ): T = { var _port: Option[T] = None val _a = WireDefault(chiselTypeOf(addr), DontCare) when(enable) { _a := addr _port = Some(super.do_apply_impl(_a, clock, MemPortDirection.RDWR, warn)) - val accessor = _port.get.asInstanceOf[Vec[Data]] + val accessor = seqFromData(_port.get) when(isWrite) { - val dataVec = data.asInstanceOf[Vec[Data]] + val dataVec = seqFromData(data) + println(dataVec) + println(dataVec.size) if (accessor.length != dataVec.length) { Builder.error(s"Mem write data must contain ${accessor.length} elements (found ${dataVec.length})") } diff --git a/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala b/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala index f282b26c464..c50675c7d37 100644 --- a/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala +++ b/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala @@ -247,12 +247,12 @@ class SourceInfoTransform(val c: Context) extends AutoSourceTransform { q"$thisObj.$doFuncTerm($idx, $writeData, $en, $isWrite)($implicitSourceInfo)" } - def idxDataMaskArg(idx: c.Tree, writeData: c.Tree, mask: c.Tree)(evidence: c.Tree): c.Tree = { - q"$thisObj.$doFuncTerm($idx, $writeData, $mask)($evidence, $implicitSourceInfo)" + def idxDataMaskArg(idx: c.Tree, writeData: c.Tree, mask: c.Tree): c.Tree = { + q"$thisObj.$doFuncTerm($idx, $writeData, $mask)($implicitSourceInfo)" } - def idxDataMaskClockArg(idx: c.Tree, writeData: c.Tree, mask: c.Tree, clock: c.Tree)(evidence: c.Tree): c.Tree = { - q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $clock)($evidence, $implicitSourceInfo)" + def idxDataMaskClockArg(idx: c.Tree, writeData: c.Tree, mask: c.Tree, clock: c.Tree): c.Tree = { + q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $clock)($implicitSourceInfo)" } def idxDataEnIswClockArg(idx: c.Tree, writeData: c.Tree, en: c.Tree, isWrite: c.Tree, clock: c.Tree): c.Tree = { @@ -265,9 +265,8 @@ class SourceInfoTransform(val c: Context) extends AutoSourceTransform { mask: c.Tree, en: c.Tree, isWrite: c.Tree - )(evidence: c.Tree ): c.Tree = { - q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $en, $isWrite)($evidence, $implicitSourceInfo)" + q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $en, $isWrite)($implicitSourceInfo)" } def idxDataMaskEnIswClockArg( @@ -277,9 +276,8 @@ class SourceInfoTransform(val c: Context) extends AutoSourceTransform { en: c.Tree, isWrite: c.Tree, clock: c.Tree - )(evidence: c.Tree ): c.Tree = { - q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $en, $isWrite, $clock)($evidence, $implicitSourceInfo)" + q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $en, $isWrite, $clock)($implicitSourceInfo)" } def xEnArg(x: c.Tree, en: c.Tree): c.Tree = { diff --git a/src/main/scala/chisel3/util/SRAM.scala b/src/main/scala/chisel3/util/SRAM.scala index edcfb679971..e2a9fc73895 100644 --- a/src/main/scala/chisel3/util/SRAM.scala +++ b/src/main/scala/chisel3/util/SRAM.scala @@ -164,7 +164,6 @@ object SRAM { Seq.fill(numWritePorts)(clock), Seq.fill(numReadwritePorts)(clock), None, - None, sourceInfo ) } @@ -203,7 +202,6 @@ object SRAM { Seq.fill(numWritePorts)(clock), Seq.fill(numReadwritePorts)(clock), Some(memoryFile), - None, sourceInfo ) } @@ -239,7 +237,6 @@ object SRAM { writePortClocks, readwritePortClocks, None, - None, sourceInfo ) @@ -276,7 +273,6 @@ object SRAM { writePortClocks, readwritePortClocks, Some(memoryFile), - None, sourceInfo ) @@ -302,8 +298,7 @@ object SRAM { numWritePorts: Int, numReadwritePorts: Int )( - implicit evidence: T <:< Vec[_], - sourceInfo: SourceInfo + implicit sourceInfo: SourceInfo ): SRAMInterface[T] = { val clock = Builder.forcedClock memInterface_impl( @@ -313,7 +308,6 @@ object SRAM { Seq.fill(numWritePorts)(clock), Seq.fill(numReadwritePorts)(clock), None, - Some(evidence), sourceInfo ) } @@ -342,8 +336,7 @@ object SRAM { numReadwritePorts: Int, memoryFile: MemoryFile )( - implicit evidence: T <:< Vec[_], - sourceInfo: SourceInfo + implicit sourceInfo: SourceInfo ): SRAMInterface[T] = { val clock = Builder.forcedClock memInterface_impl( @@ -353,7 +346,6 @@ object SRAM { Seq.fill(numWritePorts)(clock), Seq.fill(numReadwritePorts)(clock), Some(memoryFile), - Some(evidence), sourceInfo ) } @@ -380,8 +372,7 @@ object SRAM { writePortClocks: Seq[Clock], readwritePortClocks: Seq[Clock] )( - implicit evidence: T <:< Vec[_], - sourceInfo: SourceInfo + implicit sourceInfo: SourceInfo ): SRAMInterface[T] = memInterface_impl( size, @@ -390,7 +381,6 @@ object SRAM { writePortClocks, readwritePortClocks, None, - Some(evidence), sourceInfo ) @@ -418,8 +408,7 @@ object SRAM { readwritePortClocks: Seq[Clock], memoryFile: MemoryFile )( - implicit evidence: T <:< Vec[_], - sourceInfo: SourceInfo + implicit sourceInfo: SourceInfo ): SRAMInterface[T] = memInterface_impl( size, @@ -428,7 +417,6 @@ object SRAM { writePortClocks, readwritePortClocks, Some(memoryFile), - Some(evidence), sourceInfo ) @@ -439,13 +427,11 @@ object SRAM { writePortClocks: Seq[Clock], readwritePortClocks: Seq[Clock], memoryFile: Option[MemoryFile], - evidenceOpt: Option[T <:< Vec[_]], sourceInfo: SourceInfo ): SRAMInterface[T] = { val numReadPorts = readPortClocks.size val numWritePorts = writePortClocks.size val numReadwritePorts = readwritePortClocks.size - val isVecMem = evidenceOpt.isDefined val isValidSRAM = ((numReadPorts + numReadwritePorts) > 0) && ((numWritePorts + numReadwritePorts) > 0) if (!isValidSRAM) { @@ -459,7 +445,7 @@ object SRAM { ) } - val _out = Wire(new SRAMInterface(size, tpe, numReadPorts, numWritePorts, numReadwritePorts, isVecMem)) + val _out = Wire(new SRAMInterface(size, tpe, numReadPorts, numWritePorts, numReadwritePorts, true)) val mem = SyncReadMem(size, tpe) for ((clock, port) <- readPortClocks.zip(_out.readPorts)) { @@ -468,40 +454,26 @@ object SRAM { for ((clock, port) <- writePortClocks.zip(_out.writePorts)) { when(port.enable) { - if (isVecMem) { - mem.write( - port.address, - port.data, - port.mask.get, - clock - )(evidenceOpt.get) - } else { - mem.write(port.address, port.data, clock) - } - } - } - - for ((clock, port) <- readwritePortClocks.zip(_out.readwritePorts)) { - if (isVecMem) { - port.readData := mem.readWrite( + mem.write( port.address, - port.writeData, + port.data, port.mask.get, - port.enable, - port.isWrite, - clock - )(evidenceOpt.get) - } else { - port.readData := mem.readWrite( - port.address, - port.writeData, - port.enable, - port.isWrite, clock ) } } + for ((clock, port) <- readwritePortClocks.zip(_out.readwritePorts)) { + port.readData := mem.readWrite( + port.address, + port.writeData, + port.mask.get, + port.enable, + port.isWrite, + clock + ) + } + // Emit Verilog for preloading the memory from a file if requested memoryFile.foreach { file: MemoryFile => loadMemoryFromFileInline(mem, file.path, file.fileType) } diff --git a/src/test/scala/chiselTests/Mem.scala b/src/test/scala/chiselTests/Mem.scala index a854a0195c7..a11251c5a84 100644 --- a/src/test/scala/chiselTests/Mem.scala +++ b/src/test/scala/chiselTests/Mem.scala @@ -446,6 +446,36 @@ class MemorySpec extends ChiselPropSpec { } ChiselStage.emitSystemVerilog(new TestModule) } + + property("Bundle-typed memory with masked writes should compile") { + class MyMemoryType extends Bundle { + val data = UInt(64.W) + val pNext = UInt(6.W) + val gNext = UInt(6.W) + val gPrev = UInt(6.W) + } + + class Foo extends Module { + val width: Int = 8 + val io = IO(new Bundle { + val enable = Input(Bool()) + val write = Input(Bool()) + val addr = Input(UInt(10.W)) + val mask = Input(Vec(4, Bool())) + val dataIn = Input(new MyMemoryType) + val dataOut = Output(new MyMemoryType) + }) + + // Create a 32-bit wide memory that is byte-masked + val mem = SyncReadMem(1024, new MyMemoryType) + // Write with mask + mem.write(io.addr, io.dataIn, io.mask) + io.dataOut := mem.read(io.addr, io.enable) + } + + ChiselStage.emitSystemVerilog(new Foo, firtoolOpts = Array("-disable-all-randomization", "-strip-debug-info")) + } + } class SRAMSpec extends ChiselFunSpec { @@ -609,4 +639,28 @@ class SRAMSpec extends ChiselFunSpec { } } } + + describe("SRAM with bundle ports and masked writes") { + it(s"should compile to Verilog") { + class Type extends Bundle { + val a = UInt(8.W) + val b = Vec(2, UInt(4.W)) + } + class TestModule extends Module { + val mem = SyncReadMem(32, new Type) + val writeData = Wire(new Type) + writeData.a := DontCare + writeData.b(0) := DontCare + writeData.b(1) := 15.U + mem.readWrite(4.U, writeData, Seq(false.B, false.B, true.B), true.B, true.B, Module.clock) + } + println( + ChiselStage.emitSystemVerilog( + new TestModule, + Array("--full-stacktrace", "--throw-on-first-error"), + firtoolOpts = Array("-disable-all-randomization", "-strip-debug-info", "-disable-opt") + ) + ) + } + } }