Skip to content

Commit c3afb00

Browse files
philnik777Lancern
authored andcommitted
[CIR][Lowering] Add the concept of simple lowering and useit to implement FP intrincis
1 parent 826abe4 commit c3afb00

File tree

8 files changed

+205
-117
lines changed

8 files changed

+205
-117
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+66-18
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,48 @@ include "mlir/IR/CommonAttrConstraints.td"
3838
// CIR Ops
3939
//===----------------------------------------------------------------------===//
4040

41+
// LLVMLoweringInfo is used by cir-tablegen to generate LLVM lowering logic
42+
// automatically for CIR operations. The `llvmOp` field gives the name of the
43+
// LLVM IR dialect operation that the CIR operation will be lowered to. The
44+
// input arguments of the CIR operation will be passed in the same order to the
45+
// lowered LLVM IR operation.
46+
//
47+
// Example:
48+
//
49+
// For the following CIR operation definition:
50+
//
51+
// def FooOp : CIR_Op<"foo"> {
52+
// // ...
53+
// let arguments = (ins CIR_AnyType:$arg1, CIR_AnyType:$arg2);
54+
// let llvmOp = "BarOp";
55+
// }
56+
//
57+
// cir-tablegen will generate LLVM lowering code for the FooOp similar to the
58+
// following:
59+
//
60+
// class CIRFooOpLowering
61+
// : public mlir::OpConversionPattern<mlir::cir::FooOp> {
62+
// public:
63+
// using OpConversionPattern<mlir::cir::FooOp>::OpConversionPattern;
64+
//
65+
// mlir::LogicalResult matchAndRewrite(
66+
// mlir::cir::FooOp op,
67+
// OpAdaptor adaptor,
68+
// mlir::ConversionPatternRewriter &rewriter) const override {
69+
// rewriter.replaceOpWithNewOp<mlir::LLVM::BarOp>(
70+
// op, adaptor.getOperands()[0], adaptor.getOperands()[1]);
71+
// return mlir::success();
72+
// }
73+
// }
74+
//
75+
// If you want fully customized LLVM IR lowering logic, simply exclude the
76+
// `llvmOp` field from your CIR operation definition.
77+
class LLVMLoweringInfo {
78+
string llvmOp = "";
79+
}
80+
4181
class CIR_Op<string mnemonic, list<Trait> traits = []> :
42-
Op<CIR_Dialect, mnemonic, traits>;
82+
Op<CIR_Dialect, mnemonic, traits>, LLVMLoweringInfo;
4383

4484
//===----------------------------------------------------------------------===//
4585
// CIR Op Traits
@@ -2708,6 +2748,8 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,
27082748
}];
27092749

27102750
let hasVerifier = 0;
2751+
2752+
let llvmOp = "InsertElementOp";
27112753
}
27122754

27132755
//===----------------------------------------------------------------------===//
@@ -2732,6 +2774,8 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
27322774
}];
27332775

27342776
let hasVerifier = 0;
2777+
2778+
let llvmOp = "ExtractElementOp";
27352779
}
27362780

27372781
//===----------------------------------------------------------------------===//
@@ -3762,30 +3806,32 @@ def LLroundOp : UnaryFPToIntBuiltinOp<"llround">;
37623806
def LrintOp : UnaryFPToIntBuiltinOp<"lrint">;
37633807
def LLrintOp : UnaryFPToIntBuiltinOp<"llrint">;
37643808

3765-
class UnaryFPToFPBuiltinOp<string mnemonic>
3809+
class UnaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
37663810
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
37673811
let arguments = (ins CIR_AnyFloat:$src);
37683812
let results = (outs CIR_AnyFloat:$result);
37693813
let summary = "libc builtin equivalent ignoring "
37703814
"floating point exceptions and errno";
37713815
let assemblyFormat = "$src `:` type($src) attr-dict";
3772-
}
37733816

3774-
def CeilOp : UnaryFPToFPBuiltinOp<"ceil">;
3775-
def CosOp : UnaryFPToFPBuiltinOp<"cos">;
3776-
def ExpOp : UnaryFPToFPBuiltinOp<"exp">;
3777-
def Exp2Op : UnaryFPToFPBuiltinOp<"exp2">;
3778-
def FloorOp : UnaryFPToFPBuiltinOp<"floor">;
3779-
def FAbsOp : UnaryFPToFPBuiltinOp<"fabs">;
3780-
def LogOp : UnaryFPToFPBuiltinOp<"log">;
3781-
def Log10Op : UnaryFPToFPBuiltinOp<"log10">;
3782-
def Log2Op : UnaryFPToFPBuiltinOp<"log2">;
3783-
def NearbyintOp : UnaryFPToFPBuiltinOp<"nearbyint">;
3784-
def RintOp : UnaryFPToFPBuiltinOp<"rint">;
3785-
def RoundOp : UnaryFPToFPBuiltinOp<"round">;
3786-
def SinOp : UnaryFPToFPBuiltinOp<"sin">;
3787-
def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt">;
3788-
def TruncOp : UnaryFPToFPBuiltinOp<"trunc">;
3817+
let llvmOp = llvmOpName;
3818+
}
3819+
3820+
def CeilOp : UnaryFPToFPBuiltinOp<"ceil", "FCeilOp">;
3821+
def CosOp : UnaryFPToFPBuiltinOp<"cos", "CosOp">;
3822+
def ExpOp : UnaryFPToFPBuiltinOp<"exp", "ExpOp">;
3823+
def Exp2Op : UnaryFPToFPBuiltinOp<"exp2", "Exp2Op">;
3824+
def FloorOp : UnaryFPToFPBuiltinOp<"floor", "FFloorOp">;
3825+
def FAbsOp : UnaryFPToFPBuiltinOp<"fabs", "FAbsOp">;
3826+
def LogOp : UnaryFPToFPBuiltinOp<"log", "LogOp">;
3827+
def Log10Op : UnaryFPToFPBuiltinOp<"log10", "Log10Op">;
3828+
def Log2Op : UnaryFPToFPBuiltinOp<"log2", "Log2Op">;
3829+
def NearbyintOp : UnaryFPToFPBuiltinOp<"nearbyint", "NearbyintOp">;
3830+
def RintOp : UnaryFPToFPBuiltinOp<"rint", "RintOp">;
3831+
def RoundOp : UnaryFPToFPBuiltinOp<"round", "RoundOp">;
3832+
def SinOp : UnaryFPToFPBuiltinOp<"sin", "SinOp">;
3833+
def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt", "SqrtOp">;
3834+
def TruncOp : UnaryFPToFPBuiltinOp<"trunc", "FTruncOp">;
37893835

37903836
class BinaryFPToFPBuiltinOp<string mnemonic>
37913837
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
@@ -3987,6 +4033,8 @@ def StackRestoreOp : CIR_Op<"stack_restore"> {
39874033

39884034
let arguments = (ins CIR_PointerType:$ptr);
39894035
let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr))";
4036+
4037+
let llvmOp = "StackRestoreOp";
39904038
}
39914039

39924040
def AsmATT : I32EnumAttrCase<"x86_att", 0>;

clang/include/clang/CIR/Dialect/IR/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,7 @@ mlir_tablegen(CIROpsStructs.cpp.inc -gen-attrdef-defs)
2727
mlir_tablegen(CIROpsAttributes.h.inc -gen-attrdef-decls)
2828
mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs)
2929
add_public_tablegen_target(MLIRCIREnumsGen)
30+
31+
clang_tablegen(CIRBuiltinsLowering.inc -gen-cir-builtins-lowering
32+
SOURCE CIROps.td
33+
TARGET CIRBuiltinsLowering)

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+13-99
Original file line numberDiff line numberDiff line change
@@ -1330,34 +1330,6 @@ class CIRVectorCreateLowering
13301330
}
13311331
};
13321332

1333-
class CIRVectorInsertLowering
1334-
: public mlir::OpConversionPattern<mlir::cir::VecInsertOp> {
1335-
public:
1336-
using OpConversionPattern<mlir::cir::VecInsertOp>::OpConversionPattern;
1337-
1338-
mlir::LogicalResult
1339-
matchAndRewrite(mlir::cir::VecInsertOp op, OpAdaptor adaptor,
1340-
mlir::ConversionPatternRewriter &rewriter) const override {
1341-
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>(
1342-
op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex());
1343-
return mlir::success();
1344-
}
1345-
};
1346-
1347-
class CIRVectorExtractLowering
1348-
: public mlir::OpConversionPattern<mlir::cir::VecExtractOp> {
1349-
public:
1350-
using OpConversionPattern<mlir::cir::VecExtractOp>::OpConversionPattern;
1351-
1352-
mlir::LogicalResult
1353-
matchAndRewrite(mlir::cir::VecExtractOp op, OpAdaptor adaptor,
1354-
mlir::ConversionPatternRewriter &rewriter) const override {
1355-
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
1356-
op, adaptor.getVec(), adaptor.getIndex());
1357-
return mlir::success();
1358-
}
1359-
};
1360-
13611333
class CIRVectorCmpOpLowering
13621334
: public mlir::OpConversionPattern<mlir::cir::VecCmpOp> {
13631335
public:
@@ -3154,19 +3126,6 @@ class CIRPtrDiffOpLowering
31543126
}
31553127
};
31563128

3157-
class CIRFAbsOpLowering : public mlir::OpConversionPattern<mlir::cir::FAbsOp> {
3158-
public:
3159-
using OpConversionPattern<mlir::cir::FAbsOp>::OpConversionPattern;
3160-
3161-
mlir::LogicalResult
3162-
matchAndRewrite(mlir::cir::FAbsOp op, OpAdaptor adaptor,
3163-
mlir::ConversionPatternRewriter &rewriter) const override {
3164-
rewriter.replaceOpWithNewOp<mlir::LLVM::FAbsOp>(
3165-
op, adaptor.getOperands().front());
3166-
return mlir::success();
3167-
}
3168-
};
3169-
31703129
class CIRExpectOpLowering
31713130
: public mlir::OpConversionPattern<mlir::cir::ExpectOp> {
31723131
public:
@@ -3246,19 +3205,8 @@ class CIRStackSaveLowering
32463205
}
32473206
};
32483207

3249-
class CIRStackRestoreLowering
3250-
: public mlir::OpConversionPattern<mlir::cir::StackRestoreOp> {
3251-
public:
3252-
using OpConversionPattern<mlir::cir::StackRestoreOp>::OpConversionPattern;
3253-
3254-
mlir::LogicalResult
3255-
matchAndRewrite(mlir::cir::StackRestoreOp op, OpAdaptor adaptor,
3256-
mlir::ConversionPatternRewriter &rewriter) const override {
3257-
rewriter.replaceOpWithNewOp<mlir::LLVM::StackRestoreOp>(op,
3258-
adaptor.getPtr());
3259-
return mlir::success();
3260-
}
3261-
};
3208+
#define GET_BUILTIN_LOWERING_CLASSES
3209+
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"
32623210

32633211
class CIRUnreachableLowering
32643212
: public mlir::OpConversionPattern<mlir::cir::UnreachableOp> {
@@ -3601,38 +3549,6 @@ class CIRUnaryFPBuiltinOpLowering : public mlir::OpConversionPattern<CIROp> {
36013549
}
36023550
};
36033551

3604-
using CIRCeilOpLowering =
3605-
CIRUnaryFPBuiltinOpLowering<mlir::cir::CeilOp, mlir::LLVM::FCeilOp>;
3606-
using CIRCosOpLowering =
3607-
CIRUnaryFPBuiltinOpLowering<mlir::cir::CosOp, mlir::LLVM::CosOp>;
3608-
using CIRExpOpLowering =
3609-
CIRUnaryFPBuiltinOpLowering<mlir::cir::ExpOp, mlir::LLVM::ExpOp>;
3610-
using CIRExp2OpLowering =
3611-
CIRUnaryFPBuiltinOpLowering<mlir::cir::Exp2Op, mlir::LLVM::Exp2Op>;
3612-
using CIRFloorOpLowering =
3613-
CIRUnaryFPBuiltinOpLowering<mlir::cir::FloorOp, mlir::LLVM::FFloorOp>;
3614-
using CIRFabsOpLowering =
3615-
CIRUnaryFPBuiltinOpLowering<mlir::cir::FAbsOp, mlir::LLVM::FAbsOp>;
3616-
using CIRLogOpLowering =
3617-
CIRUnaryFPBuiltinOpLowering<mlir::cir::LogOp, mlir::LLVM::LogOp>;
3618-
using CIRLog10OpLowering =
3619-
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log10Op, mlir::LLVM::Log10Op>;
3620-
using CIRLog2OpLowering =
3621-
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log2Op, mlir::LLVM::Log2Op>;
3622-
using CIRNearbyintOpLowering =
3623-
CIRUnaryFPBuiltinOpLowering<mlir::cir::NearbyintOp,
3624-
mlir::LLVM::NearbyintOp>;
3625-
using CIRRintOpLowering =
3626-
CIRUnaryFPBuiltinOpLowering<mlir::cir::RintOp, mlir::LLVM::RintOp>;
3627-
using CIRRoundOpLowering =
3628-
CIRUnaryFPBuiltinOpLowering<mlir::cir::RoundOp, mlir::LLVM::RoundOp>;
3629-
using CIRSinOpLowering =
3630-
CIRUnaryFPBuiltinOpLowering<mlir::cir::SinOp, mlir::LLVM::SinOp>;
3631-
using CIRSqrtOpLowering =
3632-
CIRUnaryFPBuiltinOpLowering<mlir::cir::SqrtOp, mlir::LLVM::SqrtOp>;
3633-
using CIRTruncOpLowering =
3634-
CIRUnaryFPBuiltinOpLowering<mlir::cir::TruncOp, mlir::LLVM::FTruncOp>;
3635-
36363552
using CIRLroundOpLowering =
36373553
CIRUnaryFPBuiltinOpLowering<mlir::cir::LroundOp, mlir::LLVM::LroundOp>;
36383554
using CIRLLroundOpLowering =
@@ -3906,23 +3822,21 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
39063822
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
39073823
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
39083824
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
3909-
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
3910-
CIRVectorSplatLowering, CIRVectorTernaryLowering,
3825+
CIRVectorCmpOpLowering, CIRVectorSplatLowering, CIRVectorTernaryLowering,
39113826
CIRVectorShuffleIntsLowering, CIRVectorShuffleVecLowering,
3912-
CIRStackSaveLowering, CIRStackRestoreLowering, CIRUnreachableLowering,
3913-
CIRTrapLowering, CIRInlineAsmOpLowering, CIRSetBitfieldLowering,
3914-
CIRGetBitfieldLowering, CIRPrefetchLowering, CIRObjSizeOpLowering,
3915-
CIRIsConstantOpLowering, CIRCmpThreeWayOpLowering, CIRLroundOpLowering,
3916-
CIRLLroundOpLowering, CIRLrintOpLowering, CIRLLrintOpLowering,
3917-
CIRCeilOpLowering, CIRCosOpLowering, CIRExpOpLowering, CIRExp2OpLowering,
3918-
CIRFloorOpLowering, CIRFAbsOpLowering, CIRLogOpLowering,
3919-
CIRLog10OpLowering, CIRLog2OpLowering, CIRNearbyintOpLowering,
3920-
CIRRintOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
3921-
CIRSqrtOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering,
3827+
CIRStackSaveLowering, CIRUnreachableLowering, CIRTrapLowering,
3828+
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
3829+
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
3830+
CIRCmpThreeWayOpLowering, CIRLroundOpLowering, CIRLLroundOpLowering,
3831+
CIRLrintOpLowering, CIRLLrintOpLowering, CIRCopysignOpLowering,
39223832
CIRFModOpLowering, CIRFMaxOpLowering, CIRFMinOpLowering, CIRPowOpLowering,
39233833
CIRClearCacheOpLowering, CIRUndefOpLowering, CIREhTypeIdOpLowering,
39243834
CIRCatchParamOpLowering, CIRResumeOpLowering, CIRAllocExceptionOpLowering,
3925-
CIRThrowOpLowering>(converter, patterns.getContext());
3835+
CIRThrowOpLowering
3836+
#define GET_BUILTIN_LOWERING_LIST
3837+
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"
3838+
#undef GET_BUILTIN_LOWERING_LIST
3839+
>(converter, patterns.getContext());
39263840
}
39273841

39283842
namespace {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: cir-opt %s -cir-to-llvm -o %t.ll
2+
// RUN: FileCheck --input-file=%t.ll %s
3+
4+
module {
5+
cir.func @test(%arg0 : !cir.float) {
6+
%1 = cir.cos %arg0 : !cir.float
7+
// CHECK: llvm.intr.cos(%arg0) : (f32) -> f32
8+
9+
%2 = cir.ceil %arg0 : !cir.float
10+
// CHECK: llvm.intr.ceil(%arg0) : (f32) -> f32
11+
12+
%3 = cir.exp %arg0 : !cir.float
13+
// CHECK: llvm.intr.exp(%arg0) : (f32) -> f32
14+
15+
%4 = cir.exp2 %arg0 : !cir.float
16+
// CHECK: llvm.intr.exp2(%arg0) : (f32) -> f32
17+
18+
%5 = cir.fabs %arg0 : !cir.float
19+
// CHECK: llvm.intr.fabs(%arg0) : (f32) -> f32
20+
21+
%6 = cir.floor %arg0 : !cir.float
22+
// CHECK: llvm.intr.floor(%arg0) : (f32) -> f32
23+
24+
%7 = cir.log %arg0 : !cir.float
25+
// CHECK: llvm.intr.log(%arg0) : (f32) -> f32
26+
27+
%8 = cir.log10 %arg0 : !cir.float
28+
// CHECK: llvm.intr.log10(%arg0) : (f32) -> f32
29+
30+
%9 = cir.log2 %arg0 : !cir.float
31+
// CHECK: llvm.intr.log2(%arg0) : (f32) -> f32
32+
33+
%10 = cir.nearbyint %arg0 : !cir.float
34+
// CHECK: llvm.intr.nearbyint(%arg0) : (f32) -> f32
35+
36+
%11 = cir.rint %arg0 : !cir.float
37+
// CHECK: llvm.intr.rint(%arg0) : (f32) -> f32
38+
39+
%12 = cir.round %arg0 : !cir.float
40+
// CHECK: llvm.intr.round(%arg0) : (f32) -> f32
41+
42+
%13 = cir.sin %arg0 : !cir.float
43+
// CHECK: llvm.intr.sin(%arg0) : (f32) -> f32
44+
45+
%14 = cir.sqrt %arg0 : !cir.float
46+
// CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32
47+
48+
cir.return
49+
}
50+
}

0 commit comments

Comments
 (0)