Skip to content

Commit f76dbd4

Browse files
Lancernphilnik777
authored andcommitted
[CIR][Lowering] Add the concept of simple lowering and use it for unary fp2fp operations (#806)
This PR is the continuation and refinement of PR #434 which is originally authored by @philnik777 . Does not update it in-place since I don't have commit access to Nikolas' repo. This PR basically just rebases #434 onto the latest `main`. I also updated some naming used in the original PR to keep naming styles consistent. Co-authored-by: Nikolas Klauser <[email protected]>
1 parent 92bf9ae commit f76dbd4

File tree

8 files changed

+207
-117
lines changed

8 files changed

+207
-117
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+66-18
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,48 @@ include "mlir/IR/CommonAttrConstraints.td"
3838
// CIR Ops
3939
//===----------------------------------------------------------------------===//
4040

41+
// LLVMLoweringInfo is used by cir-tablegen to generate LLVM lowering logic
42+
// automatically for CIR operations. The `llvmOp` field gives the name of the
43+
// LLVM IR dialect operation that the CIR operation will be lowered to. The
44+
// input arguments of the CIR operation will be passed in the same order to the
45+
// lowered LLVM IR operation.
46+
//
47+
// Example:
48+
//
49+
// For the following CIR operation definition:
50+
//
51+
// def FooOp : CIR_Op<"foo"> {
52+
// // ...
53+
// let arguments = (ins CIR_AnyType:$arg1, CIR_AnyType:$arg2);
54+
// let llvmOp = "BarOp";
55+
// }
56+
//
57+
// cir-tablegen will generate LLVM lowering code for the FooOp similar to the
58+
// following:
59+
//
60+
// class CIRFooOpLowering
61+
// : public mlir::OpConversionPattern<mlir::cir::FooOp> {
62+
// public:
63+
// using OpConversionPattern<mlir::cir::FooOp>::OpConversionPattern;
64+
//
65+
// mlir::LogicalResult matchAndRewrite(
66+
// mlir::cir::FooOp op,
67+
// OpAdaptor adaptor,
68+
// mlir::ConversionPatternRewriter &rewriter) const override {
69+
// rewriter.replaceOpWithNewOp<mlir::LLVM::BarOp>(
70+
// op, adaptor.getOperands()[0], adaptor.getOperands()[1]);
71+
// return mlir::success();
72+
// }
73+
// }
74+
//
75+
// If you want fully customized LLVM IR lowering logic, simply exclude the
76+
// `llvmOp` field from your CIR operation definition.
77+
class LLVMLoweringInfo {
78+
string llvmOp = "";
79+
}
80+
4181
class CIR_Op<string mnemonic, list<Trait> traits = []> :
42-
Op<CIR_Dialect, mnemonic, traits>;
82+
Op<CIR_Dialect, mnemonic, traits>, LLVMLoweringInfo;
4383

4484
//===----------------------------------------------------------------------===//
4585
// CIR Op Traits
@@ -2708,6 +2748,8 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,
27082748
}];
27092749

27102750
let hasVerifier = 0;
2751+
2752+
let llvmOp = "InsertElementOp";
27112753
}
27122754

27132755
//===----------------------------------------------------------------------===//
@@ -2732,6 +2774,8 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
27322774
}];
27332775

27342776
let hasVerifier = 0;
2777+
2778+
let llvmOp = "ExtractElementOp";
27352779
}
27362780

27372781
//===----------------------------------------------------------------------===//
@@ -3771,30 +3815,32 @@ def LLroundOp : UnaryFPToIntBuiltinOp<"llround">;
37713815
def LrintOp : UnaryFPToIntBuiltinOp<"lrint">;
37723816
def LLrintOp : UnaryFPToIntBuiltinOp<"llrint">;
37733817

3774-
class UnaryFPToFPBuiltinOp<string mnemonic>
3818+
class UnaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
37753819
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
37763820
let arguments = (ins CIR_AnyFloat:$src);
37773821
let results = (outs CIR_AnyFloat:$result);
37783822
let summary = "libc builtin equivalent ignoring "
37793823
"floating point exceptions and errno";
37803824
let assemblyFormat = "$src `:` type($src) attr-dict";
3781-
}
37823825

3783-
def CeilOp : UnaryFPToFPBuiltinOp<"ceil">;
3784-
def CosOp : UnaryFPToFPBuiltinOp<"cos">;
3785-
def ExpOp : UnaryFPToFPBuiltinOp<"exp">;
3786-
def Exp2Op : UnaryFPToFPBuiltinOp<"exp2">;
3787-
def FloorOp : UnaryFPToFPBuiltinOp<"floor">;
3788-
def FAbsOp : UnaryFPToFPBuiltinOp<"fabs">;
3789-
def LogOp : UnaryFPToFPBuiltinOp<"log">;
3790-
def Log10Op : UnaryFPToFPBuiltinOp<"log10">;
3791-
def Log2Op : UnaryFPToFPBuiltinOp<"log2">;
3792-
def NearbyintOp : UnaryFPToFPBuiltinOp<"nearbyint">;
3793-
def RintOp : UnaryFPToFPBuiltinOp<"rint">;
3794-
def RoundOp : UnaryFPToFPBuiltinOp<"round">;
3795-
def SinOp : UnaryFPToFPBuiltinOp<"sin">;
3796-
def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt">;
3797-
def TruncOp : UnaryFPToFPBuiltinOp<"trunc">;
3826+
let llvmOp = llvmOpName;
3827+
}
3828+
3829+
def CeilOp : UnaryFPToFPBuiltinOp<"ceil", "FCeilOp">;
3830+
def CosOp : UnaryFPToFPBuiltinOp<"cos", "CosOp">;
3831+
def ExpOp : UnaryFPToFPBuiltinOp<"exp", "ExpOp">;
3832+
def Exp2Op : UnaryFPToFPBuiltinOp<"exp2", "Exp2Op">;
3833+
def FloorOp : UnaryFPToFPBuiltinOp<"floor", "FFloorOp">;
3834+
def FAbsOp : UnaryFPToFPBuiltinOp<"fabs", "FAbsOp">;
3835+
def LogOp : UnaryFPToFPBuiltinOp<"log", "LogOp">;
3836+
def Log10Op : UnaryFPToFPBuiltinOp<"log10", "Log10Op">;
3837+
def Log2Op : UnaryFPToFPBuiltinOp<"log2", "Log2Op">;
3838+
def NearbyintOp : UnaryFPToFPBuiltinOp<"nearbyint", "NearbyintOp">;
3839+
def RintOp : UnaryFPToFPBuiltinOp<"rint", "RintOp">;
3840+
def RoundOp : UnaryFPToFPBuiltinOp<"round", "RoundOp">;
3841+
def SinOp : UnaryFPToFPBuiltinOp<"sin", "SinOp">;
3842+
def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt", "SqrtOp">;
3843+
def TruncOp : UnaryFPToFPBuiltinOp<"trunc", "FTruncOp">;
37983844

37993845
class BinaryFPToFPBuiltinOp<string mnemonic>
38003846
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
@@ -3996,6 +4042,8 @@ def StackRestoreOp : CIR_Op<"stack_restore"> {
39964042

39974043
let arguments = (ins CIR_PointerType:$ptr);
39984044
let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr))";
4045+
4046+
let llvmOp = "StackRestoreOp";
39994047
}
40004048

40014049
def AsmATT : I32EnumAttrCase<"x86_att", 0>;

clang/include/clang/CIR/Dialect/IR/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,7 @@ mlir_tablegen(CIROpsStructs.cpp.inc -gen-attrdef-defs)
2727
mlir_tablegen(CIROpsAttributes.h.inc -gen-attrdef-decls)
2828
mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs)
2929
add_public_tablegen_target(MLIRCIREnumsGen)
30+
31+
clang_tablegen(CIRBuiltinsLowering.inc -gen-cir-builtins-lowering
32+
SOURCE CIROps.td
33+
TARGET CIRBuiltinsLowering)

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+13-99
Original file line numberDiff line numberDiff line change
@@ -1331,34 +1331,6 @@ class CIRVectorCreateLowering
13311331
}
13321332
};
13331333

1334-
class CIRVectorInsertLowering
1335-
: public mlir::OpConversionPattern<mlir::cir::VecInsertOp> {
1336-
public:
1337-
using OpConversionPattern<mlir::cir::VecInsertOp>::OpConversionPattern;
1338-
1339-
mlir::LogicalResult
1340-
matchAndRewrite(mlir::cir::VecInsertOp op, OpAdaptor adaptor,
1341-
mlir::ConversionPatternRewriter &rewriter) const override {
1342-
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>(
1343-
op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex());
1344-
return mlir::success();
1345-
}
1346-
};
1347-
1348-
class CIRVectorExtractLowering
1349-
: public mlir::OpConversionPattern<mlir::cir::VecExtractOp> {
1350-
public:
1351-
using OpConversionPattern<mlir::cir::VecExtractOp>::OpConversionPattern;
1352-
1353-
mlir::LogicalResult
1354-
matchAndRewrite(mlir::cir::VecExtractOp op, OpAdaptor adaptor,
1355-
mlir::ConversionPatternRewriter &rewriter) const override {
1356-
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
1357-
op, adaptor.getVec(), adaptor.getIndex());
1358-
return mlir::success();
1359-
}
1360-
};
1361-
13621334
class CIRVectorCmpOpLowering
13631335
: public mlir::OpConversionPattern<mlir::cir::VecCmpOp> {
13641336
public:
@@ -3155,19 +3127,6 @@ class CIRPtrDiffOpLowering
31553127
}
31563128
};
31573129

3158-
class CIRFAbsOpLowering : public mlir::OpConversionPattern<mlir::cir::FAbsOp> {
3159-
public:
3160-
using OpConversionPattern<mlir::cir::FAbsOp>::OpConversionPattern;
3161-
3162-
mlir::LogicalResult
3163-
matchAndRewrite(mlir::cir::FAbsOp op, OpAdaptor adaptor,
3164-
mlir::ConversionPatternRewriter &rewriter) const override {
3165-
rewriter.replaceOpWithNewOp<mlir::LLVM::FAbsOp>(
3166-
op, adaptor.getOperands().front());
3167-
return mlir::success();
3168-
}
3169-
};
3170-
31713130
class CIRExpectOpLowering
31723131
: public mlir::OpConversionPattern<mlir::cir::ExpectOp> {
31733132
public:
@@ -3247,19 +3206,8 @@ class CIRStackSaveLowering
32473206
}
32483207
};
32493208

3250-
class CIRStackRestoreLowering
3251-
: public mlir::OpConversionPattern<mlir::cir::StackRestoreOp> {
3252-
public:
3253-
using OpConversionPattern<mlir::cir::StackRestoreOp>::OpConversionPattern;
3254-
3255-
mlir::LogicalResult
3256-
matchAndRewrite(mlir::cir::StackRestoreOp op, OpAdaptor adaptor,
3257-
mlir::ConversionPatternRewriter &rewriter) const override {
3258-
rewriter.replaceOpWithNewOp<mlir::LLVM::StackRestoreOp>(op,
3259-
adaptor.getPtr());
3260-
return mlir::success();
3261-
}
3262-
};
3209+
#define GET_BUILTIN_LOWERING_CLASSES
3210+
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"
32633211

32643212
class CIRUnreachableLowering
32653213
: public mlir::OpConversionPattern<mlir::cir::UnreachableOp> {
@@ -3602,38 +3550,6 @@ class CIRUnaryFPBuiltinOpLowering : public mlir::OpConversionPattern<CIROp> {
36023550
}
36033551
};
36043552

3605-
using CIRCeilOpLowering =
3606-
CIRUnaryFPBuiltinOpLowering<mlir::cir::CeilOp, mlir::LLVM::FCeilOp>;
3607-
using CIRCosOpLowering =
3608-
CIRUnaryFPBuiltinOpLowering<mlir::cir::CosOp, mlir::LLVM::CosOp>;
3609-
using CIRExpOpLowering =
3610-
CIRUnaryFPBuiltinOpLowering<mlir::cir::ExpOp, mlir::LLVM::ExpOp>;
3611-
using CIRExp2OpLowering =
3612-
CIRUnaryFPBuiltinOpLowering<mlir::cir::Exp2Op, mlir::LLVM::Exp2Op>;
3613-
using CIRFloorOpLowering =
3614-
CIRUnaryFPBuiltinOpLowering<mlir::cir::FloorOp, mlir::LLVM::FFloorOp>;
3615-
using CIRFabsOpLowering =
3616-
CIRUnaryFPBuiltinOpLowering<mlir::cir::FAbsOp, mlir::LLVM::FAbsOp>;
3617-
using CIRLogOpLowering =
3618-
CIRUnaryFPBuiltinOpLowering<mlir::cir::LogOp, mlir::LLVM::LogOp>;
3619-
using CIRLog10OpLowering =
3620-
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log10Op, mlir::LLVM::Log10Op>;
3621-
using CIRLog2OpLowering =
3622-
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log2Op, mlir::LLVM::Log2Op>;
3623-
using CIRNearbyintOpLowering =
3624-
CIRUnaryFPBuiltinOpLowering<mlir::cir::NearbyintOp,
3625-
mlir::LLVM::NearbyintOp>;
3626-
using CIRRintOpLowering =
3627-
CIRUnaryFPBuiltinOpLowering<mlir::cir::RintOp, mlir::LLVM::RintOp>;
3628-
using CIRRoundOpLowering =
3629-
CIRUnaryFPBuiltinOpLowering<mlir::cir::RoundOp, mlir::LLVM::RoundOp>;
3630-
using CIRSinOpLowering =
3631-
CIRUnaryFPBuiltinOpLowering<mlir::cir::SinOp, mlir::LLVM::SinOp>;
3632-
using CIRSqrtOpLowering =
3633-
CIRUnaryFPBuiltinOpLowering<mlir::cir::SqrtOp, mlir::LLVM::SqrtOp>;
3634-
using CIRTruncOpLowering =
3635-
CIRUnaryFPBuiltinOpLowering<mlir::cir::TruncOp, mlir::LLVM::FTruncOp>;
3636-
36373553
using CIRLroundOpLowering =
36383554
CIRUnaryFPBuiltinOpLowering<mlir::cir::LroundOp, mlir::LLVM::LroundOp>;
36393555
using CIRLLroundOpLowering =
@@ -3907,23 +3823,21 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
39073823
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
39083824
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
39093825
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
3910-
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
3911-
CIRVectorSplatLowering, CIRVectorTernaryLowering,
3826+
CIRVectorCmpOpLowering, CIRVectorSplatLowering, CIRVectorTernaryLowering,
39123827
CIRVectorShuffleIntsLowering, CIRVectorShuffleVecLowering,
3913-
CIRStackSaveLowering, CIRStackRestoreLowering, CIRUnreachableLowering,
3914-
CIRTrapLowering, CIRInlineAsmOpLowering, CIRSetBitfieldLowering,
3915-
CIRGetBitfieldLowering, CIRPrefetchLowering, CIRObjSizeOpLowering,
3916-
CIRIsConstantOpLowering, CIRCmpThreeWayOpLowering, CIRLroundOpLowering,
3917-
CIRLLroundOpLowering, CIRLrintOpLowering, CIRLLrintOpLowering,
3918-
CIRCeilOpLowering, CIRCosOpLowering, CIRExpOpLowering, CIRExp2OpLowering,
3919-
CIRFloorOpLowering, CIRFAbsOpLowering, CIRLogOpLowering,
3920-
CIRLog10OpLowering, CIRLog2OpLowering, CIRNearbyintOpLowering,
3921-
CIRRintOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
3922-
CIRSqrtOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering,
3828+
CIRStackSaveLowering, CIRUnreachableLowering, CIRTrapLowering,
3829+
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
3830+
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
3831+
CIRCmpThreeWayOpLowering, CIRLroundOpLowering, CIRLLroundOpLowering,
3832+
CIRLrintOpLowering, CIRLLrintOpLowering, CIRCopysignOpLowering,
39233833
CIRFModOpLowering, CIRFMaxOpLowering, CIRFMinOpLowering, CIRPowOpLowering,
39243834
CIRClearCacheOpLowering, CIRUndefOpLowering, CIREhTypeIdOpLowering,
39253835
CIRCatchParamOpLowering, CIRResumeOpLowering, CIRAllocExceptionOpLowering,
3926-
CIRThrowOpLowering>(converter, patterns.getContext());
3836+
CIRThrowOpLowering
3837+
#define GET_BUILTIN_LOWERING_LIST
3838+
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"
3839+
#undef GET_BUILTIN_LOWERING_LIST
3840+
>(converter, patterns.getContext());
39273841
}
39283842

39293843
namespace {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: cir-opt %s -cir-to-llvm -o %t.ll
2+
// RUN: FileCheck --input-file=%t.ll %s
3+
4+
module {
5+
cir.func @test(%arg0 : !cir.float) {
6+
%1 = cir.cos %arg0 : !cir.float
7+
// CHECK: llvm.intr.cos(%arg0) : (f32) -> f32
8+
9+
%2 = cir.ceil %arg0 : !cir.float
10+
// CHECK: llvm.intr.ceil(%arg0) : (f32) -> f32
11+
12+
%3 = cir.exp %arg0 : !cir.float
13+
// CHECK: llvm.intr.exp(%arg0) : (f32) -> f32
14+
15+
%4 = cir.exp2 %arg0 : !cir.float
16+
// CHECK: llvm.intr.exp2(%arg0) : (f32) -> f32
17+
18+
%5 = cir.fabs %arg0 : !cir.float
19+
// CHECK: llvm.intr.fabs(%arg0) : (f32) -> f32
20+
21+
%6 = cir.floor %arg0 : !cir.float
22+
// CHECK: llvm.intr.floor(%arg0) : (f32) -> f32
23+
24+
%7 = cir.log %arg0 : !cir.float
25+
// CHECK: llvm.intr.log(%arg0) : (f32) -> f32
26+
27+
%8 = cir.log10 %arg0 : !cir.float
28+
// CHECK: llvm.intr.log10(%arg0) : (f32) -> f32
29+
30+
%9 = cir.log2 %arg0 : !cir.float
31+
// CHECK: llvm.intr.log2(%arg0) : (f32) -> f32
32+
33+
%10 = cir.nearbyint %arg0 : !cir.float
34+
// CHECK: llvm.intr.nearbyint(%arg0) : (f32) -> f32
35+
36+
%11 = cir.rint %arg0 : !cir.float
37+
// CHECK: llvm.intr.rint(%arg0) : (f32) -> f32
38+
39+
%12 = cir.round %arg0 : !cir.float
40+
// CHECK: llvm.intr.round(%arg0) : (f32) -> f32
41+
42+
%13 = cir.sin %arg0 : !cir.float
43+
// CHECK: llvm.intr.sin(%arg0) : (f32) -> f32
44+
45+
%14 = cir.sqrt %arg0 : !cir.float
46+
// CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32
47+
48+
cir.return
49+
}
50+
}

0 commit comments

Comments
 (0)