Skip to content

Commit

Permalink
[rtl] add mask control.
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li committed Dec 11, 2024
1 parent 659f37a commit 0be86ac
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 34 deletions.
115 changes: 85 additions & 30 deletions t1/src/Lane.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ case class LaneParameter(
crossLaneVRFWriteEscapeQueueSize: Int,
fpuEnable: Boolean,
portFactor: Int,
maskRequestLatency: Int,
vrfRamType: RamType,
decoderParam: DecoderParam,
vfuInstantiateParameter: VFUInstantiateParameter)
Expand Down Expand Up @@ -325,17 +326,6 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
val vrf: Instance[VRF] = Instantiate(new VRF(parameter.vrfParam))
omInstance.vrfIn := Property(vrf.om.asAnyClassType)

/** TODO: review later
*/
val maskGroupedOrR: UInt = VecInit(
maskInput.asBools
.grouped(parameter.dataPathByteWidth)
.toSeq
.map(
VecInit(_).asUInt.orR
)
).asUInt

val fullMask: UInt = (-1.S(parameter.datapathWidth.W)).asUInt

/** the slot is occupied by instruction */
Expand Down Expand Up @@ -378,7 +368,67 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
val afterCheckValid: Seq[Bool] = Seq.tabulate(parameter.chainingSize + 3) { _ => RegInit(false.B) }
val afterCheckDequeueReady: Vec[Bool] = Wire(Vec(parameter.chainingSize + 3, Bool()))
val afterCheckDequeueFire: Seq[Bool] = afterCheckValid.zip(afterCheckDequeueReady).map { case (v, r) => v && r }
0fdfbce5

// todo: mv to bundle.scala
class MaskControl(parameter: LaneParameter) extends Bundle {
val index: UInt = UInt(parameter.instructionIndexBits.W)
val sew: UInt = UInt(2.W)
val maskData: UInt = UInt(parameter.datapathWidth.W)
val group: UInt = UInt(parameter.maskGroupSizeBits.W)
val dataValid: Bool = Bool()
val waiteResponse: Bool = Bool()
val controlValid: Bool = Bool()
}

val maskControlRelease: Vec[ValidIO[UInt]] =
Wire(Vec(parameter.chainingSize, Valid(UInt(parameter.instructionIndexBits.W))))

val maskControlEnq: UInt = Wire(UInt(parameter.chainingSize.W))
val maskControlDataDeq: UInt = Wire(UInt(parameter.chainingSize.W))
val maskControlReq: Vec[Bool] = Wire(Vec(parameter.chainingSize, Bool()))
val maskControlReqSelect: UInt = ffo(maskControlReq.asUInt)
// mask request & response handle
val maskControlVec: Seq[MaskControl] = Seq.tabulate(parameter.chainingSize) { index =>
val state = RegInit(0.U.asTypeOf(new MaskControl(parameter)))
val releaseHit: Bool = maskControlRelease.map(r => r.valid && (r.bits === state.index)).reduce(_ || _)
val responseFire =
Pipe(maskControlReqSelect(index), 0.U.asTypeOf(new EmptyBundle), parameter.maskRequestLatency).valid

when(maskControlEnq(index)) {
state := 0.U.asTypeOf(state)
state.index := laneRequest.bits.instructionIndex
state.sew := laneRequest.bits.csrInterface.vSew
state.controlValid := true.B
}

when(state.controlValid) {
when(releaseHit) {
state.controlValid := false.B
}
}

maskControlReq(index) := state.controlValid && !state.dataValid && !state.waiteResponse
when(maskControlReqSelect(index)) {
state.waiteResponse := true.B
state.group := state.group + 1.U
}

when(responseFire) {
state.dataValid := true.B
state.waiteResponse := false.B
state.maskData := maskInput
}

when(maskControlDataDeq(index)) {
state.dataValid := false.B
}

state
}
val maskControlFree: Seq[Bool] = maskControlVec.map(s => !s.controlValid && !s.waiteResponse)
val freeSelect: UInt = ffo(VecInit(maskControlFree).asUInt)
maskControlEnq := maskAnd(laneRequest.fire && laneRequest.bits.mask, freeSelect)

/** for each slot, assert when it is asking [[T1]] to change mask */
val slotMaskRequestVec: Vec[ValidIO[UInt]] = Wire(
Vec(
Expand All @@ -388,7 +438,8 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
)

/** which slot wins the arbitration for requesting mask. */
val maskRequestFireOH: UInt = Wire(UInt(parameter.chainingSize.W))
val maskRequestFireOH: Vec[Bool] = Wire(Vec(parameter.chainingSize, Bool()))
val maskDataVec: Vec[UInt] = Wire(Vec(parameter.chainingSize, UInt(parameter.maskGroupWidth.W)))

/** FSM control for each slot. if index == 0,
* - slot can support write v0 in mask type, see [[Decoder.maskDestination]] [[Decoder.maskSource]]
Expand Down Expand Up @@ -611,13 +662,16 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
laneState.elements.get(k).foreach(stateData => d := stateData)
}

maskControlRelease(index).valid := false.B
maskControlRelease(index).bits := record.laneRequest.instructionIndex
// update lane state
when(stage0.enqueue.fire) {
maskGroupCountVec(index) := stage0.updateLaneState.maskGroupCount
// todo: handle all elements in first group are masked
maskIndexVec(index) := stage0.updateLaneState.maskIndex
when(stage0.updateLaneState.outOfExecutionRange) {
slotOccupied(index) := false.B
maskControlRelease(index).valid := true.B
}
}

Expand All @@ -632,7 +686,7 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
val maskFailure: Bool = stage0.updateLaneState.maskExhausted && stage0.enqueue.fire
// update mask register
when(maskUpdateFire) {
record.mask.bits := maskInput
record.mask.bits := DontCare
}
when(maskUpdateFire ^ maskFailure) {
record.mask.valid := maskUpdateFire
Expand Down Expand Up @@ -908,20 +962,21 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
}

{
// 处理mask的请求
val maskSelectArbitrator = ffo(
VecInit(slotMaskRequestVec.map(_.valid)).asUInt ##
(laneRequest.valid && (laneRequest.bits.mask || laneRequest.bits.decodeResult(Decoder.maskSource)))
)
maskRequestFireOH := maskSelectArbitrator(parameter.chainingSize, 1)
maskSelect := Mux1H(
maskSelectArbitrator,
0.U.asTypeOf(slotMaskRequestVec.head.bits) +: slotMaskRequestVec.map(_.bits)
)
maskSelectSew := Mux1H(
maskSelectArbitrator,
csrInterface.vSew +: slotControl.map(_.laneRequest.csrInterface.vSew)
)
maskSelect := Mux1H(maskControlReqSelect, maskControlVec.map(_.group))
maskSelectSew := Mux1H(maskControlReqSelect, maskControlVec.map(_.sew))
maskControlDataDeq := slotMaskRequestVec.zipWithIndex.map { case (req, index) =>
val slotIndex = slotControl(index).laneRequest.instructionIndex
val hitMaskControl = VecInit(maskControlVec.map(_.index === slotIndex)).asUInt
val dataValid = Mux1H(hitMaskControl, maskControlVec.map(_.dataValid))
val data = Mux1H(hitMaskControl, maskControlVec.map(_.maskData))
val group = Mux1H(hitMaskControl, maskControlVec.map(_.group))
val sameGroup = group === req.bits
dontTouch(sameGroup)
val maskRequestFire = req.valid && dataValid
maskRequestFireOH(index) := maskRequestFire
maskDataVec(index) := data
maskAnd(maskRequestFire, hitMaskControl).asUInt
}.reduce(_ | _)
}

// package a control logic for incoming instruction.
Expand All @@ -944,9 +999,9 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
// for 'nr' type instructions, they will need another complete signal.
!(laneRequest.bits.decodeResult(Decoder.nr) || laneRequest.bits.lsWholeReg)
// indicate if this is the mask type.
entranceControl.mask.valid := laneRequest.bits.mask
entranceControl.mask.valid := false.B
// assign mask from [[V]]
entranceControl.mask.bits := maskInput
entranceControl.mask.bits := DontCare
// mask used for VRF write in this group.
entranceControl.vrfWriteMask := 0.U

Expand Down
11 changes: 7 additions & 4 deletions t1/src/T1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ case class T1Parameter(
val lsuReadTokenSize: Seq[Int] = Seq.tabulate(laneNumber)(_ => 4)
val lsuReadShifterSize: Seq[Int] = Seq.tabulate(laneNumber)(_ => 1)

val maskRequestLatency = 2

val decoderParam: DecoderParam = DecoderParam(fpuEnable, zvbbEnable, allInstructions)

/** paraemter for AXI4. */
Expand Down Expand Up @@ -299,6 +301,7 @@ case class T1Parameter(
crossLaneVRFWriteEscapeQueueSize = vrfWriteQueueSize,
fpuEnable = fpuEnable,
portFactor = vrfBankSize,
maskRequestLatency = 2 * maskRequestLatency,
vrfRamType = vrfRamType,
decoderParam = decoderParam,
vfuInstantiateParameter = vfuInstantiateParameter
Expand Down Expand Up @@ -760,7 +763,7 @@ class T1(val parameter: T1Parameter)
lane.vrfWriteChannel,
0
)
lane.writeFromMask := maskUnit.exeResp(index).fire
lane.writeFromMask := maskUnit.io.exeResp(index).fire

lsu.offsetReadResult(index).valid := lane.maskUnitRequest.valid && lane.maskRequestToLSU
lsu.offsetReadResult(index).bits := lane.maskUnitRequest.bits.source2
Expand All @@ -770,9 +773,9 @@ class T1(val parameter: T1Parameter)
d := ohCheck(lane.instructionFinished, f, parameter.chainingSize)
}
vxsatReportVec(index) := lane.vxsatReport
lane.maskInput := maskUnit.io.laneMaskInput(index)
maskUnit.io.laneMaskSelect(index) := lane.maskSelect
maskUnit.io.laneMaskSewSelect(index) := lane.maskSelectSew
lane.maskInput := Pipe(true.B, maskUnit.io.laneMaskInput(index), parameter.maskRequestLatency).bits
maskUnit.io.laneMaskSelect(index) := Pipe(true.B, lane.maskSelect, parameter.maskRequestLatency).bits
maskUnit.io.laneMaskSewSelect(index) := Pipe(true.B, lane.maskSelectSew, parameter.maskRequestLatency).bits
maskUnit.io.v0UpdateVec(index) <> lane.v0Update

lane.lsuLastReport := lsu.lastReport | maskUnit.io.lastReport
Expand Down

0 comments on commit 0be86ac

Please sign in to comment.