diff --git a/llvm/include/llvm/IR/MatcherCast.h b/llvm/include/llvm/IR/MatcherCast.h new file mode 100644 index 000000000000..cd7efc994a9f --- /dev/null +++ b/llvm/include/llvm/IR/MatcherCast.h @@ -0,0 +1,67 @@ +#ifndef LLVM_IR_MATCHERCAST_H +#define LLVM_IR_MATCHERCAST_H + +//===- MatcherCast.h - Match on the LLVM IR --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Parameterized class hierachy for templatized pattern matching. +// +//===----------------------------------------------------------------------===// + + +namespace llvm { +namespace PatternMatch { + + +// type modification +template +struct MatcherCast { + using ActualCastType = DestClass; +}; + +// whether the Value \p Obj behaves like a \p Class. +template +bool match_isa(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return isa(Obj); +} + +template +auto match_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + +template +auto match_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_MATCHERCAST_H + diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index f9f4f1603861..2e8ec174a7dd 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -41,13 +41,68 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" +#include "llvm/IR/MatcherCast.h" + #include + namespace llvm { namespace PatternMatch { +// Use verbatim types in default (empty) context. +struct EmptyContext { + static constexpr bool IsEmpty = true; + + EmptyContext() {} + + EmptyContext(const Value *) {} + + EmptyContext(const EmptyContext & E) {} + + // reset this match context to be rooted at \p V + void reset(Value * V) {} + + // accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value * Val) const { return true; } + + // accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value * Val) const { return true; } + + // whether this context is compatiable with \p E. + bool acceptContext(EmptyContext E) const { return true; } + + // merge the context \p E into this context and return whether the resulting context is valid. + bool mergeContext(EmptyContext E) { return true; } + + // reset this context to \p Val. + template bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + // match in the current context + template bool try_match(Val *V, const Pattern &P) { + return const_cast(P).match_context(V, *this); + } +}; + +template +struct MatcherCast { using ActualCastType = DestClass; }; + + + + + + +// match without (== empty) context template bool match(Val *V, const Pattern &P) { - return const_cast(P).match(V); + EmptyContext ECtx; + return const_cast(P).match_context(V, ECtx); +} + +// match pattern in a given context +template bool match(Val *V, const Pattern &P, MatchContext & MContext) { + return const_cast(P).match_context(V, MContext); } template bool match(ArrayRef Mask, const Pattern &P) { @@ -60,7 +115,11 @@ template struct OneUse_match { OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} template bool match(OpTy *V) { - return V->hasOneUse() && SubPattern.match(V); + EmptyContext EContext; return match_context(V, EContext); + } + + template bool match_context(OpTy *V, MatchContext & MContext) { + return V->hasOneUse() && SubPattern.match_context(V, MContext); } }; @@ -69,7 +128,11 @@ template inline OneUse_match m_OneUse(const T &SubPattern) { } template struct class_match { - template bool match(ITy *V) { return isa(V); } + template bool match(ITy *V) { + EmptyContext EContext; return match_context(V, EContext); + } + template + bool match_context(ITy *V, MatchContext & MContext) { return match_isa(V); } }; /// Match an arbitrary value and ignore it. @@ -128,6 +191,10 @@ struct undef_match { return true; } template bool match(ITy *V) { return check(V); } + template + bool match_context(ITy *V, MatcherContext &MC) { + return check(V); + } }; /// Match an arbitrary undef constant. This matches poison as well. @@ -167,7 +234,14 @@ template struct match_unless { match_unless(const Ty &Matcher) : M(Matcher) {} - template bool match(ITy *V) { return !M.match(V); } + template + bool match_context(ITy *V, MatcherContext &MC) { + return !M.match_context(V, MC); + } + template bool match(ITy *V) { + EmptyContext EC; + return match_context(V, EC); + } }; /// Match if the inner matcher does *NOT* match. @@ -182,11 +256,17 @@ template struct match_combine_or { match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + MatchContext SubContext; + + if (L.match_context(V, SubContext) && MContext.acceptContext(SubContext)) { + MContext.mergeContext(SubContext); return true; - if (R.match(V)) + } + if (R.match_context(V, MContext)) { return true; + } return false; } }; @@ -197,9 +277,10 @@ template struct match_combine_and { match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) - if (R.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + if (L.match_context(V, MContext)) + if (R.match_context(V, MContext)) return true; return false; } @@ -224,7 +305,8 @@ struct apint_match { apint_match(const APInt *&Res, bool AllowUndef) : Res(Res), AllowUndef(AllowUndef) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValue(); return true; @@ -249,7 +331,8 @@ struct apfloat_match { apfloat_match(const APFloat *&Res, bool AllowUndef) : Res(Res), AllowUndef(AllowUndef) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValueAPF(); return true; @@ -300,7 +383,8 @@ inline apfloat_match m_APFloatForbidUndef(const APFloat *&Res) { } template struct constantint_match { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) { const APInt &CIV = CI->getValue(); if (Val >= 0) @@ -319,14 +403,15 @@ template inline constantint_match m_ConstantInt() { return constantint_match(); } -/// This helper class is used to match constant scalars, vector splats, -/// and fixed width vectors that satisfy a specified predicate. -/// For fixed width vector constants, undefined elements are ignored. +/// This helper class is used to match scalar and fixed width vector integer +/// constants that satisfy a specified predicate. +/// For vector constants, undefined elements are ignored. template struct cstval_pred_ty : public Predicate { - template bool match(ITy *V) { - if (const auto *CV = dyn_cast(V)) - return this->isValue(CV->getValue()); + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + if (const auto *CI = dyn_cast(V)) + return this->isValue(CI->getValue()); if (const auto *VTy = dyn_cast(V->getType())) { if (const auto *C = dyn_cast(V)) { if (const auto *CV = dyn_cast_or_null(C->getSplatValue())) @@ -374,7 +459,8 @@ template struct api_pred_ty : public Predicate { api_pred_ty(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -401,6 +487,12 @@ template struct apf_pred_ty : public Predicate { apf_pred_ty(const APFloat *&R) : Res(R) {} template bool match(ITy *V) { + EmptyContext Empty; + return match_context(V, Empty); + } + + template + bool match_context(ITy *V, MatchContext &MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -524,7 +616,8 @@ inline cst_pred_ty m_ZeroInt() { } struct is_zero { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { auto *C = dyn_cast(V); // FIXME: this should be able to do something for scalable vectors return C && (C->isNullValue() || cst_pred_ty().match(C)); @@ -709,8 +802,11 @@ template struct bind_ty { bind_ty(Class *&V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CV = dyn_cast(V)) { + if (!MContext.acceptBoundNode(V)) return false; + VR = CV; return true; } @@ -773,7 +869,8 @@ struct specificval_ty { specificval_ty(const Value *V) : Val(V) {} - template bool match(ITy *V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { return V == Val; } }; /// Match if we have a specific specified value. @@ -786,7 +883,8 @@ template struct deferredval_ty { deferredval_ty(Class *const &V) : Val(V) {} - template bool match(ITy *const V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *const V, MatchContext & MContext) { return V == Val; } }; /// Like m_Specific(), but works if the specific value to match is determined @@ -807,7 +905,8 @@ struct specific_fpval { specific_fpval(double V) : Val(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CFP = dyn_cast(V)) return CFP->isExactlyValue(Val); if (V->getType()->isVectorTy()) @@ -830,7 +929,8 @@ struct bind_const_intval_ty { bind_const_intval_ty(uint64_t &V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CV = dyn_cast(V)) if (CV->getValue().ule(UINT64_MAX)) { VR = CV->getZExtValue(); @@ -848,7 +948,8 @@ struct specific_intval { specific_intval(APInt V) : Val(std::move(V)) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { const auto *CI = dyn_cast(V); if (!CI && V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) @@ -886,7 +987,8 @@ struct specific_bbval { specific_bbval(BasicBlock *Val) : Val(Val) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EC; return match_context(V, EC); } + template bool match_context(ITy *V, MatchContext & MContext) { const auto *BB = dyn_cast(V); return BB && BB == Val; } @@ -918,11 +1020,16 @@ struct AnyBinaryOp_match { // The LHS is always matched first. AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (!I) return false; + + if (!MContext.acceptInnerNode(I)) return false; + + MatchContext LRContext(MContext); + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; return false; } }; @@ -941,9 +1048,15 @@ template struct AnyUnaryOp_match { AnyUnaryOp_match(const OP_t &X) : X(X) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return X.match(I->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (!I) return false; + + if (!MContext.acceptInnerNode(I)) return false; + + MatchContext XContext(MContext); + if (X.match_context(I->getOperand(0), XContext) && MContext.mergeContext(XContext)) return true; return false; } }; @@ -966,12 +1079,29 @@ struct BinaryOp_match { // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template inline bool match(unsigned Opc, OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opc) { - auto *I = cast(V); - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(unsigned Opc, OpTy *V) { + EmptyContext EContext; + return match_context(Opc, V, EContext); + } + template + bool match_context(OpTy *V, MatchContext &MContext) { + return match_context<>(Opcode, V, MContext); + } + template + bool match_context(unsigned Opc, OpTy *V, MatchContext &MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opc) { + MatchContext LRContext(MContext); + if (!MContext.acceptInnerNode(I)) + return false; + if (L.match_context(I->getOperand(0), LRContext) && + R.match_context(I->getOperand(1), LRContext) && + MContext.mergeContext(LRContext)) + return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && + R.match_context(I->getOperand(0), MContext))) + return true; + return false; } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opc && @@ -981,7 +1111,10 @@ struct BinaryOp_match { return false; } - template bool match(OpTy *V) { return match(Opcode, V); } + template bool match(OpTy *V) { + EmptyContext EC; + return match_context<>(Opcode, V, EC); + } }; template @@ -1012,25 +1145,26 @@ template struct FNeg_match { Op_t X; FNeg_match(const Op_t &Op) : X(Op) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { auto *FPMO = dyn_cast(V); if (!FPMO) return false; - if (FPMO->getOpcode() == Instruction::FNeg) + if (match_cast(V)->getOpcode() == Instruction::FNeg) return X.match(FPMO->getOperand(0)); - if (FPMO->getOpcode() == Instruction::FSub) { + if (match_cast(V)->getOpcode() == Instruction::FSub) { if (FPMO->hasNoSignedZeros()) { // With 'nsz', any zero goes. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } else { // Without 'nsz', we need fsub -0.0, X exactly. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } - return X.match(FPMO->getOperand(1)); + return X.match_context(FPMO->getOperand(1), MContext); } return false; @@ -1144,7 +1278,8 @@ struct OverflowingBinaryOp_match { OverflowingBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *Op = dyn_cast(V)) { if (Op->getOpcode() != Opcode) return false; @@ -1154,7 +1289,7 @@ struct OverflowingBinaryOp_match { if ((WrapFlags & OverflowingBinaryOperator::NoSignedWrap) && !Op->hasNoSignedWrap()) return false; - return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); + return L.match_context(Op->getOperand(0), MContext) && R.match_context(Op->getOperand(1), MContext); } return false; } @@ -1237,6 +1372,12 @@ struct SpecificBinaryOp_match template bool match(OpTy *V) { return BinaryOp_match::match(Opcode, V); } + + template + bool match_context(OpTy *V, MatchContext &MContext) { + return BinaryOp_match::match_context(Opcode, V, + MContext); + } }; /// Matches a specific opcode. @@ -1256,10 +1397,11 @@ struct BinOpPred_match : Predicate { BinOpPred_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return this->isOpType(I->getOpcode()) && L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) + return this->isOpType(I->getOpcode()) && L.match_context(I->getOperand(0), MContext) && + R.match_context(I->getOperand(1), MContext); if (auto *CE = dyn_cast(V)) return this->isOpType(CE->getOpcode()) && L.match(CE->getOperand(0)) && R.match(CE->getOperand(1)); @@ -1351,9 +1493,10 @@ template struct Exact_match { Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *PEO = dyn_cast(V)) - return PEO->isExact() && SubPattern.match(V); + return PEO->isExact() && SubPattern.match_context(V, MContext); return false; } }; @@ -1378,13 +1521,27 @@ struct CmpClass_match { CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) { + if (!MContext.acceptInnerNode(I)) + return false; + + MatchContext LRContext(MContext); + if (L.match_context(I->getOperand(0), LRContext) && + R.match_context(I->getOperand(1), LRContext) && + MContext.mergeContext(LRContext)) { Predicate = I->getPredicate(); return true; - } else if (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))) { + } + + if (!Commutable) + return false; + + MatchContext RLContext(MContext); + if (L.match_context(I->getOperand(1), RLContext) && + R.match_context(I->getOperand(0), RLContext) && + MContext.mergeContext(RLContext)) { Predicate = I->getSwappedPredicate(); return true; } @@ -1421,10 +1578,11 @@ template struct OneOps_match { OneOps_match(const T0 &Op1) : Op1(Op1) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext); } return false; } @@ -1437,10 +1595,12 @@ template struct TwoOps_match { TwoOps_match(const T0 &Op1, const T1 &Op2) : Op1(Op1), Op2(Op2) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext); } return false; } @@ -1456,11 +1616,13 @@ struct ThreeOps_match { ThreeOps_match(const T0 &Op1, const T1 &Op2, const T2 &Op3) : Op1(Op1), Op2(Op2), Op3(Op3) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Op3.match(I->getOperand(2)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext) && + Op3.match_context(I->getOperand(2), MContext); } return false; } @@ -1512,9 +1674,12 @@ template struct Shuffle_match { Shuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask) : Op1(Op1), Op2(Op2), Mask(Mask) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EC; return match_context(V, EC); } + template + bool match_context(OpTy *V, MatchContext &MC) { if (auto *I = dyn_cast(V)) { - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && + return Op1.match_context(I->getOperand(0), MC) && + Op2.match_context(I->getOperand(1), MC) && Mask.match(I->getShuffleMask()); } return false; @@ -1591,9 +1756,10 @@ template struct CastClass_match { CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} - template bool match(OpTy *V) { - if (auto *O = dyn_cast(V)) - return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto O = match_dyn_cast(V)) + return O->getOpcode() == Opcode && MContext.acceptInnerNode(O) && Op.match_context(O->getOperand(0), MContext); return false; } }; @@ -1707,8 +1873,9 @@ struct br_match { br_match(BasicBlock *&Succ) : Succ(Succ) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) if (BI->isUnconditional()) { Succ = BI->getSuccessor(0); return true; @@ -1728,10 +1895,12 @@ struct brc_match { brc_match(const Cond_t &C, const TrueBlock_t &t, const FalseBlock_t &f) : Cond(C), T(t), F(f) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) - if (BI->isConditional() && Cond.match(BI->getCondition())) - return T.match(BI->getSuccessor(0)) && F.match(BI->getSuccessor(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) + if (BI->isConditional() && Cond.match(BI->getCondition())) { + return T.match_context(BI->getSuccessor(0), MContext) && F.match_context(BI->getSuccessor(1), MContext); + } return false; } }; @@ -1764,7 +1933,8 @@ struct MaxMin_match { // The LHS is always matched first. MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *II = dyn_cast(V)) { Intrinsic::ID IID = II->getIntrinsicID(); if ((IID == Intrinsic::smax && Pred_t::match(ICmpInst::ICMP_SGT)) || @@ -1772,16 +1942,18 @@ struct MaxMin_match { (IID == Intrinsic::umax && Pred_t::match(ICmpInst::ICMP_UGT)) || (IID == Intrinsic::umin && Pred_t::match(ICmpInst::ICMP_ULT))) { Value *LHS = II->getOperand(0), *RHS = II->getOperand(1); + if (!MContext.acceptInnerNode(LHS) || !MContext.acceptInnerNode(RHS)) + return false; return (L.match(LHS) && R.match(RHS)) || (Commutable && L.match(RHS) && R.match(LHS)); } } // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". - auto *SI = dyn_cast(V); - if (!SI) + auto *SI = match_dyn_cast(V); + if (!SI || !MContext.acceptInnerNode(SI)) return false; - auto *Cmp = dyn_cast(SI->getCondition()); - if (!Cmp) + auto *Cmp = match_dyn_cast(SI->getCondition()); + if (!Cmp || !MContext.acceptInnerNode(Cmp)) return false; // At this point we have a select conditioned on a comparison. Check that // it is the values returned by the select that are being compared. @@ -1797,9 +1969,12 @@ struct MaxMin_match { // Does "(x pred y) ? x : y" represent the desired max/min operation? if (!Pred_t::match(Pred)) return false; + // It does! Bind the operands. - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + MatchContext LRContext(MContext); + if (L.match_context(LHS, LRContext) && R.match_context(RHS, LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(RHS, MContext) && R.match_context(LHS, MContext))) return true; + return false; } }; @@ -1968,7 +2143,8 @@ struct UAddWithOverflow_match { UAddWithOverflow_match(const LHS_t &L, const RHS_t &R, const Sum_t &S) : L(L), R(R), S(S) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { Value *ICmpLHS, *ICmpRHS; ICmpInst::Predicate Pred; if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) @@ -2034,9 +2210,10 @@ template struct Argument_match { Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // FIXME: Should likely be switched to use `CallBase`. - if (const auto *CI = dyn_cast(V)) + if (const auto *CI = match_dyn_cast(V)) return Val.match(CI->getArgOperand(OpI)); return false; } @@ -2054,8 +2231,9 @@ struct IntrinsicID_match { IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} - template bool match(OpTy *V) { - if (const auto *CI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (const auto *CI = match_dyn_cast(V)) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; @@ -2292,16 +2470,17 @@ template struct NotForbidUndef_match { ValTy Val; NotForbidUndef_match(const ValTy &V) : Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // We do not use m_c_Xor because that could match an arbitrary APInt that is // not -1 as C and then fail to match the other operand if it is -1. // This code should still work even when both operands are constants. Value *X; const APInt *C; - if (m_Xor(m_Value(X), m_APIntForbidUndef(C)).match(V) && C->isAllOnes()) - return Val.match(X); - if (m_Xor(m_APIntForbidUndef(C), m_Value(X)).match(V) && C->isAllOnes()) - return Val.match(X); + if (m_Xor(m_Value(X), m_APIntForbidUndef(C)).match_context(V, MContext) && C->isAllOnes()) + return Val.match_context(X, MContext); + if (m_Xor(m_APIntForbidUndef(C), m_Value(X)).match_context(V, MContext) && C->isAllOnes()) + return Val.match_context(X, MContext); return false; } }; @@ -2367,7 +2546,8 @@ template struct Signum_match { Opnd_t Val; Signum_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { unsigned TypeSize = V->getType()->getScalarSizeInBits(); if (TypeSize == 0) return false; @@ -2407,13 +2587,14 @@ template struct ExtractValue_match { Opnd_t Val; ExtractValue_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EC; return match_context(V, EC); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *I = dyn_cast(V)) { // If Ind is -1, don't inspect indices if (Ind != -1 && !(I->getNumIndices() == 1 && I->getIndices()[0] == (unsigned)Ind)) return false; - return Val.match(I->getAggregateOperand()); + return Val.match_context(I->getAggregateOperand(), MContext); } return false; } @@ -2440,9 +2621,12 @@ template struct InsertValue_match { InsertValue_match(const T0 &Op0, const T1 &Op1) : Op0(Op0), Op1(Op1) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EC; return match_context(V, EC); } + template + bool match_context(OpTy *V, MatchContext &MContext) { if (auto *I = dyn_cast(V)) { - return Op0.match(I->getOperand(0)) && Op1.match(I->getOperand(1)) && + return Op0.match_context(I->getOperand(0), MContext) && + Op1.match_context(I->getOperand(1), MContext) && I->getNumIndices() == 1 && Ind == I->getIndices()[0]; } return false; @@ -2464,7 +2648,8 @@ struct VScaleVal_match { const DataLayout &DL; VScaleVal_match(const DataLayout &DL) : DL(DL) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EC; return match_context(V, EC); } + template bool match_context(ITy *V, MatchContext & MContext) { if (m_Intrinsic().match(V)) return true; @@ -2496,7 +2681,10 @@ struct LogicalOp_match { LogicalOp_match(const LHS &L, const RHS &R) : L(L), R(R) {} template bool match(T *V) { - auto *I = dyn_cast(V); + EmptyContext EC; return match_context(V, EC); + } + template bool match_context(ITy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); if (!I || !I->getType()->isIntOrIntVectorTy(1)) return false; @@ -2507,21 +2695,25 @@ struct LogicalOp_match { (Commutable && L.match(Op1) && R.match(Op0)); } - if (auto *Select = dyn_cast(I)) { + if (auto *Select = match_dyn_cast(I)) { auto *Cond = Select->getCondition(); auto *TVal = Select->getTrueValue(); auto *FVal = Select->getFalseValue(); if (Opcode == Instruction::And) { auto *C = dyn_cast(FVal); if (C && C->isNullValue()) - return (L.match(Cond) && R.match(TVal)) || - (Commutable && L.match(TVal) && R.match(Cond)); + return (L.match_context(Cond, MContext) && + R.match_context(TVal, MContext)) || + (Commutable && L.match_context(TVal, MContext) && + R.match_context(Cond, MContext)); } else { assert(Opcode == Instruction::Or); auto *C = dyn_cast(TVal); if (C && C->isOneValue()) - return (L.match(Cond) && R.match(FVal)) || - (Commutable && L.match(FVal) && R.match(Cond)); + return (L.match_context(Cond, MContext) && + R.match_context(FVal, MContext)) || + (Commutable && L.match_context(FVal, MContext) && + R.match_context(Cond, MContext)); } } diff --git a/llvm/include/llvm/IR/PredicatedInst.h b/llvm/include/llvm/IR/PredicatedInst.h new file mode 100644 index 000000000000..9387272fcc38 --- /dev/null +++ b/llvm/include/llvm/IR/PredicatedInst.h @@ -0,0 +1,513 @@ +//===-- llvm/PredicatedInst.h - Predication utility subclass --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various classes for working with predicated instructions. +// Predicated instructions are either regular instructions or calls to +// Vector Predication (VP) intrinsics that have a mask and an explicit +// vector length argument. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_PREDICATEDINST_H +#define LLVM_IR_PREDICATEDINST_H + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MatcherCast.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" + +#include + +namespace llvm { + +class BasicBlock; + +class PredicatedInstruction : public User { +public: + // The PredicatedInstruction class is intended to be used as a utility, and is + // never itself instantiated. + PredicatedInstruction() = delete; + ~PredicatedInstruction() = delete; + + void copyIRFlags(const Value *V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } + + BasicBlock *getParent() { return cast(this)->getParent(); } + const BasicBlock *getParent() const { + return cast(this)->getParent(); + } + + void *operator new(size_t s) = delete; + + Value *getMaskParam() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getMaskParam(); + } + + Value *getVectorLengthParam() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getVectorLengthParam(); + } + + /// \returns True if the passed vector length value has no predicating effect + /// on the op. + bool canIgnoreVectorLengthParam() const; + + /// \return True if the static operator of this instruction has a mask or + /// vector length parameter. + bool isVectorPredicatedOp() const { return isa(this); } + + /// \returns the effective Opcode of this operation (ignoring the mask and + /// vector length param). + unsigned getOpcode() const { + auto *VPInst = dyn_cast(this); + + if (!VPInst) + return cast(this)->getOpcode(); + + auto OC = VPInst->getFunctionalOpcode(); + + return OC ? *OC : (unsigned) Instruction::Call; + } + + bool isVectorReduction() const; + + static bool classof(const Instruction *I) { return isa(I); } + static bool classof(const ConstantExpr *CE) { return false; } + static bool classof(const Value *V) { return isa(V); } + + /// Convenience function for getting all the fast-math flags, which must be an + /// operator which supports these flags. See LangRef.html for the meaning of + /// these flags. + FastMathFlags getFastMathFlags() const; +}; + +class PredicatedOperator : public User { +public: + // The PredicatedOperator class is intended to be used as a utility, and is + // never itself instantiated. + PredicatedOperator() = delete; + ~PredicatedOperator() = delete; + + void *operator new(size_t s) = delete; + + /// Return the opcode for this Instruction or ConstantExpr. + unsigned getOpcode() const { + auto *VPInst = dyn_cast(this); + + // Conceal the fp operation if it has non-default rounding mode or exception + // behavior + if (VPInst && !VPInst->isConstrainedOp()) { + auto OC = VPInst->getFunctionalOpcode(); + return OC ? *OC : (unsigned) Instruction::Call; + } + + if (const Instruction *I = dyn_cast(this)) + return I->getOpcode(); + + return cast(this)->getOpcode(); + } + + Value *getMask() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getMaskParam(); + } + + Value *getVectorLength() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getVectorLengthParam(); + } + + void copyIRFlags(const Value *V, bool IncludeWrapFlags = true); + FastMathFlags getFastMathFlags() const { + auto *I = dyn_cast(this); + if (I) + return I->getFastMathFlags(); + else + return FastMathFlags(); + } + + static bool classof(const Instruction *I) { + return isa(I) || isa(I); + } + static bool classof(const ConstantExpr *CE) { return isa(CE); } + static bool classof(const Value *V) { + return isa(V) || isa(V); + } +}; + +class PredicatedUnaryOperator : public PredicatedOperator { +public: + // The PredicatedUnaryOperator class is intended to be used as a utility, and + // is never itself instantiated. + PredicatedUnaryOperator() = delete; + ~PredicatedUnaryOperator() = delete; + + using UnaryOps = Instruction::UnaryOps; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->isUnaryOp(); + } + static bool classof(const ConstantExpr *CE) { + return isa(CE); + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + auto *CE = dyn_cast(V); + return CE && classof(CE); + } + + /// Construct a predicated binary instruction, given the opcode and the two + /// operands. + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + Instruction::UnaryOps Opc, Value *V, + const Twine &Name, BasicBlock *InsertAtEnd, + Instruction *InsertBefore); + + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + UnaryOps Opc, Value *V, + const Twine &Name = Twine(), + Instruction *InsertBefore = nullptr) { + return Create(Mod, Mask, VectorLen, Opc, V, Name, nullptr, + InsertBefore); + } + + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + UnaryOps Opc, Value *V, + const Twine &Name, BasicBlock *InsertAtEnd) { + return Create(Mod, Mask, VectorLen, Opc, V, Name, InsertAtEnd, + nullptr); + } + + static Instruction *CreateWithCopiedFlags(Module *Mod, Value *Mask, + Value *VectorLen, UnaryOps Opc, + Value *V, + Instruction *CopyBO, + const Twine &Name = "") { + Instruction *BO = + Create(Mod, Mask, VectorLen, Opc, V, Name, nullptr, nullptr); + BO->copyIRFlags(CopyBO); + return BO; + } +}; + +class PredicatedBinaryOperator : public PredicatedOperator { +public: + // The PredicatedBinaryOperator class is intended to be used as a utility, and + // is never itself instantiated. + PredicatedBinaryOperator() = delete; + ~PredicatedBinaryOperator() = delete; + + using BinaryOps = Instruction::BinaryOps; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->isBinaryOp(); + } + static bool classof(const ConstantExpr *CE) { + return isa(CE); + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + auto *CE = dyn_cast(V); + return CE && classof(CE); + } + + /// Construct a predicated binary instruction, given the opcode and the two + /// operands. + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + Instruction::BinaryOps Opc, Value *V1, Value *V2, + const Twine &Name, BasicBlock *InsertAtEnd, + Instruction *InsertBefore); + + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + BinaryOps Opc, Value *V1, Value *V2, + const Twine &Name = Twine(), + Instruction *InsertBefore = nullptr) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, + InsertBefore); + } + + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + BinaryOps Opc, Value *V1, Value *V2, + const Twine &Name, BasicBlock *InsertAtEnd) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, InsertAtEnd, + nullptr); + } + + static Instruction *CreateWithCopiedFlags(Module *Mod, Value *Mask, + Value *VectorLen, BinaryOps Opc, + Value *V1, Value *V2, + Instruction *CopyBO, + const Twine &Name = "") { + Instruction *BO = + Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, nullptr); + BO->copyIRFlags(CopyBO); + return BO; + } +}; + +class PredicatedICmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedICmpInst() = delete; + ~PredicatedICmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + if (!VPInst) + return false; + auto OC = VPInst->getFunctionalOpcode(); + return OC && (*OC == Instruction::ICmp); + } + static bool classof(const ConstantExpr *CE) { + return CE->getOpcode() == Instruction::ICmp; + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + auto *CE = dyn_cast(V); + return CE && classof(CE); + } + + ICmpInst::Predicate getPredicate() const { + auto *ICInst = dyn_cast(this); + if (ICInst) + return ICInst->getPredicate(); + auto *CE = dyn_cast(this); + if (CE) + return static_cast(CE->getPredicate()); + return static_cast( + cast(this)->getCmpPredicate()); + } +}; + +class PredicatedFCmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedFCmpInst() = delete; + ~PredicatedFCmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + if (!VPInst) + return false; + auto OC = VPInst->getFunctionalOpcode(); + return OC && (*OC == Instruction::FCmp); + } + static bool classof(const ConstantExpr *CE) { + return CE->getOpcode() == Instruction::FCmp; + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + return isa(V); + } + + FCmpInst::Predicate getPredicate() const { + auto *FCInst = dyn_cast(this); + if (FCInst) + return FCInst->getPredicate(); + auto *CE = dyn_cast(this); + if (CE) + return static_cast(CE->getPredicate()); + return static_cast( + cast(this)->getCmpPredicate()); + } +}; + +class PredicatedSelectInst : public PredicatedOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedSelectInst() = delete; + ~PredicatedSelectInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + if (!VPInst) + return false; + auto OC = VPInst->getFunctionalOpcode(); + return OC && (*OC == Instruction::Select); + } + static bool classof(const ConstantExpr *CE) { + return CE->getOpcode() == Instruction::Select; + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + auto *CE = dyn_cast(V); + return CE && CE->getOpcode() == Instruction::Select; + } + + const Value *getCondition() const { return getOperand(0); } + const Value *getTrueValue() const { return getOperand(1); } + const Value *getFalseValue() const { return getOperand(2); } + Value *getCondition() { return getOperand(0); } + Value *getTrueValue() { return getOperand(1); } + Value *getFalseValue() { return getOperand(2); } + + void setCondition(Value *V) { setOperand(0, V); } + void setTrueValue(Value *V) { setOperand(1, V); } + void setFalseValue(Value *V) { setOperand(2, V); } +}; + +namespace PatternMatch { + +// PredicatedMatchContext for pattern matching +struct PredicatedContext { + static constexpr bool IsEmpty = false; + + Value *Mask; + Value *VectorLength; + Module *Mod; + + void reset(Value *V) { + auto *PI = dyn_cast(V); + if (!PI) { + VectorLength = nullptr; + Mask = nullptr; + return; + } + VectorLength = PI->getVectorLengthParam(); + Mask = PI->getMaskParam(); + + if (Mod) return; + + // try to get a hold of the Module + auto *BB = PI->getParent(); + if (BB) { + auto *Func = BB->getParent(); + if (Func) { + Mod = Func->getParent(); + } + } + + if (Mod) return; + + // try to infer the module from a call + auto CallI = dyn_cast(V); + if (CallI && CallI->getCalledFunction()) { + Mod = CallI->getCalledFunction()->getParent(); + } + } + + PredicatedContext(Value *Val) + : Mask(nullptr), VectorLength(nullptr), Mod(nullptr) { + reset(Val); + } + + PredicatedContext(const PredicatedContext &PC) + : Mask(PC.Mask), VectorLength(PC.VectorLength), Mod(PC.Mod) {} + + /// accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value *Val) const { + auto PredI = dyn_cast(Val); + if (!PredI) + return VectorLength == nullptr && Mask == nullptr; + return VectorLength == PredI->getVectorLengthParam() && + Mask == PredI->getMaskParam(); + } + + /// accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value *Val) const { return true; } + + /// whether this context is compatiable with \p E. + bool acceptContext(PredicatedContext PC) const { + return std::tie(PC.Mask, PC.VectorLength) == std::tie(Mask, VectorLength); + } + + /// merge the context \p E into this context and return whether the resulting + /// context is valid. + bool mergeContext(PredicatedContext PC) const { return acceptContext(PC); } + + /// match \p P in a new contesx for \p Val. + template + bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + /// match \p P in the current context. + template + bool try_match(Val *V, const Pattern &P) { + PredicatedContext SubContext(*this); + return const_cast(P).match_context(V, SubContext); + } +}; + +struct PredicatedContext; +template <> struct MatcherCast { + using ActualCastType = PredicatedBinaryOperator; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedOperator; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedICmpInst; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedFCmpInst; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedSelectInst; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedInstruction; +}; + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_PREDICATEDINST_H diff --git a/llvm/include/llvm/IR/VPBuilder.h b/llvm/include/llvm/IR/VPBuilder.h index 6de1317f7934..9a8b5fb9ffbe 100644 --- a/llvm/include/llvm/IR/VPBuilder.h +++ b/llvm/include/llvm/IR/VPBuilder.h @@ -5,6 +5,7 @@ #include #include #include +#include #include namespace llvm { @@ -71,6 +72,124 @@ class VPBuilder { Value& CreateGather(Type *ReturnTy, Value & PointerVec, MaybeAlign Alignment); }; + + + + +namespace PatternMatch { + // Factory class to generate instructions in a context + template + class MatchContextBuilder { + public: + // MatchContextBuilder(MatcherContext MC); + }; + + +// Context-free instruction builder +template<> +class MatchContextBuilder { +public: + MatchContextBuilder(EmptyContext & EC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Instruction *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const { \ + auto * Inst = BinaryOperator::Create(Instruction::OPC, V1, V2, Name); \ + Builder.Insert(Inst); return Inst; \ + } \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, BB);\ + } \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, I);\ + } \ + Instruction *Create##OPC##FMF(Value *V1, Value *V2, Instruction *FMFSource, \ + const Twine &Name = "") const {\ + return BinaryOperator::CreateWithCopiedFlags(Instruction::OPC, V1, V2, FMFSource, Name);\ + } \ + template \ + Instruction *Create##OPC##FMF(IRBuilderType& Builder, Value *V1, Value *V2, Instruction *FMFSource, \ + const Twine &Name = "") const {\ + auto * Inst = BinaryOperator::CreateWithCopiedFlags(Instruction::OPC, V1, V2, FMFSource, Name);\ + Builder.Insert(Inst); return Inst; \ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + UnaryOperator *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + return UnaryOperator::CreateFNegFMF(Op, FMFSource, Name); + } + + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + + + +// Context-free instruction builder +template<> +class MatchContextBuilder { + PredicatedContext & PC; +public: + MatchContextBuilder(PredicatedContext & PC) : PC(PC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Instruction *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + auto * PredInst = Create##OPC(V1, V2, Name); \ + Builder.Insert(PredInst); \ + return PredInst; \ + } \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, BB);\ + } \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, I);\ + } \ + Instruction *Create##OPC##FMF(Value *V1, Value *V2, Instruction *FMFSource, \ + const Twine &Name = "") const {\ + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, FMFSource, Name);\ + } \ + template \ + Instruction *Create##OPC##FMF(IRBuilderType& Builder, Value *V1, Value *V2, Instruction *FMFSource, \ + const Twine &Name = "") const {\ + auto * Inst = PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, FMFSource, Name);\ + Builder.Insert(Inst); return Inst; \ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + Instruction *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedUnaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FNeg, Op, FMFSource, Name); + } + + // TODO predicated casts + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + +} + } // namespace llvm #endif // LLVM_IR_VPBUILDER_H diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 77d600ac5883..85f1e876fa31 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -39,6 +39,8 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/IR/ValueHandle.h" #include "llvm/Support/KnownBits.h" #include using namespace llvm; @@ -5144,11 +5146,12 @@ SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// Given operands for an FSub, see if we can fold the result. If not, this /// returns null. -static Value * -SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse, +template +static Value *SimplifyFSubInstGeneric(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse, MatchContext & MC, fp::ExceptionBehavior ExBehavior = fp::ebIgnore, RoundingMode Rounding = RoundingMode::NearestTiesToEven) { + if (isDefaultFPEnvironment(ExBehavior, Rounding)) if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) return C; @@ -5160,7 +5163,7 @@ SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (canIgnoreSNaN(ExBehavior, FMF) && (!canRoundingModeBe(Rounding, RoundingMode::TowardNegative) || FMF.noSignedZeros())) - if (match(Op1, m_PosZeroFP())) + if (MC.try_match(Op1, m_PosZeroFP())) return Op0; // fsub X, -0 ==> X, when we know X is not -0 @@ -5172,18 +5175,23 @@ SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (!isDefaultFPEnvironment(ExBehavior, Rounding)) return nullptr; + // fsub X, -0 ==> X, when we know X is not -0 + if (MC.try_match(Op1, m_NegZeroFP()) && + (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) + return Op0; + // fsub -0.0, (fsub -0.0, X) ==> X // fsub -0.0, (fneg X) ==> X Value *X; - if (match(Op0, m_NegZeroFP()) && - match(Op1, m_FNeg(m_Value(X)))) + if (MC.try_match(Op0, m_NegZeroFP()) && + MC.try_match(Op1, m_FNeg(m_Value(X)))) return X; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. // fsub 0.0, (fneg X) ==> X if signed zeros are ignored. if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && - (match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) || - match(Op1, m_FNeg(m_Value(X))))) + (MC.try_match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) || + MC.try_match(Op1, m_FNeg(m_Value(X))))) return X; // fsub nnan x, x ==> 0.0 @@ -5193,8 +5201,8 @@ SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, // Y - (Y - X) --> X // (X + Y) - Y --> X if (FMF.noSignedZeros() && FMF.allowReassoc() && - (match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || - match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) + (MC.try_match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || + MC.try_match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) return X; return nullptr; @@ -5260,14 +5268,32 @@ Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, Rounding); } +/// Given operands for an FSub, see if we can fold the result. +static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) + return C; + + EmptyContext EC; + return SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, EC); +} + Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, fp::ExceptionBehavior ExBehavior, RoundingMode Rounding) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + EmptyContext EC; + return SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, EC, ExBehavior, Rounding); } +Value *llvm::SimplifyPredicatedFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, PredicatedContext & PC, + fp::ExceptionBehavior ExBehavior, + RoundingMode Rounding) { + return ::SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, PC); +} + Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, fp::ExceptionBehavior ExBehavior, @@ -6228,6 +6254,18 @@ static Value *SimplifyLoadInst(LoadInst *LI, Value *PtrOp, return nullptr; } +Value *llvm::SimplifyVPIntrinsic(VPIntrinsic & VPInst, const SimplifyQuery &Q) { + PredicatedContext PC(&VPInst); + + auto & PI = cast(VPInst); + switch (PI.getOpcode()) { + default: + return nullptr; + + case Instruction::FSub: return SimplifyPredicatedFSubInst(VPInst.getOperand(0), VPInst.getOperand(1), VPInst.getFastMathFlags(), Q, PC); + } +} + /// See if we can compute a simplified version of this instruction. /// If not, this returns null. @@ -6358,6 +6396,12 @@ static Value *simplifyInstructionWithOperands(Instruction *I, Result = SimplifyPHINode(cast(I), NewOps, Q); break; case Instruction::Call: { + auto * VPInst = dyn_cast(I); + if (VPInst) { + Result = SimplifyVPIntrinsic(*VPInst, Q); + if (Result) break; + } + // TODO: Use NewOps Result = SimplifyCall(cast(I), Q); break; diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 8e383ce85cb7..6e0d7540c037 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -466,6 +466,7 @@ namespace { SDValue visitFREEZE(SDNode *N); SDValue visitBUILD_PAIR(SDNode *N); SDValue visitFADD(SDNode *N); + SDValue visitFADD_VP(SDNode *N); SDValue visitSTRICT_FADD(SDNode *N); SDValue visitFSUB(SDNode *N); SDValue visitFMUL(SDNode *N); @@ -513,6 +514,7 @@ namespace { SDValue visitVECREDUCE(SDNode *N); SDValue visitVPOp(SDNode *N); + template SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); @@ -819,6 +821,161 @@ class WorklistInserter : public SelectionDAG::DAGUpdateListener { void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); } }; +struct EmptyMatchContext { + SelectionDAG & DAG; + + EmptyMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + {} + + bool match(SDValue OpN, unsigned OpCode) const { return OpCode == OpN->getOpcode(); } + + unsigned getFunctionOpCode(SDValue N) const { + return N->getOpcode(); + } + + bool isCompatible(SDValue OpVal) const { return true; } + + // Specialize based on number of operands. + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + Optional Flags = None) { + if (Flags) + return DAG.getNode(Opcode, DL, VT, Operand, *Flags); + else + return DAG.getNode(Opcode, DL, VT, Operand); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, Optional Flags = None) { + if (Flags) + return DAG.getNode(Opcode, DL, VT, N1, N2, *Flags); + else + return DAG.getNode(Opcode, DL, VT, N1, N2); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + Optional Flags = None) { + if (Flags) + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, *Flags); + else + return DAG.getNode(Opcode, DL, VT, N1, N2, N3); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4, SDValue N5) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4, N5); + } +}; + +struct +VPMatchContext { + SelectionDAG & DAG; + SDNode * Root; + SDValue RootMaskOp; + SDValue RootVectorLenOp; + + VPMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + , Root(Root) + , RootMaskOp() + , RootVectorLenOp() + { + if (Root->isVP()) { + auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()); + if (RootMaskPos) { + RootMaskOp = Root->getOperand(*RootMaskPos); + } + + auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()); + if (RootVLenPos) { + RootVectorLenOp = Root->getOperand(*RootVLenPos); + } + } + } + + unsigned getFunctionOpCode(SDValue N) const { + unsigned VPOpCode = N->getOpcode(); + auto FuncOpc = ISD::GetFunctionOpCodeForVP(VPOpCode, !N->getFlags().hasNoFPExcept()); + if (!FuncOpc) return VPOpCode; + return *FuncOpc; + } + + bool isCompatible(SDValue OpVal) const { + if (!OpVal->isVP()) { + return !Root->isVP(); + + } else { + unsigned VPOpCode = OpVal->getOpcode(); + auto MaskPos = ISD::getVPMaskIdx(VPOpCode); + if (MaskPos && RootMaskOp != OpVal.getOperand(*MaskPos)) { + return false; + } + + auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpCode); + if (VLenPos && RootVectorLenOp != OpVal.getOperand(*VLenPos)) { + return false; + } + + return true; + } + } + + /// whether \p OpN is a node that is functionally compatible with the NodeType \p OpNodeTy + bool match(SDValue OpVal, unsigned OpNT) const { + return isCompatible(OpVal) && getFunctionOpCode(OpVal) == OpNT; + } + + // Specialize based on number of operands. + // TODO emit VP intrinsics where MaskOp/VectorLenOp != null + // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + Optional Flags = None) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = *ISD::getVPMaskIdx(VPOpcode); + int VLenPos = *ISD::getVPExplicitVectorLengthIdx(VPOpcode); + assert(MaskPos == 1 && VLenPos == 2); + + if (Flags) + return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}, *Flags); + else + return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, Optional Flags = None) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = *ISD::getVPMaskIdx(VPOpcode); + int VLenPos = *ISD::getVPExplicitVectorLengthIdx(VPOpcode); + assert(MaskPos == 2 && VLenPos == 3); + + if (Flags) + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}, *Flags); + else + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + Optional Flags = None) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = *ISD::getVPMaskIdx(VPOpcode); + int VLenPos = *ISD::getVPExplicitVectorLengthIdx(VPOpcode); + assert(MaskPos == 3 && VLenPos == 4); + + if (Flags) + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3, RootMaskOp, RootVectorLenOp}, *Flags); + else + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3, RootMaskOp, RootVectorLenOp}); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -13699,12 +13856,16 @@ static bool hasNoInfs(const TargetOptions &Options, SDValue N) { } /// Try to perform FMA combining on a given FADD node. +template SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); + MatchContextClass matcher(DAG, N); + if (!matcher.isCompatible(N0) || !matcher.isCompatible(N1)) return SDValue(); + const TargetOptions &Options = DAG.getTarget().Options; // Floating-point multiply-add with intermediate rounding. @@ -13735,14 +13896,13 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { bool Aggressive = TLI.enableAggressiveFMAFusion(VT); auto isFusedOp = [&](SDValue N) { - unsigned Opcode = N.getOpcode(); - return Opcode == ISD::FMA || Opcode == ISD::FMAD; + return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD); }; // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) + auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { + if (!matcher.match(N, ISD::FMUL)) return false; return AllowFusionGlobally || N->getFlags().hasAllowContract(); }; @@ -13755,15 +13915,15 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), - N0.getOperand(1), N1); + return matcher.getNode(PreferredFusedOpcode, SL, VT, + N0.getOperand(0), N0.getOperand(1), N1); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), - N1.getOperand(1), N0); + return matcher.getNode(PreferredFusedOpcode, SL, VT, + N1.getOperand(0), N1.getOperand(1), N0); } // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E) @@ -13793,29 +13953,31 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) - if (N0.getOpcode() == ISD::FP_EXTEND) { + if ((N0.getOpcode() == ISD::FP_EXTEND) && matcher.isCompatible(N0.getOperand(0))) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), - N1); + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(1)), N1); } } // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), - N0); + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, + N10.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, + N10.getOperand(1)), N0); } } @@ -13823,17 +13985,17 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { if (Aggressive) { // fold (fadd (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y, (fma (fpext u), (fpext v), z)) - auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V, - SDValue Z) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + auto FoldFAddFMAFPExtFMul = [&] ( + SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { + return matcher.getNode(PreferredFusedOpcode, SL, VT, X, Y, + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z)); }; if (isFusedOp(N0)) { SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableFMUL(N020) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, @@ -13850,14 +14012,15 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { // FIXME: This turns two single-precision and one double-precision // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. - auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V, - SDValue Z) { - return DAG.getNode( - PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X), - DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z)); + auto FoldFAddFPExtFMAFMul = [&] ( + SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, X), + matcher.getNode(ISD::FP_EXTEND, SL, VT, Y), + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), + Z)); }; if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); @@ -14320,6 +14483,17 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitFADD_VP(SDNode *N) { + SelectionDAG::FlagInserter FlagsInserter(DAG, N); + + // FADD -> FMA combines: + if (SDValue Fused = visitFADDForFMACombine(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); +} + SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -14494,7 +14668,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { } // enable-unsafe-fp-math // FADD -> FMA combines: - if (SDValue Fused = visitFADDForFMACombine(N)) { + if (SDValue Fused = visitFADDForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -22854,8 +23028,16 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) { ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode()); // This is the only generic VP combine we support for now. - if (!AreAllEltsDisabled) + if (!AreAllEltsDisabled) { + // FMA fusion. + switch (N->getOpcode()) { + default: + break; + case ISD::VP_FADD: + return visitFADD_VP(N); + } return SDValue(); + } // Binary operations can be replaced by UNDEF. if (ISD::isVPBinaryOp(N->getOpcode())) diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt index 38f520a80926..10b332050191 100644 --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -48,6 +48,7 @@ add_llvm_component_library(LLVMCore PassManager.cpp PassRegistry.cpp PassTimingInfo.cpp + PredicatedInst.cpp PrintPasses.cpp SafepointIRVerifier.cpp ProfileSummary.cpp diff --git a/llvm/lib/IR/PredicatedInst.cpp b/llvm/lib/IR/PredicatedInst.cpp new file mode 100644 index 000000000000..6288c2060bae --- /dev/null +++ b/llvm/lib/IR/PredicatedInst.cpp @@ -0,0 +1,181 @@ +#include +#include +#include +#include +#include + +namespace { +using namespace llvm; +using ShortValueVec = SmallVector; +} // namespace + +namespace llvm { + +bool PredicatedInstruction::canIgnoreVectorLengthParam() const { + auto VPI = dyn_cast(this); + if (!VPI) + return true; + + return VPI->canIgnoreVectorLengthParam(); +} + +FastMathFlags PredicatedInstruction::getFastMathFlags() const { + return cast(this)->getFastMathFlags(); +} + +void PredicatedOperator::copyIRFlags(const Value *V, bool IncludeWrapFlags) { + auto *I = dyn_cast(this); + if (I) + I->copyIRFlags(V, IncludeWrapFlags); +} + +bool +PredicatedInstruction::isVectorReduction() const { + auto VPI = dyn_cast(this); + if (VPI) + return isa(VPI); + auto II = dyn_cast(this); + if (!II) return false; + + switch (II->getIntrinsicID()) { + default: + return false; + + case Intrinsic::vector_reduce_add: + case Intrinsic::vector_reduce_mul: + case Intrinsic::vector_reduce_and: + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_xor: + case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: + case Intrinsic::vector_reduce_umin: + case Intrinsic::vector_reduce_umax: + case Intrinsic::vector_reduce_fadd: + case Intrinsic::vector_reduce_fmul: + case Intrinsic::vector_reduce_fmin: + case Intrinsic::vector_reduce_fmax: + return true; + } +} + +Instruction *PredicatedUnaryOperator::Create( + Module *Mod, Value *Mask, Value *VectorLen, Instruction::UnaryOps Opc, + Value *V, const Twine &Name, BasicBlock *InsertAtEnd, + Instruction *InsertBefore) { + assert(!(InsertAtEnd && InsertBefore)); + auto VPID = VPIntrinsic::getForOpcode(Opc); + + // Default Code Path + if ((!Mod || (!Mask && !VectorLen)) || VPID == Intrinsic::not_intrinsic) { + if (InsertAtEnd) { + return UnaryOperator::Create(Opc, V, Name, InsertAtEnd); + } else { + return UnaryOperator::Create(Opc, V, Name, InsertBefore); + } + } + + assert(Mod && "Need a module to emit VP Intrinsics"); + + // Fetch the VP intrinsic + auto &VecTy = cast(*V->getType()); + auto *VPFunc = + VPIntrinsic::getDeclarationForParams(Mod, VPID, &VecTy, {V}); + + // Encode default environment fp behavior + +#if 0 + // TODO + LLVMContext &Ctx = V1->getContext(); + SmallVector ConstraintBundles; + if (VPIntrinsic::HasRoundingMode(VPID)) + ConstraintBundles.emplace_back( + "cfp-round", + GetConstrainedFPRounding(Ctx, RoundingMode::NearestTiesToEven)); + if (VPIntrinsic::HasExceptionMode(VPID)) + ConstraintBundles.emplace_back( + "cfp-except", + GetConstrainedFPExcept(Ctx, fp::ExceptionBehavior::ebIgnore)); + + CallInst *CI; + if (InsertAtEnd) { + CI = CallInst::Create(VPFunc, BinOpArgs, ConstraintBundles, Name, InsertAtEnd); + } else { + CI = CallInst::Create(VPFunc, BinOpArgs, ConstraintBundles, Name, InsertBefore); + } +#endif + + CallInst *CI; + SmallVector UnOpArgs({V, Mask, VectorLen}); + if (InsertAtEnd) { + CI = CallInst::Create(VPFunc, UnOpArgs, Name, InsertAtEnd); + } else { + CI = CallInst::Create(VPFunc, UnOpArgs, Name, InsertBefore); + } + + // the VP inst does not touch memory if the exception behavior is + // "fpecept.ignore" + CI->setDoesNotAccessMemory(); + return CI; +} + +Instruction *PredicatedBinaryOperator::Create( + Module *Mod, Value *Mask, Value *VectorLen, Instruction::BinaryOps Opc, + Value *V1, Value *V2, const Twine &Name, BasicBlock *InsertAtEnd, + Instruction *InsertBefore) { + assert(!(InsertAtEnd && InsertBefore)); + auto VPID = VPIntrinsic::getForOpcode(Opc); + + // Default Code Path + if ((!Mod || (!Mask && !VectorLen)) || VPID == Intrinsic::not_intrinsic) { + if (InsertAtEnd) { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertAtEnd); + } else { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertBefore); + } + } + + assert(Mod && "Need a module to emit VP Intrinsics"); + + // Fetch the VP intrinsic + auto &VecTy = cast(*V1->getType()); + auto *VPFunc = + VPIntrinsic::getDeclarationForParams(Mod, VPID, &VecTy, {V1, V2}); + + // Encode default environment fp behavior + +#if 0 + // TODO + LLVMContext &Ctx = V1->getContext(); + SmallVector ConstraintBundles; + if (VPIntrinsic::HasRoundingMode(VPID)) + ConstraintBundles.emplace_back( + "cfp-round", + GetConstrainedFPRounding(Ctx, RoundingMode::NearestTiesToEven)); + if (VPIntrinsic::HasExceptionMode(VPID)) + ConstraintBundles.emplace_back( + "cfp-except", + GetConstrainedFPExcept(Ctx, fp::ExceptionBehavior::ebIgnore)); + + CallInst *CI; + if (InsertAtEnd) { + CI = CallInst::Create(VPFunc, BinOpArgs, ConstraintBundles, Name, InsertAtEnd); + } else { + CI = CallInst::Create(VPFunc, BinOpArgs, ConstraintBundles, Name, InsertBefore); + } +#endif + + CallInst *CI; + SmallVector BinOpArgs({V1, V2, Mask, VectorLen}); + if (InsertAtEnd) { + CI = CallInst::Create(VPFunc, BinOpArgs, Name, InsertAtEnd); + } else { + CI = CallInst::Create(VPFunc, BinOpArgs, Name, InsertBefore); + } + + // the VP inst does not touch memory if the exception behavior is + // "fpecept.ignore" + CI->setDoesNotAccessMemory(); + return CI; +} + +} // namespace llvm diff --git a/llvm/lib/IR/VPBuilder.cpp b/llvm/lib/IR/VPBuilder.cpp index 728132fada9a..bee80fb3ad1f 100644 --- a/llvm/lib/IR/VPBuilder.cpp +++ b/llvm/lib/IR/VPBuilder.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include namespace { diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h index d135d8b97068..9508a8bc6b77 100644 --- a/llvm/lib/Target/VE/VETargetTransformInfo.h +++ b/llvm/lib/Target/VE/VETargetTransformInfo.h @@ -22,6 +22,7 @@ #include "llvm/CodeGen/BasicTTIImpl.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/Type.h" // Penalty cost factor to make vectorization unappealing (see @@ -284,20 +285,25 @@ class VETTIImpl : public BasicTTIImplBase { TargetTransformInfo::VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const { using VPTransform = TargetTransformInfo::VPLegalization; + auto &PI = cast(VPI); return TargetTransformInfo::VPLegalization( /* EVLParamStrategy */ VPTransform::Legal, - /* OperatorStrategy */ supportsVPOperation(VPI) ? VPTransform::Legal + /* OperatorStrategy */ supportsVPOperation(PI) ? VPTransform::Legal : VPTransform::Convert); } /// \returns False if this VP op should be replaced by a non-VP op or an /// unpredicated op plus a select. - bool supportsVPOperation(const VPIntrinsic &VPI) const { + bool supportsVPOperation(const PredicatedInstruction &PredInst) const { if (!enableVPU()) return false; + auto VPI = dyn_cast(&PredInst); + if (!VPI) + return true; + // Cannot be widened into a legal VVP op - auto EC = VPI.getStaticVectorLength(); + auto EC = VPI->getStaticVectorLength(); if (EC.isScalable()) return false; @@ -306,15 +312,12 @@ class VETTIImpl : public BasicTTIImplBase { // Bail on yet-unimplemented reductions if (isa(VPI)) { - auto FPRed = dyn_cast(&VPI); - bool Unordered = FPRed ? VPI.getFastMathFlags().allowReassoc() : true; - return isSupportedReduction(VPI.getIntrinsicID(), Unordered); + auto FPRed = dyn_cast(VPI); + bool Unordered = FPRed ? VPI->getFastMathFlags().allowReassoc() : true; + return isSupportedReduction(VPI->getIntrinsicID(), Unordered); } - Optional OpCodeOpt = VPI.getFunctionalOpcode(); - unsigned OpCode = OpCodeOpt ? *OpCodeOpt : Instruction::Call; - - switch (OpCode) { + switch (PredInst.getOpcode()) { default: break; @@ -326,29 +329,30 @@ class VETTIImpl : public BasicTTIImplBase { // Non-opcode VP ops case Instruction::Call: // vp mask operations unsupported - if (isa(VPI)) - return !VPI.getType()->isIntOrIntVectorTy(1); + if (PredInst.isVectorReduction()) + return !PredInst.getType()->isIntOrIntVectorTy(1); break; // TODO mask scatter&gather // vp mask load/store unsupported (FIXME) case Instruction::Load: - return !IsMaskType(VPI.getType()); + return !IsMaskType(PredInst.getType()); case Instruction::Store: - return !IsMaskType(VPI.getOperand(0)->getType()); + return !IsMaskType(PredInst.getOperand(0)->getType()); // vp mask operations unsupported case Instruction::And: case Instruction::Or: case Instruction::Xor: - auto ITy = VPI.getType(); + auto ITy = PredInst.getType(); if (!ITy->isVectorTy()) break; if (!ITy->isIntOrIntVectorTy(1)) break; return false; } + // be optimistic by default return true; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index de79fe0ec507..30186f85264b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -24,6 +24,9 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/IR/VPBuilder.h" +#include "llvm/IR/MatcherCast.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/AlignOf.h" @@ -2186,6 +2189,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { /// This eliminates floating-point negation in either 'fneg(X)' or /// 'fsub(-0.0, X)' form by combining into a constant operand. +template static Instruction *foldFNegIntoConstant(Instruction &I) { // This is limited with one-use because fneg is assumed better for // reassociation and cheaper in codegen than fmul/fdiv. @@ -2197,17 +2201,20 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { Value *X; Constant *C; + MatchContextType MC(cast(&I)); + MatchContextBuilder MCBuilder(MC); + // Fold negation into constant operand. // -(X * C) --> X * (-C) - if (match(FNegOp, m_FMul(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(FNegOp, m_FMul(m_Value(X), m_Constant(C)))) + return MCBuilder.CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); // -(X / C) --> X / (-C) - if (match(FNegOp, m_FDiv(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(FNegOp, m_FDiv(m_Value(X), m_Constant(C)))) + return MCBuilder.CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); // -(C / X) --> (-C) / X - if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) { + if (MC.try_match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) { Instruction *FDiv = - BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + MCBuilder.CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); // Intersect 'nsz' and 'ninf' because those special value exceptions may not // apply to the fdiv. Everything else propagates from the fneg. @@ -2220,8 +2227,8 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { } // With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]: // -(X + C) --> -X + -C --> -C - X - if (I.hasNoSignedZeros() && match(FNegOp, m_FAdd(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I); + if (I.hasNoSignedZeros() && MC.try_match(FNegOp, m_FAdd(m_Value(X), m_Constant(C)))) + return MCBuilder.CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I); return nullptr; } @@ -2249,7 +2256,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { getSimplifyQuery().getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldFNegIntoConstant(I)) + if (Instruction *X = foldFNegIntoConstant(I)) return X; Value *X, *Y; @@ -2295,6 +2302,17 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { return nullptr; } +Instruction *InstCombinerImpl::visitPredicatedFSub(PredicatedBinaryOperator& I) { + auto * Inst = cast(&I); + PredicatedContext PC(&I); + if (Value *V = SimplifyPredicatedFSubInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(Inst), PC)) + return replaceInstUsesWith(*Inst, V); + + return visitFSubGeneric(*Inst); +} + Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -2307,6 +2325,14 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { if (Instruction *Phi = foldBinopWithPhiOperands(I)) return Phi; + return visitFSubGeneric(I); +} + +template +Instruction *InstCombinerImpl::visitFSubGeneric(BinaryOpTy &I) { + MatchContextType MC(cast(&I)); + MatchContextBuilder MCBuilder(MC); + // Subtraction from -0.0 is the canonical form of fneg. // fsub -0.0, X ==> fneg X // fsub nsz 0.0, X ==> fneg nsz X @@ -2315,10 +2341,10 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { // fsub -0.0, Denorm ==> +-0 // fneg Denorm ==> -Denorm Value *Op; - if (match(&I, m_FNeg(m_Value(Op)))) - return UnaryOperator::CreateFNegFMF(Op, &I); + if (MC.try_match(&I, m_FNeg(m_Value(Op)))) + return MCBuilder.CreateFNegFMF(Op, &I); - if (Instruction *X = foldFNegIntoConstant(I)) + if (Instruction *X = foldFNegIntoConstant(I)) return X; if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) @@ -2335,20 +2361,20 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { // killed later. We still limit that particular transform with 'hasOneUse' // because an fneg is assumed better/cheaper than a generic fsub. if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { - if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { - Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); - return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); + if (MC.try_match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *NewSub = MCBuilder.CreateFSubFMF(Builder, Y, X, &I); + return MCBuilder.CreateFAddFMF(Op0, NewSub, &I); } } // (-X) - Op1 --> -(X + Op1) if (I.hasNoSignedZeros() && !isa(Op0) && - match(Op0, m_OneUse(m_FNeg(m_Value(X))))) { - Value *FAdd = Builder.CreateFAddFMF(X, Op1, &I); - return UnaryOperator::CreateFNegFMF(FAdd, &I); + MC.try_match(Op0, m_OneUse(m_FNeg(m_Value(X))))) { + Value *FAdd = MCBuilder.CreateFAddFMF(Builder, X, Op1, &I); + return MCBuilder.CreateFNegFMF(FAdd, &I); } - if (isa(Op0)) + if (MatchContextType::IsEmpty && isa(Op0)) if (SelectInst *SI = dyn_cast(Op1)) if (Instruction *NV = FoldOpIntoSelect(I, SI)) return NV; @@ -2356,22 +2382,22 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { // X - C --> X + (-C) // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. - if (match(Op1, m_ImmConstant(C))) - return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(Op1, m_ImmConstant(C))) + return MCBuilder.CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); // X - (-Y) --> X + Y - if (match(Op1, m_FNeg(m_Value(Y)))) - return BinaryOperator::CreateFAddFMF(Op0, Y, &I); + if (MC.try_match(Op1, m_FNeg(m_Value(Y)))) + return MCBuilder.CreateFAddFMF(Op0, Y, &I); // Similar to above, but look through a cast of the negated value: // X - (fptrunc(-Y)) --> X + fptrunc(Y) Type *Ty = I.getType(); - if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPTrunc(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPTrunc(Builder, Y, Ty), &I); // X - (fpext(-Y)) --> X + fpext(Y) - if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPExt(Builder, Y, Ty), &I); // Similar to above, but look through fmul/fdiv of the negated value: // Op0 - (-X * Y) --> Op0 + (X * Y) @@ -2389,39 +2415,40 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { } // Handle special cases for FSub with selects feeding the operation - if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) - return replaceInstUsesWith(I, V); + if (auto * PlainBinOp = dyn_cast(&I)) + if (Value *V = SimplifySelectsFeedingBinaryOp(*PlainBinOp, Op0, Op1)) + return replaceInstUsesWith(I, V); if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { // (Y - X) - Y --> -X - if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) - return UnaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // Y - (X + Y) --> -X // Y - (Y + X) --> -X - if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) - return UnaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // (X * C) - X --> X * (C - 1.0) - if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { + if (MC.try_match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { Constant *CSubOne = ConstantExpr::getFSub(C, ConstantFP::get(Ty, 1.0)); - return BinaryOperator::CreateFMulFMF(Op1, CSubOne, &I); + return MCBuilder.CreateFMulFMF(Op1, CSubOne, &I); } // X - (X * C) --> X * (1.0 - C) - if (match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { + if (MC.try_match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { Constant *OneSubC = ConstantExpr::getFSub(ConstantFP::get(Ty, 1.0), C); - return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I); + return MCBuilder.CreateFMulFMF(Op0, OneSubC, &I); } // Reassociate fsub/fadd sequences to create more fadd instructions and // reduce dependency chains: // ((X - Y) + Z) - Op1 --> (X + Z) - (Y + Op1) Value *Z; - if (match(Op0, m_OneUse(m_c_FAdd(m_OneUse(m_FSub(m_Value(X), m_Value(Y))), + if (MC.try_match(Op0, m_OneUse(m_c_FAdd(m_OneUse(m_FSub(m_Value(X), m_Value(Y))), m_Value(Z))))) { Value *XZ = Builder.CreateFAddFMF(X, Z, &I); Value *YW = Builder.CreateFAddFMF(Y, Op1, &I); - return BinaryOperator::CreateFSubFMF(XZ, YW, &I); + return MCBuilder.CreateFSubFMF(XZ, YW, &I); } auto m_FaddRdx = [](Value *&Sum, Value *&Vec) { @@ -2439,8 +2466,12 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { return BinaryOperator::CreateFSubFMF(Rdx, A1, &I); } - if (Instruction *F = factorizeFAddFSub(I, Builder)) - return F; + auto *BinOp = dyn_cast(&I); + if (BinOp) { + auto *F = factorizeFAddFSub(*BinOp, Builder); + if (F) + return F; + } // TODO: This performs reassociative folds for FP ops. Some fraction of the // functionality has been subsumed by simple pattern matching here and in @@ -2450,9 +2481,9 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { return replaceInstUsesWith(I, V); // (X - Y) - Op1 --> X - (Y + Op1) - if (match(Op0, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + if (MC.try_match(Op0, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { Value *FAdd = Builder.CreateFAddFMF(Y, Op1, &I); - return BinaryOperator::CreateFSubFMF(X, FAdd, &I); + return MCBuilder.CreateFSubFMF(X, FAdd, &I); } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index eedc79f43cd4..58bd2f5034b0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -41,6 +41,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsAArch64.h" @@ -1066,6 +1067,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return &CI; } + // Predicated instruction patterns + auto * VPInst = dyn_cast(&CI); + if (VPInst) { + auto * PredInst = cast(VPInst); + auto Result = visitPredicatedInstruction(PredInst); + if (Result) return Result; + } + IntrinsicInst *II = dyn_cast(&CI); if (!II) return visitCallBase(CI); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index b6237a54ebc5..ed27944eb0df 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -21,6 +21,11 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" #include "llvm/Support/Debug.h" @@ -91,6 +96,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Value *OptimizePointerDifference( Value *LHS, Value *RHS, Type *Ty, bool isNUW); Instruction *visitSub(BinaryOperator &I); + template Instruction *visitFSubGeneric(BinaryOpTy &I); + Instruction *visitPredicatedFSub(PredicatedBinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); Instruction *visitFMul(BinaryOperator &I); @@ -175,6 +182,16 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final bool freezeDominatedUses(FreezeInst &FI); Instruction *visitFreeze(FreezeInst &I); + // Entry point to VPIntrinsic + Instruction *visitPredicatedInstruction(PredicatedInstruction * PI) { + switch (PI->getOpcode()) { + default: + return nullptr; + case Instruction::FSub: + return visitPredicatedFSub(cast(*PI)); + } + } + /// Specify what to return for unhandled instructions. Instruction *visitInstruction(Instruction &I) { return nullptr; } diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 2d6740831cdc..8dcf95cc1f50 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -2252,8 +2252,13 @@ template struct match_LoopInvariant { match_LoopInvariant(const SubPattern_t &SP, const Loop *L) : SubPattern(SP), L(L) {} + template bool match_context(ITy *V, MatcherContext &MC) { + return L->isLoopInvariant(V) && SubPattern.match_context(V, MC); + } + template bool match(ITy *V) { - return L->isLoopInvariant(V) && SubPattern.match(V); + PatternMatch::EmptyContext EC; + return match_context(V, EC); } }; diff --git a/llvm/test/Transforms/InstCombine/vp-fsub.ll b/llvm/test/Transforms/InstCombine/vp-fsub.ll new file mode 100644 index 000000000000..e560b8a56a8b --- /dev/null +++ b/llvm/test/Transforms/InstCombine/vp-fsub.ll @@ -0,0 +1,45 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; PR4374 + +define <4 x float> @test1_vp(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @test1_vp( +; + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) #0 + %t2 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> , <4 x float> %t1, <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + +; Can't do anything with the test above because -0.0 - 0.0 = -0.0, but if we have nsz: +; -(X - Y) --> Y - X + +; TODO predicated FAdd folding +define <4 x float> @neg_sub_nsz_vp(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CH***-LABEL: @neg_sub_nsz_vp( +; + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) #0 + %t2 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> , <4 x float> %t1, <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + +; With nsz: Z - (X - Y) --> Z + (Y - X) + +define <4 x float> @sub_sub_nsz_vp(<4 x float> %x, <4 x float> %y, <4 x float> %z, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @sub_sub_nsz_vp( +; CHECK-NEXT: %1 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %y, <4 x float> %x, <4 x i1> %M, i32 %L) # +; CHECK-NEXT: %t2 = call nsz <4 x float> @llvm.vp.fadd.v4f32(<4 x float> %z, <4 x float> %1, <4 x i1> %M, i32 %L) # +; CHECK-NEXT: ret <4 x float> %t2 + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) #0 + %t2 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %z, <4 x float> %t1, <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + + + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.vp.fadd.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32) + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.vp.fsub.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32) + +attributes #0 = { readnone } diff --git a/llvm/test/Transforms/InstSimplify/vp-fsub.ll b/llvm/test/Transforms/InstSimplify/vp-fsub.ll new file mode 100644 index 000000000000..45846769a415 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/vp-fsub.ll @@ -0,0 +1,55 @@ +; RUN: opt < %s -instsimplify -S | FileCheck %s + +define <8 x double> @fsub_fadd_fold_vp_xy(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_vp_xy +; CHECK: ret <8 x double> %x + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) + %res0 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) + ret <8 x double> %res0 +} + +define <8 x double> @fsub_fadd_fold_vp_zw(<8 x double> %z, <8 x double> %w, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_vp_zw +; CHECK: ret <8 x double> %z + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %w, <8 x double> %z, <8 x i1> %m, i32 %len) + %res1 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %w, <8 x i1> %m, i32 %len) + ret <8 x double> %res1 +} + +; REQUIRES-CONSTRAINED-VP: define <8 x double> @fsub_fadd_fold_vp_yx_fpexcept(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) #0 { +; REQUIRES-CONSTRAINED-VP: ; *HECK-LABEL: fsub_fadd_fold_vp_yx +; REQUIRES-CONSTRAINED-VP: ; *HECK-NEXT: %tmp = +; REQUIRES-CONSTRAINED-VP: ; *HECK-NEXT: %res2 = +; REQUIRES-CONSTRAINED-VP: ; *HECK-NEXT: ret +; REQUIRES-CONSTRAINED-VP: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %len) [ "cfp-except"(metadata !"fpexcept.strict") ] +; REQUIRES-CONSTRAINED-VP: %res2 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) [ "cfp-except"(metadata !"fpexcept.strict") ] +; REQUIRES-CONSTRAINED-VP: ret <8 x double> %res2 +; REQUIRES-CONSTRAINED-VP: } + +define <8 x double> @fsub_fadd_fold_vp_yx_olen(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, i32 %otherLen) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx_olen +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %otherLen) +; CHECK-NEXT: %res3 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) +; CHECK-NEXT: ret <8 x double> %res3 + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %otherLen) + %res3 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) + ret <8 x double> %res3 +} + +define <8 x double> @fsub_fadd_fold_vp_yx_omask(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, <8 x i1> %othermask) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx_omask +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %len) +; CHECK-NEXT: %res4 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %othermask, i32 %len) +; CHECK-NEXT: ret <8 x double> %res4 + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %len) + %res4 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %othermask, i32 %len) + ret <8 x double> %res4 +} + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, <8 x i1>, i32) + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, <8 x i1>, i32) + +attributes #0 = { strictfp }