Skip to content

Commit

Permalink
[ValueTracking] Improve KnownBits for signed min-max clamping (llvm#1…
Browse files Browse the repository at this point in the history
…20576)

A signed min-max clamp is the sequence of smin and smax intrinsics,
which constrain a signed value into the range: smin <= value <= smax.
The patch improves the calculation of KnownBits for a value subjected to
the signed clamping.
  • Loading branch information
adam-bzowski authored Dec 25, 2024
1 parent 3469996 commit 6d7cf52
Show file tree
Hide file tree
Showing 2 changed files with 325 additions and 49 deletions.
108 changes: 59 additions & 49 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,63 @@ void llvm::adjustKnownBitsForSelectArm(KnownBits &Known, Value *Cond,
Known = CondRes;
}

// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow).
// Returns the input and lower/upper bounds.
static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
const APInt *&CLow, const APInt *&CHigh) {
assert(isa<Operator>(Select) &&
cast<Operator>(Select)->getOpcode() == Instruction::Select &&
"Input should be a Select!");

const Value *LHS = nullptr, *RHS = nullptr;
SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor;
if (SPF != SPF_SMAX && SPF != SPF_SMIN)
return false;

if (!match(RHS, m_APInt(CLow)))
return false;

const Value *LHS2 = nullptr, *RHS2 = nullptr;
SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor;
if (getInverseMinMaxFlavor(SPF) != SPF2)
return false;

if (!match(RHS2, m_APInt(CHigh)))
return false;

if (SPF == SPF_SMIN)
std::swap(CLow, CHigh);

In = LHS2;
return CLow->sle(*CHigh);
}

static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II,
const APInt *&CLow,
const APInt *&CHigh) {
assert((II->getIntrinsicID() == Intrinsic::smin ||
II->getIntrinsicID() == Intrinsic::smax) &&
"Must be smin/smax");

Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(II->getIntrinsicID());
auto *InnerII = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
if (!InnerII || InnerII->getIntrinsicID() != InverseID ||
!match(II->getArgOperand(1), m_APInt(CLow)) ||
!match(InnerII->getArgOperand(1), m_APInt(CHigh)))
return false;

if (II->getIntrinsicID() == Intrinsic::smin)
std::swap(CLow, CHigh);
return CLow->sle(*CHigh);
}

static void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II,
KnownBits &Known) {
const APInt *CLow, *CHigh;
if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
Known = Known.unionWith(ConstantRange(*CLow, *CHigh + 1).toKnownBits());
}

static void computeKnownBitsFromOperator(const Operator *I,
const APInt &DemandedElts,
KnownBits &Known, unsigned Depth,
Expand Down Expand Up @@ -1804,11 +1861,13 @@ static void computeKnownBitsFromOperator(const Operator *I,
computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::smin(Known, Known2);
unionWithMinMaxIntrinsicClamp(II, Known);
break;
case Intrinsic::smax:
computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
Known = KnownBits::smax(Known, Known2);
unionWithMinMaxIntrinsicClamp(II, Known);
break;
case Intrinsic::ptrmask: {
computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q);
Expand Down Expand Up @@ -3751,55 +3810,6 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2,
return false;
}

// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow).
// Returns the input and lower/upper bounds.
static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
const APInt *&CLow, const APInt *&CHigh) {
assert(isa<Operator>(Select) &&
cast<Operator>(Select)->getOpcode() == Instruction::Select &&
"Input should be a Select!");

const Value *LHS = nullptr, *RHS = nullptr;
SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor;
if (SPF != SPF_SMAX && SPF != SPF_SMIN)
return false;

if (!match(RHS, m_APInt(CLow)))
return false;

const Value *LHS2 = nullptr, *RHS2 = nullptr;
SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor;
if (getInverseMinMaxFlavor(SPF) != SPF2)
return false;

if (!match(RHS2, m_APInt(CHigh)))
return false;

if (SPF == SPF_SMIN)
std::swap(CLow, CHigh);

In = LHS2;
return CLow->sle(*CHigh);
}

static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II,
const APInt *&CLow,
const APInt *&CHigh) {
assert((II->getIntrinsicID() == Intrinsic::smin ||
II->getIntrinsicID() == Intrinsic::smax) && "Must be smin/smax");

Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(II->getIntrinsicID());
auto *InnerII = dyn_cast<IntrinsicInst>(II->getArgOperand(0));
if (!InnerII || InnerII->getIntrinsicID() != InverseID ||
!match(II->getArgOperand(1), m_APInt(CLow)) ||
!match(InnerII->getArgOperand(1), m_APInt(CHigh)))
return false;

if (II->getIntrinsicID() == Intrinsic::smin)
std::swap(CLow, CHigh);
return CLow->sle(*CHigh);
}

/// For vector constants, loop over the elements and find the constant with the
/// minimum number of sign bits. Return 0 if the value is not a vector constant
/// or if any element was not analyzed; otherwise, return the count for the
Expand Down
Loading

0 comments on commit 6d7cf52

Please sign in to comment.