Skip to content

Commit

Permalink
[LLVM][CodeGen] Add lowering for scalable vector bfloat operations. (l…
Browse files Browse the repository at this point in the history
…lvm#109803)

Specifically:
  fabs, fadd, fceil, fdiv, ffloor, fma, fmax, fmaxnm, fmin, fminnm,
  fmul, fnearbyint, fneg, frint, fround, froundeven, fsub, fsqrt &
  ftrunc
  • Loading branch information
paulwalker-arm authored Oct 7, 2024
1 parent 8b6e1dc commit 02dd6b1
Show file tree
Hide file tree
Showing 9 changed files with 1,234 additions and 31 deletions.
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5616,6 +5616,10 @@ class TargetLowering : public TargetLoweringBase {
return true;
}

// Expand vector operation by dividing it into smaller length operations and
// joining their results. SDValue() is returned when expansion did not happen.
SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;

private:
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
const SDLoc &DL, DAGCombinerInfo &DCI) const;
Expand Down
23 changes: 23 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,24 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::UCMP:
Results.push_back(TLI.expandCMP(Node, DAG));
return;

case ISD::FADD:
case ISD::FMUL:
case ISD::FMA:
case ISD::FDIV:
case ISD::FCEIL:
case ISD::FFLOOR:
case ISD::FNEARBYINT:
case ISD::FRINT:
case ISD::FROUND:
case ISD::FROUNDEVEN:
case ISD::FTRUNC:
case ISD::FSQRT:
if (SDValue Expanded = TLI.expandVectorNaryOpBySplitting(Node, DAG)) {
Results.push_back(Expanded);
return;
}
break;
}

SDValue Unrolled = DAG.UnrollVectorOp(Node);
Expand Down Expand Up @@ -1885,6 +1903,11 @@ void VectorLegalizer::ExpandFSUB(SDNode *Node,
TLI.isOperationLegalOrCustom(ISD::FADD, VT))
return; // Defer to LegalizeDAG

if (SDValue Expanded = TLI.expandVectorNaryOpBySplitting(Node, DAG)) {
Results.push_back(Expanded);
return;
}

SDValue Tmp = DAG.UnrollVectorOp(Node);
Results.push_back(Tmp);
}
Expand Down
46 changes: 42 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8440,15 +8440,18 @@ TargetLowering::createSelectForFMINNUM_FMAXNUM(SDNode *Node,

SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node,
SelectionDAG &DAG) const {
SDLoc dl(Node);
unsigned NewOp = Node->getOpcode() == ISD::FMINNUM ?
ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;
EVT VT = Node->getValueType(0);
if (SDValue Expanded = expandVectorNaryOpBySplitting(Node, DAG))
return Expanded;

EVT VT = Node->getValueType(0);
if (VT.isScalableVector())
report_fatal_error(
"Expanding fminnum/fmaxnum for scalable vectors is undefined.");

SDLoc dl(Node);
unsigned NewOp =
Node->getOpcode() == ISD::FMINNUM ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE;

if (isOperationLegalOrCustom(NewOp, VT)) {
SDValue Quiet0 = Node->getOperand(0);
SDValue Quiet1 = Node->getOperand(1);
Expand Down Expand Up @@ -8493,6 +8496,9 @@ SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node,

SDValue TargetLowering::expandFMINIMUM_FMAXIMUM(SDNode *N,
SelectionDAG &DAG) const {
if (SDValue Expanded = expandVectorNaryOpBySplitting(N, DAG))
return Expanded;

SDLoc DL(N);
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
Expand Down Expand Up @@ -11920,3 +11926,35 @@ bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT,
}
return false;
}

SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
SelectionDAG &DAG) const {
EVT VT = Node->getValueType(0);
// Despite its documentation, GetSplitDestVTs will assert if VT cannot be
// split into two equal parts.
if (!VT.isVector() || !VT.getVectorElementCount().isKnownMultipleOf(2))
return SDValue();

// Restrict expansion to cases where both parts can be concatenated.
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VT);
if (LoVT != HiVT || !isTypeLegal(LoVT))
return SDValue();

SDLoc DL(Node);
unsigned Opcode = Node->getOpcode();

// Don't expand if the result is likely to be unrolled anyway.
if (!isOperationLegalOrCustomOrPromote(Opcode, LoVT))
return SDValue();

SmallVector<SDValue, 4> LoOps, HiOps;
for (const SDValue &V : Node->op_values()) {
auto [Lo, Hi] = DAG.SplitVector(V, DL, LoVT, HiVT);
LoOps.push_back(Lo);
HiOps.push_back(Hi);
}

SDValue SplitOpLo = DAG.getNode(Opcode, DL, LoVT, LoOps);
SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
}
30 changes: 30 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1663,12 +1663,42 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::FABS, VT, Legal);
setOperationAction(ISD::FNEG, VT, Legal);
setOperationAction(ISD::FP_EXTEND, VT, Custom);
setOperationAction(ISD::FP_ROUND, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);

if (Subtarget->hasSVEB16B16()) {
setOperationAction(ISD::FADD, VT, Legal);
setOperationAction(ISD::FMA, VT, Custom);
setOperationAction(ISD::FMAXIMUM, VT, Custom);
setOperationAction(ISD::FMAXNUM, VT, Custom);
setOperationAction(ISD::FMINIMUM, VT, Custom);
setOperationAction(ISD::FMINNUM, VT, Custom);
setOperationAction(ISD::FMUL, VT, Legal);
setOperationAction(ISD::FSUB, VT, Legal);
}
}

for (auto Opcode :
{ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC}) {
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
setOperationAction(Opcode, MVT::nxv8bf16, Expand);
}

if (!Subtarget->hasSVEB16B16()) {
for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
setOperationAction(Opcode, MVT::nxv8bf16, Expand);
}
}

setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,15 @@ let Predicates = [HasSVEorSME] in {
defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", AArch64fabs_mt>;
defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>;

foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in {
// No dedicated instruction, so just clear the sign bit.
def : Pat<(VT (fabs VT:$op)),
(AND_ZI $op, (i64 (logical_imm64_XFORM(i64 0x7fff7fff7fff7fff))))>;
// No dedicated instruction, so just invert the sign bit.
def : Pat<(VT (fneg VT:$op)),
(EOR_ZI $op, (i64 (logical_imm64_XFORM(i64 0x8000800080008000))))>;
}

// zext(cmpeq(x, splat(0))) -> cnot(x)
def : Pat<(nxv16i8 (zext (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)))),
(CNOT_ZPmZ_B $Op2, $Pg, $Op2)>;
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,8 @@ multiclass sve_fp_3op_u_zd_bfloat<bits<3> opc, string asm, SDPatternOperator op>
def NAME : sve_fp_3op_u_zd<0b00, opc, asm, ZPR16>;

def : SVE_2_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME)>;
def : SVE_2_Op_Pat<nxv4bf16, op, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME)>;
def : SVE_2_Op_Pat<nxv2bf16, op, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME)>;
}

multiclass sve_fp_3op_u_zd_ftsmul<bits<3> opc, string asm, SDPatternOperator op> {
Expand Down Expand Up @@ -9078,6 +9080,8 @@ multiclass sve_fp_bin_pred_bfloat<SDPatternOperator op> {
def _UNDEF : PredTwoOpPseudo<NAME, ZPR16, FalseLanesUndef>;

def : SVE_3_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, !cast<Pseudo>(NAME # _UNDEF)>;
def : SVE_3_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, !cast<Pseudo>(NAME # _UNDEF)>;
def : SVE_3_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, !cast<Pseudo>(NAME # _UNDEF)>;
}

// Predicated pseudo floating point three operand instructions.
Expand All @@ -9099,6 +9103,8 @@ multiclass sve_fp_3op_pred_bfloat<SDPatternOperator op> {
def _UNDEF : PredThreeOpPseudo<NAME, ZPR16, FalseLanesUndef>;

def : SVE_4_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME # _UNDEF)>;
def : SVE_4_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME # _UNDEF)>;
def : SVE_4_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME # _UNDEF)>;
}

// Predicated pseudo integer two operand instructions.
Expand Down
Loading

0 comments on commit 02dd6b1

Please sign in to comment.