Skip to content

Commit

Permalink
Try removing memory evidence of Vec
Browse files Browse the repository at this point in the history
Signed-off-by: Schuyler Eldridge <[email protected]>
  • Loading branch information
seldridge committed Aug 1, 2023
1 parent 28e235e commit d8e5bfb
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 83 deletions.
50 changes: 21 additions & 29 deletions core/src/main/scala/chisel3/Mem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,14 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf
idx: UInt,
writeData: T,
mask: Seq[Bool]
)(
implicit evidence: T <:< Vec[_]
): Unit = macro SourceInfoTransform.idxDataMaskArg

def do_write(
idx: UInt,
data: T,
mask: Seq[Bool]
)(
implicit evidence: T <:< Vec[_],
sourceInfo: SourceInfo
implicit sourceInfo: SourceInfo
): Unit =
masked_write_impl(idx, data, mask, Builder.forcedClock, true)

Expand All @@ -223,8 +220,6 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf
writeData: T,
mask: Seq[Bool],
clock: Clock
)(
implicit evidence: T <:< Vec[_]
): Unit = macro SourceInfoTransform.idxDataMaskClockArg

def do_write(
Expand All @@ -233,8 +228,7 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf
mask: Seq[Bool],
clock: Clock
)(
implicit evidence: T <:< Vec[_],
sourceInfo: SourceInfo
implicit sourceInfo: SourceInfo
): Unit =
masked_write_impl(idx, data, mask, clock, false)

Expand All @@ -245,21 +239,20 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf
clock: Clock,
warn: Boolean
)(
implicit evidence: T <:< Vec[_],
sourceInfo: SourceInfo
implicit sourceInfo: SourceInfo
): Unit = {
if (warn && clockInst.isDefined && clock != clockInst.get) {
clockWarning(None, MemPortDirection.WRITE)
}
val accessor = makePort(sourceInfo, idx, MemPortDirection.WRITE, clock).asInstanceOf[Vec[Data]]
val dataVec = data.asInstanceOf[Vec[Data]]
if (accessor.length != dataVec.length) {
Builder.error(s"Mem write data must contain ${accessor.length} elements (found ${dataVec.length})")
val accessorVec: Seq[Data] = seqFromData(makePort(sourceInfo, idx, MemPortDirection.WRITE, clock))
val dataVec: Seq[Data] = seqFromData(data)
if (accessorVec.length != dataVec.length) {
Builder.error(s"Mem write data must contain ${accessorVec.length} elements (found ${dataVec.length})")
}
if (accessor.length != mask.length) {
Builder.error(s"Mem write mask must contain ${accessor.length} elements (found ${mask.length})")
if (accessorVec.length != mask.length) {
Builder.error(s"Mem write mask must contain ${accessorVec.length} elements (found ${mask.length})")
}
for (((cond, port), datum) <- mask.zip(accessor).zip(dataVec))
for (((cond, port), datum) <- mask.zip(accessorVec).zip(dataVec))
when(cond) { port := datum }
}

Expand All @@ -284,6 +277,12 @@ sealed abstract class MemBase[T <: Data](val t: T, val length: BigInt, sourceInf
port.bind(MemoryPortBinding(Builder.forcedUserModule, Builder.currentWhen))
port
}

private[chisel3] def seqFromData(data: Data): Seq[Data] = data match {
case a: Record => a.getElements.flatMap(seqFromData)
case a: Vec[_] => a.allElements.flatMap(seqFromData)
case _ => Seq(data)
}
}

/** A combinational/asynchronous-read, sequential/synchronous-write memory.
Expand Down Expand Up @@ -551,8 +550,6 @@ sealed class SyncReadMem[T <: Data] private[chisel3] (
mask: Seq[Bool],
en: Bool,
isWrite: Bool
)(
implicit evidence: T <:< Vec[_]
): T = macro SourceInfoTransform.idxDataMaskEnIswArg

def do_readWrite(
Expand All @@ -562,8 +559,7 @@ sealed class SyncReadMem[T <: Data] private[chisel3] (
en: Bool,
isWrite: Bool
)(
implicit evidence: T <:< Vec[_],
sourceInfo: SourceInfo
implicit sourceInfo: SourceInfo
): T = masked_readWrite_impl(idx, writeData, mask, en, isWrite, Builder.forcedClock, true)

/** Generates an explicit read-write port for this SyncReadMem, with a bytemask for
Expand Down Expand Up @@ -591,8 +587,6 @@ sealed class SyncReadMem[T <: Data] private[chisel3] (
en: Bool,
isWrite: Bool,
clock: Clock
)(
implicit evidence: T <:< Vec[_]
): T = macro SourceInfoTransform.idxDataMaskEnIswClockArg

def do_readWrite(
Expand All @@ -603,8 +597,7 @@ sealed class SyncReadMem[T <: Data] private[chisel3] (
isWrite: Bool,
clock: Clock
)(
implicit evidence: T <:< Vec[_],
sourceInfo: SourceInfo
implicit sourceInfo: SourceInfo
) = masked_readWrite_impl(idx, writeData, mask, en, isWrite, clock, false)

private def masked_readWrite_impl(
Expand All @@ -616,18 +609,17 @@ sealed class SyncReadMem[T <: Data] private[chisel3] (
clock: Clock,
warn: Boolean
)(
implicit evidence: T <:< Vec[_],
sourceInfo: SourceInfo
implicit sourceInfo: SourceInfo
): T = {
var _port: Option[T] = None
val _a = WireDefault(chiselTypeOf(addr), DontCare)
when(enable) {
_a := addr
_port = Some(super.do_apply_impl(_a, clock, MemPortDirection.RDWR, warn))
val accessor = _port.get.asInstanceOf[Vec[Data]]
val accessor = seqFromData(_port.get)

when(isWrite) {
val dataVec = data.asInstanceOf[Vec[Data]]
val dataVec = seqFromData(data)
if (accessor.length != dataVec.length) {
Builder.error(s"Mem write data must contain ${accessor.length} elements (found ${dataVec.length})")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ class SourceInfoTransform(val c: Context) extends AutoSourceTransform {
q"$thisObj.$doFuncTerm($idx, $writeData, $en, $isWrite)($implicitSourceInfo)"
}

def idxDataMaskArg(idx: c.Tree, writeData: c.Tree, mask: c.Tree)(evidence: c.Tree): c.Tree = {
q"$thisObj.$doFuncTerm($idx, $writeData, $mask)($evidence, $implicitSourceInfo)"
def idxDataMaskArg(idx: c.Tree, writeData: c.Tree, mask: c.Tree): c.Tree = {
q"$thisObj.$doFuncTerm($idx, $writeData, $mask)($implicitSourceInfo)"
}

def idxDataMaskClockArg(idx: c.Tree, writeData: c.Tree, mask: c.Tree, clock: c.Tree)(evidence: c.Tree): c.Tree = {
q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $clock)($evidence, $implicitSourceInfo)"
def idxDataMaskClockArg(idx: c.Tree, writeData: c.Tree, mask: c.Tree, clock: c.Tree): c.Tree = {
q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $clock)($implicitSourceInfo)"
}

def idxDataEnIswClockArg(idx: c.Tree, writeData: c.Tree, en: c.Tree, isWrite: c.Tree, clock: c.Tree): c.Tree = {
Expand All @@ -265,9 +265,8 @@ class SourceInfoTransform(val c: Context) extends AutoSourceTransform {
mask: c.Tree,
en: c.Tree,
isWrite: c.Tree
)(evidence: c.Tree
): c.Tree = {
q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $en, $isWrite)($evidence, $implicitSourceInfo)"
q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $en, $isWrite)($implicitSourceInfo)"
}

def idxDataMaskEnIswClockArg(
Expand All @@ -277,9 +276,8 @@ class SourceInfoTransform(val c: Context) extends AutoSourceTransform {
en: c.Tree,
isWrite: c.Tree,
clock: c.Tree
)(evidence: c.Tree
): c.Tree = {
q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $en, $isWrite, $clock)($evidence, $implicitSourceInfo)"
q"$thisObj.$doFuncTerm($idx, $writeData, $mask, $en, $isWrite, $clock)($implicitSourceInfo)"
}

def xEnArg(x: c.Tree, en: c.Tree): c.Tree = {
Expand Down
64 changes: 18 additions & 46 deletions src/main/scala/chisel3/util/SRAM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ object SRAM {
Seq.fill(numWritePorts)(clock),
Seq.fill(numReadwritePorts)(clock),
None,
None,
sourceInfo
)
}
Expand Down Expand Up @@ -203,7 +202,6 @@ object SRAM {
Seq.fill(numWritePorts)(clock),
Seq.fill(numReadwritePorts)(clock),
Some(memoryFile),
None,
sourceInfo
)
}
Expand Down Expand Up @@ -239,7 +237,6 @@ object SRAM {
writePortClocks,
readwritePortClocks,
None,
None,
sourceInfo
)

Expand Down Expand Up @@ -276,7 +273,6 @@ object SRAM {
writePortClocks,
readwritePortClocks,
Some(memoryFile),
None,
sourceInfo
)

Expand All @@ -302,8 +298,7 @@ object SRAM {
numWritePorts: Int,
numReadwritePorts: Int
)(
implicit evidence: T <:< Vec[_],
sourceInfo: SourceInfo
implicit sourceInfo: SourceInfo
): SRAMInterface[T] = {
val clock = Builder.forcedClock
memInterface_impl(
Expand All @@ -313,7 +308,6 @@ object SRAM {
Seq.fill(numWritePorts)(clock),
Seq.fill(numReadwritePorts)(clock),
None,
Some(evidence),
sourceInfo
)
}
Expand Down Expand Up @@ -342,8 +336,7 @@ object SRAM {
numReadwritePorts: Int,
memoryFile: MemoryFile
)(
implicit evidence: T <:< Vec[_],
sourceInfo: SourceInfo
implicit sourceInfo: SourceInfo
): SRAMInterface[T] = {
val clock = Builder.forcedClock
memInterface_impl(
Expand All @@ -353,7 +346,6 @@ object SRAM {
Seq.fill(numWritePorts)(clock),
Seq.fill(numReadwritePorts)(clock),
Some(memoryFile),
Some(evidence),
sourceInfo
)
}
Expand All @@ -380,8 +372,7 @@ object SRAM {
writePortClocks: Seq[Clock],
readwritePortClocks: Seq[Clock]
)(
implicit evidence: T <:< Vec[_],
sourceInfo: SourceInfo
implicit sourceInfo: SourceInfo
): SRAMInterface[T] =
memInterface_impl(
size,
Expand All @@ -390,7 +381,6 @@ object SRAM {
writePortClocks,
readwritePortClocks,
None,
Some(evidence),
sourceInfo
)

Expand Down Expand Up @@ -418,8 +408,7 @@ object SRAM {
readwritePortClocks: Seq[Clock],
memoryFile: MemoryFile
)(
implicit evidence: T <:< Vec[_],
sourceInfo: SourceInfo
implicit sourceInfo: SourceInfo
): SRAMInterface[T] =
memInterface_impl(
size,
Expand All @@ -428,7 +417,6 @@ object SRAM {
writePortClocks,
readwritePortClocks,
Some(memoryFile),
Some(evidence),
sourceInfo
)

Expand All @@ -439,13 +427,11 @@ object SRAM {
writePortClocks: Seq[Clock],
readwritePortClocks: Seq[Clock],
memoryFile: Option[MemoryFile],
evidenceOpt: Option[T <:< Vec[_]],
sourceInfo: SourceInfo
): SRAMInterface[T] = {
val numReadPorts = readPortClocks.size
val numWritePorts = writePortClocks.size
val numReadwritePorts = readwritePortClocks.size
val isVecMem = evidenceOpt.isDefined
val isValidSRAM = ((numReadPorts + numReadwritePorts) > 0) && ((numWritePorts + numReadwritePorts) > 0)

if (!isValidSRAM) {
Expand All @@ -459,7 +445,7 @@ object SRAM {
)
}

val _out = Wire(new SRAMInterface(size, tpe, numReadPorts, numWritePorts, numReadwritePorts, isVecMem))
val _out = Wire(new SRAMInterface(size, tpe, numReadPorts, numWritePorts, numReadwritePorts, true))
val mem = SyncReadMem(size, tpe)

for ((clock, port) <- readPortClocks.zip(_out.readPorts)) {
Expand All @@ -468,40 +454,26 @@ object SRAM {

for ((clock, port) <- writePortClocks.zip(_out.writePorts)) {
when(port.enable) {
if (isVecMem) {
mem.write(
port.address,
port.data,
port.mask.get,
clock
)(evidenceOpt.get)
} else {
mem.write(port.address, port.data, clock)
}
}
}

for ((clock, port) <- readwritePortClocks.zip(_out.readwritePorts)) {
if (isVecMem) {
port.readData := mem.readWrite(
mem.write(
port.address,
port.writeData,
port.data,
port.mask.get,
port.enable,
port.isWrite,
clock
)(evidenceOpt.get)
} else {
port.readData := mem.readWrite(
port.address,
port.writeData,
port.enable,
port.isWrite,
clock
)
}
}

for ((clock, port) <- readwritePortClocks.zip(_out.readwritePorts)) {
port.readData := mem.readWrite(
port.address,
port.writeData,
port.mask.get,
port.enable,
port.isWrite,
clock
)
}

// Emit Verilog for preloading the memory from a file if requested
memoryFile.foreach { file: MemoryFile => loadMemoryFromFileInline(mem, file.path, file.fileType) }

Expand Down
Loading

0 comments on commit d8e5bfb

Please sign in to comment.