Skip to content

Commit

Permalink
[CIR][Lowering] Add the concept of simple lowering and use it for una…
Browse files Browse the repository at this point in the history
…ry fp2fp operations (llvm#806)

This PR is the continuation and refinement of PR llvm#434 which is
originally authored by @philnik777 . Does not update it in-place since I
don't have commit access to Nikolas' repo.

This PR basically just rebases llvm#434 onto the latest `main`. I also
updated some naming used in the original PR to keep naming styles
consistent.

Co-authored-by: Nikolas Klauser <[email protected]>
  • Loading branch information
2 people authored and smeenai committed Oct 9, 2024
1 parent 54a854b commit d1016d6
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 117 deletions.
84 changes: 66 additions & 18 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,48 @@ include "mlir/IR/CommonAttrConstraints.td"
// CIR Ops
//===----------------------------------------------------------------------===//

// LLVMLoweringInfo is used by cir-tablegen to generate LLVM lowering logic
// automatically for CIR operations. The `llvmOp` field gives the name of the
// LLVM IR dialect operation that the CIR operation will be lowered to. The
// input arguments of the CIR operation will be passed in the same order to the
// lowered LLVM IR operation.
//
// Example:
//
// For the following CIR operation definition:
//
// def FooOp : CIR_Op<"foo"> {
// // ...
// let arguments = (ins CIR_AnyType:$arg1, CIR_AnyType:$arg2);
// let llvmOp = "BarOp";
// }
//
// cir-tablegen will generate LLVM lowering code for the FooOp similar to the
// following:
//
// class CIRFooOpLowering
// : public mlir::OpConversionPattern<mlir::cir::FooOp> {
// public:
// using OpConversionPattern<mlir::cir::FooOp>::OpConversionPattern;
//
// mlir::LogicalResult matchAndRewrite(
// mlir::cir::FooOp op,
// OpAdaptor adaptor,
// mlir::ConversionPatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<mlir::LLVM::BarOp>(
// op, adaptor.getOperands()[0], adaptor.getOperands()[1]);
// return mlir::success();
// }
// }
//
// If you want fully customized LLVM IR lowering logic, simply exclude the
// `llvmOp` field from your CIR operation definition.
class LLVMLoweringInfo {
string llvmOp = "";
}

class CIR_Op<string mnemonic, list<Trait> traits = []> :
Op<CIR_Dialect, mnemonic, traits>;
Op<CIR_Dialect, mnemonic, traits>, LLVMLoweringInfo;

//===----------------------------------------------------------------------===//
// CIR Op Traits
Expand Down Expand Up @@ -2708,6 +2748,8 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,
}];

let hasVerifier = 0;

let llvmOp = "InsertElementOp";
}

//===----------------------------------------------------------------------===//
Expand All @@ -2732,6 +2774,8 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
}];

let hasVerifier = 0;

let llvmOp = "ExtractElementOp";
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3762,30 +3806,32 @@ def LLroundOp : UnaryFPToIntBuiltinOp<"llround">;
def LrintOp : UnaryFPToIntBuiltinOp<"lrint">;
def LLrintOp : UnaryFPToIntBuiltinOp<"llrint">;

class UnaryFPToFPBuiltinOp<string mnemonic>
class UnaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let arguments = (ins CIR_AnyFloat:$src);
let results = (outs CIR_AnyFloat:$result);
let summary = "libc builtin equivalent ignoring "
"floating point exceptions and errno";
let assemblyFormat = "$src `:` type($src) attr-dict";
}

def CeilOp : UnaryFPToFPBuiltinOp<"ceil">;
def CosOp : UnaryFPToFPBuiltinOp<"cos">;
def ExpOp : UnaryFPToFPBuiltinOp<"exp">;
def Exp2Op : UnaryFPToFPBuiltinOp<"exp2">;
def FloorOp : UnaryFPToFPBuiltinOp<"floor">;
def FAbsOp : UnaryFPToFPBuiltinOp<"fabs">;
def LogOp : UnaryFPToFPBuiltinOp<"log">;
def Log10Op : UnaryFPToFPBuiltinOp<"log10">;
def Log2Op : UnaryFPToFPBuiltinOp<"log2">;
def NearbyintOp : UnaryFPToFPBuiltinOp<"nearbyint">;
def RintOp : UnaryFPToFPBuiltinOp<"rint">;
def RoundOp : UnaryFPToFPBuiltinOp<"round">;
def SinOp : UnaryFPToFPBuiltinOp<"sin">;
def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt">;
def TruncOp : UnaryFPToFPBuiltinOp<"trunc">;
let llvmOp = llvmOpName;
}

def CeilOp : UnaryFPToFPBuiltinOp<"ceil", "FCeilOp">;
def CosOp : UnaryFPToFPBuiltinOp<"cos", "CosOp">;
def ExpOp : UnaryFPToFPBuiltinOp<"exp", "ExpOp">;
def Exp2Op : UnaryFPToFPBuiltinOp<"exp2", "Exp2Op">;
def FloorOp : UnaryFPToFPBuiltinOp<"floor", "FFloorOp">;
def FAbsOp : UnaryFPToFPBuiltinOp<"fabs", "FAbsOp">;
def LogOp : UnaryFPToFPBuiltinOp<"log", "LogOp">;
def Log10Op : UnaryFPToFPBuiltinOp<"log10", "Log10Op">;
def Log2Op : UnaryFPToFPBuiltinOp<"log2", "Log2Op">;
def NearbyintOp : UnaryFPToFPBuiltinOp<"nearbyint", "NearbyintOp">;
def RintOp : UnaryFPToFPBuiltinOp<"rint", "RintOp">;
def RoundOp : UnaryFPToFPBuiltinOp<"round", "RoundOp">;
def SinOp : UnaryFPToFPBuiltinOp<"sin", "SinOp">;
def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt", "SqrtOp">;
def TruncOp : UnaryFPToFPBuiltinOp<"trunc", "FTruncOp">;

class BinaryFPToFPBuiltinOp<string mnemonic>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
Expand Down Expand Up @@ -3987,6 +4033,8 @@ def StackRestoreOp : CIR_Op<"stack_restore"> {

let arguments = (ins CIR_PointerType:$ptr);
let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr))";

let llvmOp = "StackRestoreOp";
}

def AsmATT : I32EnumAttrCase<"x86_att", 0>;
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ mlir_tablegen(CIROpsStructs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(CIROpsAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRCIREnumsGen)

clang_tablegen(CIRBuiltinsLowering.inc -gen-cir-builtins-lowering
SOURCE CIROps.td
TARGET CIRBuiltinsLowering)
112 changes: 13 additions & 99 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1330,34 +1330,6 @@ class CIRVectorCreateLowering
}
};

class CIRVectorInsertLowering
: public mlir::OpConversionPattern<mlir::cir::VecInsertOp> {
public:
using OpConversionPattern<mlir::cir::VecInsertOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VecInsertOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>(
op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex());
return mlir::success();
}
};

class CIRVectorExtractLowering
: public mlir::OpConversionPattern<mlir::cir::VecExtractOp> {
public:
using OpConversionPattern<mlir::cir::VecExtractOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VecExtractOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
op, adaptor.getVec(), adaptor.getIndex());
return mlir::success();
}
};

class CIRVectorCmpOpLowering
: public mlir::OpConversionPattern<mlir::cir::VecCmpOp> {
public:
Expand Down Expand Up @@ -3156,19 +3128,6 @@ class CIRPtrDiffOpLowering
}
};

class CIRFAbsOpLowering : public mlir::OpConversionPattern<mlir::cir::FAbsOp> {
public:
using OpConversionPattern<mlir::cir::FAbsOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::FAbsOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::FAbsOp>(
op, adaptor.getOperands().front());
return mlir::success();
}
};

class CIRExpectOpLowering
: public mlir::OpConversionPattern<mlir::cir::ExpectOp> {
public:
Expand Down Expand Up @@ -3248,19 +3207,8 @@ class CIRStackSaveLowering
}
};

class CIRStackRestoreLowering
: public mlir::OpConversionPattern<mlir::cir::StackRestoreOp> {
public:
using OpConversionPattern<mlir::cir::StackRestoreOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::StackRestoreOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::StackRestoreOp>(op,
adaptor.getPtr());
return mlir::success();
}
};
#define GET_BUILTIN_LOWERING_CLASSES
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"

class CIRUnreachableLowering
: public mlir::OpConversionPattern<mlir::cir::UnreachableOp> {
Expand Down Expand Up @@ -3603,38 +3551,6 @@ class CIRUnaryFPBuiltinOpLowering : public mlir::OpConversionPattern<CIROp> {
}
};

using CIRCeilOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::CeilOp, mlir::LLVM::FCeilOp>;
using CIRCosOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::CosOp, mlir::LLVM::CosOp>;
using CIRExpOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::ExpOp, mlir::LLVM::ExpOp>;
using CIRExp2OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Exp2Op, mlir::LLVM::Exp2Op>;
using CIRFloorOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::FloorOp, mlir::LLVM::FFloorOp>;
using CIRFabsOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::FAbsOp, mlir::LLVM::FAbsOp>;
using CIRLogOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::LogOp, mlir::LLVM::LogOp>;
using CIRLog10OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log10Op, mlir::LLVM::Log10Op>;
using CIRLog2OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log2Op, mlir::LLVM::Log2Op>;
using CIRNearbyintOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::NearbyintOp,
mlir::LLVM::NearbyintOp>;
using CIRRintOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::RintOp, mlir::LLVM::RintOp>;
using CIRRoundOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::RoundOp, mlir::LLVM::RoundOp>;
using CIRSinOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::SinOp, mlir::LLVM::SinOp>;
using CIRSqrtOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::SqrtOp, mlir::LLVM::SqrtOp>;
using CIRTruncOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::TruncOp, mlir::LLVM::FTruncOp>;

using CIRLroundOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::LroundOp, mlir::LLVM::LroundOp>;
using CIRLLroundOpLowering =
Expand Down Expand Up @@ -3909,23 +3825,21 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
CIRVectorSplatLowering, CIRVectorTernaryLowering,
CIRVectorCmpOpLowering, CIRVectorSplatLowering, CIRVectorTernaryLowering,
CIRVectorShuffleIntsLowering, CIRVectorShuffleVecLowering,
CIRStackSaveLowering, CIRStackRestoreLowering, CIRUnreachableLowering,
CIRTrapLowering, CIRInlineAsmOpLowering, CIRSetBitfieldLowering,
CIRGetBitfieldLowering, CIRPrefetchLowering, CIRObjSizeOpLowering,
CIRIsConstantOpLowering, CIRCmpThreeWayOpLowering, CIRLroundOpLowering,
CIRLLroundOpLowering, CIRLrintOpLowering, CIRLLrintOpLowering,
CIRCeilOpLowering, CIRCosOpLowering, CIRExpOpLowering, CIRExp2OpLowering,
CIRFloorOpLowering, CIRFAbsOpLowering, CIRLogOpLowering,
CIRLog10OpLowering, CIRLog2OpLowering, CIRNearbyintOpLowering,
CIRRintOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
CIRSqrtOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering,
CIRStackSaveLowering, CIRUnreachableLowering, CIRTrapLowering,
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
CIRCmpThreeWayOpLowering, CIRLroundOpLowering, CIRLLroundOpLowering,
CIRLrintOpLowering, CIRLLrintOpLowering, CIRCopysignOpLowering,
CIRFModOpLowering, CIRFMaxOpLowering, CIRFMinOpLowering, CIRPowOpLowering,
CIRClearCacheOpLowering, CIRUndefOpLowering, CIREhTypeIdOpLowering,
CIRCatchParamOpLowering, CIRResumeOpLowering, CIRAllocExceptionOpLowering,
CIRThrowOpLowering>(converter, patterns.getContext());
CIRThrowOpLowering
#define GET_BUILTIN_LOWERING_LIST
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"
#undef GET_BUILTIN_LOWERING_LIST
>(converter, patterns.getContext());
}

namespace {
Expand Down
50 changes: 50 additions & 0 deletions clang/test/CIR/Lowering/builtin-floating-point.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: cir-opt %s -cir-to-llvm -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s

module {
cir.func @test(%arg0 : !cir.float) {
%1 = cir.cos %arg0 : !cir.float
// CHECK: llvm.intr.cos(%arg0) : (f32) -> f32

%2 = cir.ceil %arg0 : !cir.float
// CHECK: llvm.intr.ceil(%arg0) : (f32) -> f32

%3 = cir.exp %arg0 : !cir.float
// CHECK: llvm.intr.exp(%arg0) : (f32) -> f32

%4 = cir.exp2 %arg0 : !cir.float
// CHECK: llvm.intr.exp2(%arg0) : (f32) -> f32

%5 = cir.fabs %arg0 : !cir.float
// CHECK: llvm.intr.fabs(%arg0) : (f32) -> f32

%6 = cir.floor %arg0 : !cir.float
// CHECK: llvm.intr.floor(%arg0) : (f32) -> f32

%7 = cir.log %arg0 : !cir.float
// CHECK: llvm.intr.log(%arg0) : (f32) -> f32

%8 = cir.log10 %arg0 : !cir.float
// CHECK: llvm.intr.log10(%arg0) : (f32) -> f32

%9 = cir.log2 %arg0 : !cir.float
// CHECK: llvm.intr.log2(%arg0) : (f32) -> f32

%10 = cir.nearbyint %arg0 : !cir.float
// CHECK: llvm.intr.nearbyint(%arg0) : (f32) -> f32

%11 = cir.rint %arg0 : !cir.float
// CHECK: llvm.intr.rint(%arg0) : (f32) -> f32

%12 = cir.round %arg0 : !cir.float
// CHECK: llvm.intr.round(%arg0) : (f32) -> f32

%13 = cir.sin %arg0 : !cir.float
// CHECK: llvm.intr.sin(%arg0) : (f32) -> f32

%14 = cir.sqrt %arg0 : !cir.float
// CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32

cir.return
}
}
Loading

0 comments on commit d1016d6

Please sign in to comment.