diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp index d7ac3afe7b76b2..9ecb0453fb11cc 100644 --- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp @@ -50,7 +50,10 @@ class RISCVVLOptimizer : public MachineFunctionPass { StringRef getPassName() const override { return PASS_NAME; } private: - bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI); + std::optional getMinimumVLForUser(MachineOperand &UserOp); + /// Returns the largest common VL MachineOperand that may be used to optimize + /// MI. Returns std::nullopt if it failed to find a suitable VL. + std::optional checkUsers(MachineInstr &MI); bool tryReduceVL(MachineInstr &MI); bool isCandidate(const MachineInstr &MI) const; }; @@ -95,6 +98,8 @@ struct OperandInfo { OperandInfo(std::pair EMUL, unsigned Log2EEW) : S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {} + OperandInfo(unsigned Log2EEW) : S(State::Known), Log2EEW(Log2EEW) {} + OperandInfo() : S(State::Unknown) {} bool isUnknown() const { return S == State::Unknown; } @@ -107,6 +112,11 @@ struct OperandInfo { A.EMUL->second == B.EMUL->second; } + static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) { + assert(A.isKnown() && B.isKnown() && "Both operands must be known"); + return A.Log2EEW == B.Log2EEW; + } + void print(raw_ostream &OS) const { if (isUnknown()) { OS << "Unknown"; @@ -724,6 +734,23 @@ static OperandInfo getOperandInfo(const MachineOperand &MO, return OperandInfo(MIVLMul, MILog2SEW); } + // Vector Reduction Operations + // Vector Single-Width Integer Reduction Instructions + // The Dest and VS1 only read element 0 of the vector register. Return just + // the EEW for these. VS2 has EEW=SEW and EMUL=LMUL. + case RISCV::VREDAND_VS: + case RISCV::VREDMAX_VS: + case RISCV::VREDMAXU_VS: + case RISCV::VREDMIN_VS: + case RISCV::VREDMINU_VS: + case RISCV::VREDOR_VS: + case RISCV::VREDSUM_VS: + case RISCV::VREDXOR_VS: { + if (MO.getOperandNo() == 2) + return OperandInfo(MIVLMul, MILog2SEW); + return OperandInfo(MILog2SEW); + } + default: return {}; } @@ -1061,79 +1088,102 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { return true; } -bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL, - MachineInstr &MI) { +std::optional +RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) { + const MachineInstr &UserMI = *UserOp.getParent(); + const MCInstrDesc &Desc = UserMI.getDesc(); + + if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { + LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" + " use VLMAX\n"); + return std::nullopt; + } + + // Instructions like reductions may use a vector register as a scalar + // register. In this case, we should treat it as only reading the first lane. + if (isVectorOpUsedAsScalarOp(UserOp)) { + [[maybe_unused]] Register R = UserOp.getReg(); + [[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R); + assert(RISCV::VRRegClass.hasSubClassEq(RC) && + "Expect LMUL 1 register class for vector as scalar operands!"); + LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n"); + + return MachineOperand::CreateImm(1); + } + + unsigned VLOpNum = RISCVII::getVLOpNum(Desc); + const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); + // Looking for an immediate or a register VL that isn't X0. + assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && + "Did not expect X0 VL"); + return VLOp; +} + +std::optional RISCVVLOptimizer::checkUsers(MachineInstr &MI) { // FIXME: Avoid visiting each user for each time we visit something on the // worklist, combined with an extra visit from the outer loop. Restructure // along lines of an instcombine style worklist which integrates the outer // pass. - bool CanReduceVL = true; + std::optional CommonVL; for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) { const MachineInstr &UserMI = *UserOp.getParent(); LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n"); - - // Instructions like reductions may use a vector register as a scalar - // register. In this case, we should treat it like a scalar register which - // does not impact the decision on whether to optimize VL. - // TODO: Treat it like a scalar register instead of bailing out. - if (isVectorOpUsedAsScalarOp(UserOp)) { - CanReduceVL = false; - break; - } - if (mayReadPastVL(UserMI)) { LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); - CanReduceVL = false; - break; + return std::nullopt; } // Tied operands might pass through. if (UserOp.isTied()) { LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n"); - CanReduceVL = false; - break; + return std::nullopt; } - const MCInstrDesc &Desc = UserMI.getDesc(); - if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { - LLVM_DEBUG(dbgs() << " Abort due to lack of VL or SEW, assume that" - " use VLMAX\n"); - CanReduceVL = false; - break; - } - - unsigned VLOpNum = RISCVII::getVLOpNum(Desc); - const MachineOperand &VLOp = UserMI.getOperand(VLOpNum); - - // Looking for an immediate or a register VL that isn't X0. - assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) && - "Did not expect X0 VL"); + auto VLOp = getMinimumVLForUser(UserOp); + if (!VLOp) + return std::nullopt; // Use the largest VL among all the users. If we cannot determine this // statically, then we cannot optimize the VL. - if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, VLOp)) { - CommonVL = &VLOp; + if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) { + CommonVL = *VLOp; LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); - } else if (!RISCV::isVLKnownLE(VLOp, *CommonVL)) { + } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) { LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n"); - CanReduceVL = false; - break; + return std::nullopt; + } + + if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) { + LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n"); + return std::nullopt; } - // The SEW and LMUL of destination and source registers need to match. OperandInfo ConsumerInfo = getOperandInfo(UserOp, MRI); OperandInfo ProducerInfo = getOperandInfo(MI.getOperand(0), MRI); - if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown() || - !OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo)) { - LLVM_DEBUG(dbgs() << " Abort due to incompatible or unknown " - "information for EMUL or EEW.\n"); + if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown()) { + LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n"); LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); - CanReduceVL = false; - break; + return std::nullopt; + } + + // If the operand is used as a scalar operand, then the EEW must be + // compatible. Otherwise, the EMUL *and* EEW must be compatible. + bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp); + if ((IsVectorOpUsedAsScalarOp && + !OperandInfo::EEWAreEqual(ConsumerInfo, ProducerInfo)) || + (!IsVectorOpUsedAsScalarOp && + !OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo))) { + LLVM_DEBUG( + dbgs() + << " Abort due to incompatible information for EMUL or EEW.\n"); + LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); + LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); + return std::nullopt; } } - return CanReduceVL; + + return CommonVL; } bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { @@ -1145,12 +1195,11 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { MachineInstr &MI = *Worklist.pop_back_val(); LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); - const MachineOperand *CommonVL = nullptr; - bool CanReduceVL = true; - if (isVectorRegClass(MI.getOperand(0).getReg(), MRI)) - CanReduceVL = checkUsers(CommonVL, MI); + if (!isVectorRegClass(MI.getOperand(0).getReg(), MRI)) + continue; - if (!CanReduceVL || !CommonVL) + auto CommonVL = checkUsers(MI); + if (!CommonVL) continue; assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll index 4f0f5dd78c94b6..bf8baafc4a25db 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll @@ -530,7 +530,7 @@ define i32 @reduce_and_16xi32_prefix5(ptr %p) { ; CHECK: # %bb.0: ; CHECK-NEXT: vsetivli zero, 5, e32, m2, ta, ma ; CHECK-NEXT: vle32.v v8, (a0) -; CHECK-NEXT: vsetivli zero, 5, e32, m1, ta, ma +; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma ; CHECK-NEXT: vmv.v.i v10, -1 ; CHECK-NEXT: vsetivli zero, 5, e32, m2, ta, ma ; CHECK-NEXT: vredand.vs v8, v8, v10 @@ -725,7 +725,7 @@ define i32 @reduce_umin_16xi32_prefix5(ptr %p) { ; RV32: # %bb.0: ; RV32-NEXT: vsetivli zero, 5, e32, m2, ta, ma ; RV32-NEXT: vle32.v v8, (a0) -; RV32-NEXT: vsetivli zero, 5, e32, m1, ta, ma +; RV32-NEXT: vsetivli zero, 1, e32, m1, ta, ma ; RV32-NEXT: vmv.v.i v10, -1 ; RV32-NEXT: vsetivli zero, 5, e32, m2, ta, ma ; RV32-NEXT: vredminu.vs v8, v8, v10 diff --git a/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll b/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll index 2fda344690bfc6..6787c8c24c87ef 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll @@ -18,7 +18,7 @@ entry: define i64 @reduce_add2(<4 x i64> %v) { ; CHECK-LABEL: reduce_add2: ; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, ma +; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, ma ; CHECK-NEXT: vmv.v.i v10, 8 ; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma ; CHECK-NEXT: vredsum.vs v8, v8, v10 diff --git a/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.mir b/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.mir index a1bbfc8a7d3514..1618f0aa854e55 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.mir +++ b/llvm/test/CodeGen/RISCV/rvv/vl-opt-op-info.mir @@ -1174,3 +1174,116 @@ body: | %x:vr = PseudoVMAND_MM_B1 $noreg, $noreg, -1, 0 %y:vr = PseudoVIOTA_M_MF2 $noreg, %x, 1, 3 /* e8 */, 0 ... +name: vred_vs2 +body: | + bb.0: + ; CHECK-LABEL: name: vred_vs2 + ; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */ + %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 + %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, $noreg, 1, 3 /* e8 */, 0 +... +--- +name: vred_vs1 +body: | + bb.0: + ; CHECK-LABEL: name: vred_vs1 + ; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0 /* tu, mu */ + %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 + %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0 +... +--- +name: vred_vs1_vs2 +body: | + bb.0: + ; CHECK-LABEL: name: vred_vs1_vs2 + ; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, %x, 1, 3 /* e8 */, 0 /* tu, mu */ + %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 + %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, %x, 1, 3 /* e8 */, 0 +... +--- +name: vred_vs1_vs2_incompatible_eew +body: | + bb.0: + ; CHECK-LABEL: name: vred_vs1_vs2_incompatible_eew + ; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, %x, 1, 4 /* e16 */, 0 /* tu, mu */ + %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 + %y:vr = PseudoVREDAND_VS_M1_E8 $noreg, %x, %x, 1, 4 /* e16 */, 0 +... +--- +name: vred_vs1_vs2_incompatible_emul +body: | + bb.0: + ; CHECK-LABEL: name: vred_vs1_vs2_incompatible_emul + ; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %y:vr = PseudoVREDAND_VS_MF2_E8 $noreg, %x, %x, 1, 3 /* e8 */, 0 /* tu, mu */ + %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 + %y:vr = PseudoVREDAND_VS_MF2_E8 $noreg, %x, %x, 1, 3 /* e8 */, 0 +... +--- +name: vred_other_user_is_vl0 +body: | + bb.0: + ; CHECK-LABEL: name: vred_other_user_is_vl0 + ; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0 /* tu, mu */ + %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 + %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0 + %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0 +... +--- +name: vred_both_vl0 +body: | + bb.0: + ; CHECK-LABEL: name: vred_both_vl0 + ; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 0, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0 /* tu, mu */ + %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 + %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 0, 3 /* e8 */, 0 + %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0 +... +--- +name: vred_vl0_and_vlreg +body: | + bb.0: + ; CHECK-LABEL: name: vred_vl0_and_vlreg + ; CHECK: %vl:gprnox0 = COPY $x1 + ; CHECK-NEXT: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, %vl, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0 /* tu, mu */ + %vl:gprnox0 = COPY $x1 + %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 + %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, %vl, 3 /* e8 */, 0 + %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 0, 3 /* e8 */, 0 +... +--- +name: vred_vlreg_and_vl0 +body: | + bb.0: + ; CHECK-LABEL: name: vred_vlreg_and_vl0 + ; CHECK: %vl:gprnox0 = COPY $x1 + ; CHECK-NEXT: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 0, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, %vl, 3 /* e8 */, 0 /* tu, mu */ + %vl:gprnox0 = COPY $x1 + %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 + %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 0, 3 /* e8 */, 0 + %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, %vl, 3 /* e8 */, 0 +... +--- +name: vred_other_user_is_vl2 +body: | + bb.0: + ; CHECK-LABEL: name: vred_other_user_is_vl2 + ; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 2, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0 /* tu, mu */ + ; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 2, 3 /* e8 */, 0 /* tu, mu */ + %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 + %y:vr = PseudoVREDSUM_VS_M1_E8 $noreg, $noreg, %x, 1, 3 /* e8 */, 0 + %z:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 2, 3 /* e8 */, 0 +...