Skip to content

Commit

Permalink
[RISCV][VLOPT] Add support for checkUsers when UserMI is a Single-Wid…
Browse files Browse the repository at this point in the history
…th Integer Reduction (llvm#120345)

Reductions are weird because for some operands, they are vector
registers but only read the first lane. For these operands, we do not
need to check to make sure the EEW and EMUL ratios match. The EEWs,
however, do need to match.
  • Loading branch information
michaelmaitland authored Jan 7, 2025
1 parent b225513 commit 142787d
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 53 deletions.
149 changes: 99 additions & 50 deletions llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ class RISCVVLOptimizer : public MachineFunctionPass {
StringRef getPassName() const override { return PASS_NAME; }

private:
bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI);
std::optional<MachineOperand> getMinimumVLForUser(MachineOperand &UserOp);
/// Returns the largest common VL MachineOperand that may be used to optimize
/// MI. Returns std::nullopt if it failed to find a suitable VL.
std::optional<MachineOperand> checkUsers(MachineInstr &MI);
bool tryReduceVL(MachineInstr &MI);
bool isCandidate(const MachineInstr &MI) const;
};
Expand Down Expand Up @@ -95,6 +98,8 @@ struct OperandInfo {
OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
: S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {}

OperandInfo(unsigned Log2EEW) : S(State::Known), Log2EEW(Log2EEW) {}

OperandInfo() : S(State::Unknown) {}

bool isUnknown() const { return S == State::Unknown; }
Expand All @@ -107,6 +112,11 @@ struct OperandInfo {
A.EMUL->second == B.EMUL->second;
}

static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
assert(A.isKnown() && B.isKnown() && "Both operands must be known");
return A.Log2EEW == B.Log2EEW;
}

void print(raw_ostream &OS) const {
if (isUnknown()) {
OS << "Unknown";
Expand Down Expand Up @@ -724,6 +734,23 @@ static OperandInfo getOperandInfo(const MachineOperand &MO,
return OperandInfo(MIVLMul, MILog2SEW);
}

// Vector Reduction Operations
// Vector Single-Width Integer Reduction Instructions
// The Dest and VS1 only read element 0 of the vector register. Return just
// the EEW for these. VS2 has EEW=SEW and EMUL=LMUL.
case RISCV::VREDAND_VS:
case RISCV::VREDMAX_VS:
case RISCV::VREDMAXU_VS:
case RISCV::VREDMIN_VS:
case RISCV::VREDMINU_VS:
case RISCV::VREDOR_VS:
case RISCV::VREDSUM_VS:
case RISCV::VREDXOR_VS: {
if (MO.getOperandNo() == 2)
return OperandInfo(MIVLMul, MILog2SEW);
return OperandInfo(MILog2SEW);
}

default:
return {};
}
Expand Down Expand Up @@ -1061,79 +1088,102 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
return true;
}

bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL,
MachineInstr &MI) {
std::optional<MachineOperand>
RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
const MachineInstr &UserMI = *UserOp.getParent();
const MCInstrDesc &Desc = UserMI.getDesc();

if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
" use VLMAX\n");
return std::nullopt;
}

// Instructions like reductions may use a vector register as a scalar
// register. In this case, we should treat it as only reading the first lane.
if (isVectorOpUsedAsScalarOp(UserOp)) {
[[maybe_unused]] Register R = UserOp.getReg();
[[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R);
assert(RISCV::VRRegClass.hasSubClassEq(RC) &&
"Expect LMUL 1 register class for vector as scalar operands!");
LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n");

return MachineOperand::CreateImm(1);
}

unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
// Looking for an immediate or a register VL that isn't X0.
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
"Did not expect X0 VL");
return VLOp;
}

std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
// FIXME: Avoid visiting each user for each time we visit something on the
// worklist, combined with an extra visit from the outer loop. Restructure
// along lines of an instcombine style worklist which integrates the outer
// pass.
bool CanReduceVL = true;
std::optional<MachineOperand> CommonVL;
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) {
const MachineInstr &UserMI = *UserOp.getParent();
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");

// Instructions like reductions may use a vector register as a scalar
// register. In this case, we should treat it like a scalar register which
// does not impact the decision on whether to optimize VL.
// TODO: Treat it like a scalar register instead of bailing out.
if (isVectorOpUsedAsScalarOp(UserOp)) {
CanReduceVL = false;
break;
}

if (mayReadPastVL(UserMI)) {
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
CanReduceVL = false;
break;
return std::nullopt;
}

// Tied operands might pass through.
if (UserOp.isTied()) {
LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n");
CanReduceVL = false;
break;
return std::nullopt;
}

const MCInstrDesc &Desc = UserMI.getDesc();
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
LLVM_DEBUG(dbgs() << " Abort due to lack of VL or SEW, assume that"
" use VLMAX\n");
CanReduceVL = false;
break;
}

unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);

// Looking for an immediate or a register VL that isn't X0.
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
"Did not expect X0 VL");
auto VLOp = getMinimumVLForUser(UserOp);
if (!VLOp)
return std::nullopt;

// Use the largest VL among all the users. If we cannot determine this
// statically, then we cannot optimize the VL.
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, VLOp)) {
CommonVL = &VLOp;
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
CommonVL = *VLOp;
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
} else if (!RISCV::isVLKnownLE(VLOp, *CommonVL)) {
} else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n");
CanReduceVL = false;
break;
return std::nullopt;
}

if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
return std::nullopt;
}

// The SEW and LMUL of destination and source registers need to match.
OperandInfo ConsumerInfo = getOperandInfo(UserOp, MRI);
OperandInfo ProducerInfo = getOperandInfo(MI.getOperand(0), MRI);
if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown() ||
!OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo)) {
LLVM_DEBUG(dbgs() << " Abort due to incompatible or unknown "
"information for EMUL or EEW.\n");
if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown()) {
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
CanReduceVL = false;
break;
return std::nullopt;
}

// If the operand is used as a scalar operand, then the EEW must be
// compatible. Otherwise, the EMUL *and* EEW must be compatible.
bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp);
if ((IsVectorOpUsedAsScalarOp &&
!OperandInfo::EEWAreEqual(ConsumerInfo, ProducerInfo)) ||
(!IsVectorOpUsedAsScalarOp &&
!OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo))) {
LLVM_DEBUG(
dbgs()
<< " Abort due to incompatible information for EMUL or EEW.\n");
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
return std::nullopt;
}
}
return CanReduceVL;

return CommonVL;
}

bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
Expand All @@ -1145,12 +1195,11 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
MachineInstr &MI = *Worklist.pop_back_val();
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");

const MachineOperand *CommonVL = nullptr;
bool CanReduceVL = true;
if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
CanReduceVL = checkUsers(CommonVL, MI);
if (!isVectorRegClass(MI.getOperand(0).getReg(), MRI))
continue;

if (!CanReduceVL || !CommonVL)
auto CommonVL = checkUsers(MI);
if (!CommonVL)
continue;

assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ define i32 @reduce_and_16xi32_prefix5(ptr %p) {
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 5, e32, m2, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: vsetivli zero, 5, e32, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; CHECK-NEXT: vmv.v.i v10, -1
; CHECK-NEXT: vsetivli zero, 5, e32, m2, ta, ma
; CHECK-NEXT: vredand.vs v8, v8, v10
Expand Down Expand Up @@ -725,7 +725,7 @@ define i32 @reduce_umin_16xi32_prefix5(ptr %p) {
; RV32: # %bb.0:
; RV32-NEXT: vsetivli zero, 5, e32, m2, ta, ma
; RV32-NEXT: vle32.v v8, (a0)
; RV32-NEXT: vsetivli zero, 5, e32, m1, ta, ma
; RV32-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; RV32-NEXT: vmv.v.i v10, -1
; RV32-NEXT: vsetivli zero, 5, e32, m2, ta, ma
; RV32-NEXT: vredminu.vs v8, v8, v10
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ entry:
define i64 @reduce_add2(<4 x i64> %v) {
; CHECK-LABEL: reduce_add2:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, ma
; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, ma
; CHECK-NEXT: vmv.v.i v10, 8
; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma
; CHECK-NEXT: vredsum.vs v8, v8, v10
Expand Down
113 changes: 113 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.mir
Original file line number Diff line number Diff line change
Expand Up @@ -1174,3 +1174,116 @@ body: |
%x:vr = PseudoVMAND_MM_B1 $noreg, $noreg, -1, 0
%y:vr = PseudoVIOTA_M_MF2 $noreg, %x, 1, 3 /* e8 */, 0
...
name: vred_vs2
body: |
bb.0:
; CHECK-LABEL: name: vred_vs2
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0
%y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, $noreg, 1, 3 /* e8 */, 0
...
---
name: vred_vs1
body: |
bb.0:
; CHECK-LABEL: name: vred_vs1
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0
%y:vr = PseudoVREDAND_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0
...
---
name: vred_vs1_vs2
body: |
bb.0:
; CHECK-LABEL: name: vred_vs1_vs2
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, %x, 1, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0
%y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, %x, 1, 3 /* e8 */, 0
...
---
name: vred_vs1_vs2_incompatible_eew
body: |
bb.0:
; CHECK-LABEL: name: vred_vs1_vs2_incompatible_eew
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, %x, 1, 4 /* e16 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0
%y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, %x, 1, 4 /* e16 */, 0
...
---
name: vred_vs1_vs2_incompatible_emul
body: |
bb.0:
; CHECK-LABEL: name: vred_vs1_vs2_incompatible_emul
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVREDAND_VS_MF2_E8 $noreg, %x, %x, 1, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0
%y:vr = PseudoVREDAND_VS_MF2_E8 $noreg, %x, %x, 1, 3 /* e8 */, 0
...
---
name: vred_other_user_is_vl0
body: |
bb.0:
; CHECK-LABEL: name: vred_other_user_is_vl0
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0
%y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0
%z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0
...
---
name: vred_both_vl0
body: |
bb.0:
; CHECK-LABEL: name: vred_both_vl0
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 0, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0
%y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 0, 3 /* e8 */, 0
%z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0
...
---
name: vred_vl0_and_vlreg
body: |
bb.0:
; CHECK-LABEL: name: vred_vl0_and_vlreg
; CHECK: %vl:gprnox0 = COPY $x1
; CHECK-NEXT: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, %vl, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0 /* tu, mu */
%vl:gprnox0 = COPY $x1
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0
%y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, %vl, 3 /* e8 */, 0
%z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0
...
---
name: vred_vlreg_and_vl0
body: |
bb.0:
; CHECK-LABEL: name: vred_vlreg_and_vl0
; CHECK: %vl:gprnox0 = COPY $x1
; CHECK-NEXT: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 0, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, %vl, 3 /* e8 */, 0 /* tu, mu */
%vl:gprnox0 = COPY $x1
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0
%y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 0, 3 /* e8 */, 0
%z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, %vl, 3 /* e8 */, 0
...
---
name: vred_other_user_is_vl2
body: |
bb.0:
; CHECK-LABEL: name: vred_other_user_is_vl2
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 2, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 2, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0
%y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0
%z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 2, 3 /* e8 */, 0
...

0 comments on commit 142787d

Please sign in to comment.