Skip to content

Commit

Permalink
[NVPTX] Constant-folding for f2i, d2ui, f2ll etc. (llvm#118965)
Browse files Browse the repository at this point in the history
Add constant-folding support for the NVVM intrinsics for converting
float/double to signed/unsigned int32/int64 types, including all
rounding-modes and ftz modifiers.
  • Loading branch information
LewisCrawford authored Jan 7, 2025
1 parent c274837 commit a629d9e
Show file tree
Hide file tree
Showing 7 changed files with 2,575 additions and 41 deletions.
39 changes: 0 additions & 39 deletions llvm/include/llvm/IR/NVVMIntrinsicFlags.h

This file was deleted.

176 changes: 176 additions & 0 deletions llvm/include/llvm/IR/NVVMIntrinsicUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
//===--- NVVMIntrinsicUtils.h -----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file
/// This file contains the definitions of the enumerations and flags
/// associated with NVVM Intrinsics, along with some helper functions.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_IR_NVVMINTRINSICUTILS_H
#define LLVM_IR_NVVMINTRINSICUTILS_H

#include <stdint.h>

#include "llvm/ADT/APFloat.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsNVPTX.h"

namespace llvm {
namespace nvvm {

// Reduction Ops supported with TMA Copy from Shared
// to Global Memory for the "cp.reduce.async.bulk.tensor.*"
// family of PTX instructions.
enum class TMAReductionOp : uint8_t {
ADD = 0,
MIN = 1,
MAX = 2,
INC = 3,
DEC = 4,
AND = 5,
OR = 6,
XOR = 7,
};

inline bool IntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
// Float to i32 / i64 conversion intrinsics:
case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2i_rp_ftz:
case Intrinsic::nvvm_f2i_rz_ftz:

case Intrinsic::nvvm_f2ui_rm_ftz:
case Intrinsic::nvvm_f2ui_rn_ftz:
case Intrinsic::nvvm_f2ui_rp_ftz:
case Intrinsic::nvvm_f2ui_rz_ftz:

case Intrinsic::nvvm_f2ll_rm_ftz:
case Intrinsic::nvvm_f2ll_rn_ftz:
case Intrinsic::nvvm_f2ll_rp_ftz:
case Intrinsic::nvvm_f2ll_rz_ftz:

case Intrinsic::nvvm_f2ull_rm_ftz:
case Intrinsic::nvvm_f2ull_rn_ftz:
case Intrinsic::nvvm_f2ull_rp_ftz:
case Intrinsic::nvvm_f2ull_rz_ftz:
return true;
}
return false;
}

inline bool IntrinsicConvertsToSignedInteger(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
// f2i
case Intrinsic::nvvm_f2i_rm:
case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2i_rn:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2i_rp:
case Intrinsic::nvvm_f2i_rp_ftz:
case Intrinsic::nvvm_f2i_rz:
case Intrinsic::nvvm_f2i_rz_ftz:
// d2i
case Intrinsic::nvvm_d2i_rm:
case Intrinsic::nvvm_d2i_rn:
case Intrinsic::nvvm_d2i_rp:
case Intrinsic::nvvm_d2i_rz:
// f2ll
case Intrinsic::nvvm_f2ll_rm:
case Intrinsic::nvvm_f2ll_rm_ftz:
case Intrinsic::nvvm_f2ll_rn:
case Intrinsic::nvvm_f2ll_rn_ftz:
case Intrinsic::nvvm_f2ll_rp:
case Intrinsic::nvvm_f2ll_rp_ftz:
case Intrinsic::nvvm_f2ll_rz:
case Intrinsic::nvvm_f2ll_rz_ftz:
// d2ll
case Intrinsic::nvvm_d2ll_rm:
case Intrinsic::nvvm_d2ll_rn:
case Intrinsic::nvvm_d2ll_rp:
case Intrinsic::nvvm_d2ll_rz:
return true;
}
return false;
}

inline APFloat::roundingMode
IntrinsicGetRoundingMode(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
// RM:
case Intrinsic::nvvm_f2i_rm:
case Intrinsic::nvvm_f2ui_rm:
case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2ui_rm_ftz:
case Intrinsic::nvvm_d2i_rm:
case Intrinsic::nvvm_d2ui_rm:

case Intrinsic::nvvm_f2ll_rm:
case Intrinsic::nvvm_f2ull_rm:
case Intrinsic::nvvm_f2ll_rm_ftz:
case Intrinsic::nvvm_f2ull_rm_ftz:
case Intrinsic::nvvm_d2ll_rm:
case Intrinsic::nvvm_d2ull_rm:
return APFloat::rmTowardNegative;

// RN:
case Intrinsic::nvvm_f2i_rn:
case Intrinsic::nvvm_f2ui_rn:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2ui_rn_ftz:
case Intrinsic::nvvm_d2i_rn:
case Intrinsic::nvvm_d2ui_rn:

case Intrinsic::nvvm_f2ll_rn:
case Intrinsic::nvvm_f2ull_rn:
case Intrinsic::nvvm_f2ll_rn_ftz:
case Intrinsic::nvvm_f2ull_rn_ftz:
case Intrinsic::nvvm_d2ll_rn:
case Intrinsic::nvvm_d2ull_rn:
return APFloat::rmNearestTiesToEven;

// RP:
case Intrinsic::nvvm_f2i_rp:
case Intrinsic::nvvm_f2ui_rp:
case Intrinsic::nvvm_f2i_rp_ftz:
case Intrinsic::nvvm_f2ui_rp_ftz:
case Intrinsic::nvvm_d2i_rp:
case Intrinsic::nvvm_d2ui_rp:

case Intrinsic::nvvm_f2ll_rp:
case Intrinsic::nvvm_f2ull_rp:
case Intrinsic::nvvm_f2ll_rp_ftz:
case Intrinsic::nvvm_f2ull_rp_ftz:
case Intrinsic::nvvm_d2ll_rp:
case Intrinsic::nvvm_d2ull_rp:
return APFloat::rmTowardPositive;

// RZ:
case Intrinsic::nvvm_f2i_rz:
case Intrinsic::nvvm_f2ui_rz:
case Intrinsic::nvvm_f2i_rz_ftz:
case Intrinsic::nvvm_f2ui_rz_ftz:
case Intrinsic::nvvm_d2i_rz:
case Intrinsic::nvvm_d2ui_rz:

case Intrinsic::nvvm_f2ll_rz:
case Intrinsic::nvvm_f2ull_rz:
case Intrinsic::nvvm_f2ll_rz_ftz:
case Intrinsic::nvvm_f2ull_rz_ftz:
case Intrinsic::nvvm_d2ll_rz:
case Intrinsic::nvvm_d2ull_rz:
return APFloat::rmTowardZero;
}
llvm_unreachable("Invalid f2i/d2i rounding mode intrinsic");
return APFloat::roundingMode::Invalid;
}

} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICUTILS_H
139 changes: 139 additions & 0 deletions llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsARM.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/IntrinsicsWebAssembly.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
Expand Down Expand Up @@ -1687,6 +1689,58 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::x86_avx512_cvttsd2usi64:
return !Call->isStrictFP();

// NVVM float/double to int32/uint32 conversion intrinsics
case Intrinsic::nvvm_f2i_rm:
case Intrinsic::nvvm_f2i_rn:
case Intrinsic::nvvm_f2i_rp:
case Intrinsic::nvvm_f2i_rz:
case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2i_rp_ftz:
case Intrinsic::nvvm_f2i_rz_ftz:
case Intrinsic::nvvm_f2ui_rm:
case Intrinsic::nvvm_f2ui_rn:
case Intrinsic::nvvm_f2ui_rp:
case Intrinsic::nvvm_f2ui_rz:
case Intrinsic::nvvm_f2ui_rm_ftz:
case Intrinsic::nvvm_f2ui_rn_ftz:
case Intrinsic::nvvm_f2ui_rp_ftz:
case Intrinsic::nvvm_f2ui_rz_ftz:
case Intrinsic::nvvm_d2i_rm:
case Intrinsic::nvvm_d2i_rn:
case Intrinsic::nvvm_d2i_rp:
case Intrinsic::nvvm_d2i_rz:
case Intrinsic::nvvm_d2ui_rm:
case Intrinsic::nvvm_d2ui_rn:
case Intrinsic::nvvm_d2ui_rp:
case Intrinsic::nvvm_d2ui_rz:

// NVVM float/double to int64/uint64 conversion intrinsics
case Intrinsic::nvvm_f2ll_rm:
case Intrinsic::nvvm_f2ll_rn:
case Intrinsic::nvvm_f2ll_rp:
case Intrinsic::nvvm_f2ll_rz:
case Intrinsic::nvvm_f2ll_rm_ftz:
case Intrinsic::nvvm_f2ll_rn_ftz:
case Intrinsic::nvvm_f2ll_rp_ftz:
case Intrinsic::nvvm_f2ll_rz_ftz:
case Intrinsic::nvvm_f2ull_rm:
case Intrinsic::nvvm_f2ull_rn:
case Intrinsic::nvvm_f2ull_rp:
case Intrinsic::nvvm_f2ull_rz:
case Intrinsic::nvvm_f2ull_rm_ftz:
case Intrinsic::nvvm_f2ull_rn_ftz:
case Intrinsic::nvvm_f2ull_rp_ftz:
case Intrinsic::nvvm_f2ull_rz_ftz:
case Intrinsic::nvvm_d2ll_rm:
case Intrinsic::nvvm_d2ll_rn:
case Intrinsic::nvvm_d2ll_rp:
case Intrinsic::nvvm_d2ll_rz:
case Intrinsic::nvvm_d2ull_rm:
case Intrinsic::nvvm_d2ull_rn:
case Intrinsic::nvvm_d2ull_rp:
case Intrinsic::nvvm_d2ull_rz:

// Sign operations are actually bitwise operations, they do not raise
// exceptions even for SNANs.
case Intrinsic::fabs:
Expand Down Expand Up @@ -1849,6 +1903,12 @@ inline bool llvm_fenv_testexcept() {
return false;
}

static const APFloat FTZPreserveSign(const APFloat &V) {
if (V.isDenormal())
return APFloat::getZero(V.getSemantics(), V.isNegative());
return V;
}

Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V,
Type *Ty) {
llvm_fenv_clearexcept();
Expand Down Expand Up @@ -2309,6 +2369,85 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
return ConstantFP::get(Ty->getContext(), U);
}

// NVVM float/double to signed/unsigned int32/int64 conversions:
switch (IntrinsicID) {
// f2i
case Intrinsic::nvvm_f2i_rm:
case Intrinsic::nvvm_f2i_rn:
case Intrinsic::nvvm_f2i_rp:
case Intrinsic::nvvm_f2i_rz:
case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2i_rp_ftz:
case Intrinsic::nvvm_f2i_rz_ftz:
// f2ui
case Intrinsic::nvvm_f2ui_rm:
case Intrinsic::nvvm_f2ui_rn:
case Intrinsic::nvvm_f2ui_rp:
case Intrinsic::nvvm_f2ui_rz:
case Intrinsic::nvvm_f2ui_rm_ftz:
case Intrinsic::nvvm_f2ui_rn_ftz:
case Intrinsic::nvvm_f2ui_rp_ftz:
case Intrinsic::nvvm_f2ui_rz_ftz:
// d2i
case Intrinsic::nvvm_d2i_rm:
case Intrinsic::nvvm_d2i_rn:
case Intrinsic::nvvm_d2i_rp:
case Intrinsic::nvvm_d2i_rz:
// d2ui
case Intrinsic::nvvm_d2ui_rm:
case Intrinsic::nvvm_d2ui_rn:
case Intrinsic::nvvm_d2ui_rp:
case Intrinsic::nvvm_d2ui_rz:
// f2ll
case Intrinsic::nvvm_f2ll_rm:
case Intrinsic::nvvm_f2ll_rn:
case Intrinsic::nvvm_f2ll_rp:
case Intrinsic::nvvm_f2ll_rz:
case Intrinsic::nvvm_f2ll_rm_ftz:
case Intrinsic::nvvm_f2ll_rn_ftz:
case Intrinsic::nvvm_f2ll_rp_ftz:
case Intrinsic::nvvm_f2ll_rz_ftz:
// f2ull
case Intrinsic::nvvm_f2ull_rm:
case Intrinsic::nvvm_f2ull_rn:
case Intrinsic::nvvm_f2ull_rp:
case Intrinsic::nvvm_f2ull_rz:
case Intrinsic::nvvm_f2ull_rm_ftz:
case Intrinsic::nvvm_f2ull_rn_ftz:
case Intrinsic::nvvm_f2ull_rp_ftz:
case Intrinsic::nvvm_f2ull_rz_ftz:
// d2ll
case Intrinsic::nvvm_d2ll_rm:
case Intrinsic::nvvm_d2ll_rn:
case Intrinsic::nvvm_d2ll_rp:
case Intrinsic::nvvm_d2ll_rz:
// d2ull
case Intrinsic::nvvm_d2ull_rm:
case Intrinsic::nvvm_d2ull_rn:
case Intrinsic::nvvm_d2ull_rp:
case Intrinsic::nvvm_d2ull_rz: {
// In float-to-integer conversion, NaN inputs are converted to 0.
if (U.isNaN())
return ConstantInt::get(Ty, 0);

APFloat::roundingMode RMode = nvvm::IntrinsicGetRoundingMode(IntrinsicID);
bool IsFTZ = nvvm::IntrinsicShouldFTZ(IntrinsicID);
bool IsSigned = nvvm::IntrinsicConvertsToSignedInteger(IntrinsicID);

APSInt ResInt(Ty->getIntegerBitWidth(), !IsSigned);
auto FloatToRound = IsFTZ ? FTZPreserveSign(U) : U;

bool IsExact = false;
APFloat::opStatus Status =
FloatToRound.convertToInteger(ResInt, RMode, &IsExact);

if (Status != APFloat::opInvalidOp)
return ConstantInt::get(Ty, ResInt);
return nullptr;
}
}

/// We only fold functions with finite arguments. Folding NaN and inf is
/// likely to be aborted with an exception anyway, and some host libms
/// have known errors raising exceptions.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/NVVMIntrinsicFlags.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/MC/MCExpr.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCInstrInfo.h"
Expand Down
Loading

0 comments on commit a629d9e

Please sign in to comment.