Skip to content

Commit

Permalink
[RISCV] Use vrsub for select of add and sub of the same operands (#12…
Browse files Browse the repository at this point in the history
…3400)

If we have a (vselect c, a+b, a-b), we can combine this to a+(vselect c,
b, -b). That by itself isn't hugely profitable, but if we reverse the
select, we get a form which matches a masked vrsub.vi with zero. The
result is that we can use a masked vrsub *before* the add instead of a
masked add or sub. This doesn't change the critical path (since we
already had the pass through on the masked second op), but does reduce
register pressure since a, b, and (a+b) don't need to all be alive at
once.

In addition to the vselect form, we can also see the same pattern with a
vector_shuffle encoding the vselect. I explored canonicalizing these to
vselects instead, but that exposes several unrelated missing combines.
  • Loading branch information
preames authored Jan 24, 2025
1 parent 7293455 commit a9ad601
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 112 deletions.
89 changes: 83 additions & 6 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1535,7 +1535,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::UDIV, ISD::SREM,
ISD::UREM, ISD::INSERT_VECTOR_ELT,
ISD::ABS, ISD::CTPOP,
ISD::VECTOR_SHUFFLE});
ISD::VECTOR_SHUFFLE, ISD::VSELECT});

if (Subtarget.hasVendorXTHeadMemPair())
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
if (Subtarget.useRVVForFixedLengthVectors())
Expand Down Expand Up @@ -16874,6 +16875,53 @@ static SDValue useInversedSetcc(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static bool matchSelectAddSub(SDValue TrueVal, SDValue FalseVal, bool &SwapCC) {
if (!TrueVal.hasOneUse() || !FalseVal.hasOneUse())
return false;

SwapCC = false;
if (TrueVal.getOpcode() == ISD::SUB && FalseVal.getOpcode() == ISD::ADD) {
std::swap(TrueVal, FalseVal);
SwapCC = true;
}

if (TrueVal.getOpcode() != ISD::ADD || FalseVal.getOpcode() != ISD::SUB)
return false;

SDValue A = FalseVal.getOperand(0);
SDValue B = FalseVal.getOperand(1);
// Add is commutative, so check both orders
return ((TrueVal.getOperand(0) == A && TrueVal.getOperand(1) == B) ||
(TrueVal.getOperand(1) == A && TrueVal.getOperand(0) == B));
}

/// Convert vselect CC, (add a, b), (sub a, b) to add a, (vselect CC, -b, b).
/// This allows us match a vadd.vv fed by a masked vrsub, which reduces
/// register pressure over the add followed by masked vsub sequence.
static SDValue performVSELECTCombine(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
SDValue CC = N->getOperand(0);
SDValue TrueVal = N->getOperand(1);
SDValue FalseVal = N->getOperand(2);

bool SwapCC;
if (!matchSelectAddSub(TrueVal, FalseVal, SwapCC))
return SDValue();

SDValue Sub = SwapCC ? TrueVal : FalseVal;
SDValue A = Sub.getOperand(0);
SDValue B = Sub.getOperand(1);

// Arrange the select such that we can match a masked
// vrsub.vi to perform the conditional negate
SDValue NegB = DAG.getNegative(B, DL, VT);
if (!SwapCC)
CC = DAG.getLogicalNOT(DL, CC, CC->getValueType(0));
SDValue NewB = DAG.getNode(ISD::VSELECT, DL, VT, CC, NegB, B);
return DAG.getNode(ISD::ADD, DL, VT, A, NewB);
}

static SDValue performSELECTCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (SDValue Folded = foldSelectOfCTTZOrCTLZ(N, DAG))
Expand Down Expand Up @@ -17153,20 +17201,48 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getBitcast(VT.getSimpleVT(), StridedLoad);
}

/// Custom legalize <N x i128> or <N x i256> to <M x ELEN>. This runs
/// during the combine phase before type legalization, and relies on
/// DAGCombine not undoing the transform if isShuffleMaskLegal returns false
/// for the source mask.
static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
const unsigned ElementSize = VT.getScalarSizeInBits();
const unsigned NumElts = VT.getVectorNumElements();
SDValue V1 = N->getOperand(0);
SDValue V2 = N->getOperand(1);
ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask();
MVT XLenVT = Subtarget.getXLenVT();

// Recognized a disguised select of add/sub.
bool SwapCC;
if (ShuffleVectorInst::isSelectMask(Mask, NumElts) &&
matchSelectAddSub(V1, V2, SwapCC)) {
SDValue Sub = SwapCC ? V1 : V2;
SDValue A = Sub.getOperand(0);
SDValue B = Sub.getOperand(1);

SmallVector<SDValue> MaskVals;
for (int MaskIndex : Mask) {
bool SelectMaskVal = (MaskIndex < (int)NumElts);
MaskVals.push_back(DAG.getConstant(SelectMaskVal, DL, XLenVT));
}
assert(MaskVals.size() == NumElts && "Unexpected select-like shuffle");
EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts);
SDValue CC = DAG.getBuildVector(MaskVT, DL, MaskVals);

// Arrange the select such that we can match a masked
// vrsub.vi to perform the conditional negate
SDValue NegB = DAG.getNegative(B, DL, VT);
if (!SwapCC)
CC = DAG.getLogicalNOT(DL, CC, CC->getValueType(0));
SDValue NewB = DAG.getNode(ISD::VSELECT, DL, VT, CC, NegB, B);
return DAG.getNode(ISD::ADD, DL, VT, A, NewB);
}

// Custom legalize <N x i128> or <N x i256> to <M x ELEN>. This runs
// during the combine phase before type legalization, and relies on
// DAGCombine not undoing the transform if isShuffleMaskLegal returns false
// for the source mask.
if (TLI.isTypeLegal(VT) || ElementSize <= Subtarget.getELen() ||
!isPowerOf2_64(ElementSize) || VT.getVectorNumElements() % 2 != 0 ||
VT.isFloatingPoint() || TLI.isShuffleMaskLegal(Mask, VT))
Expand All @@ -17183,7 +17259,6 @@ static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
return DAG.getBitcast(VT, Res);
}


static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {

Expand Down Expand Up @@ -17857,6 +17932,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
return performSELECTCombine(N, DAG, Subtarget);
case ISD::VSELECT:
return performVSELECTCombine(N, DAG);
case RISCVISD::CZERO_EQZ:
case RISCVISD::CZERO_NEZ: {
SDValue Val = N->getOperand(0);
Expand Down
Loading

0 comments on commit a9ad601

Please sign in to comment.