Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][Lowering] Add the concept of simple lowering and use it for unary fp2fp operations #806

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 66 additions & 18 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,48 @@ include "mlir/IR/CommonAttrConstraints.td"
// CIR Ops
//===----------------------------------------------------------------------===//

// LLVMLoweringInfo is used by cir-tablegen to generate LLVM lowering logic
// automatically for CIR operations. The `llvmOp` field gives the name of the
// LLVM IR dialect operation that the CIR operation will be lowered to. The
// input arguments of the CIR operation will be passed in the same order to the
// lowered LLVM IR operation.
//
// Example:
//
// For the following CIR operation definition:
//
// def FooOp : CIR_Op<"foo"> {
// // ...
// let arguments = (ins CIR_AnyType:$arg1, CIR_AnyType:$arg2);
// let llvmOp = "BarOp";
// }
//
// cir-tablegen will generate LLVM lowering code for the FooOp similar to the
// following:
//
// class CIRFooOpLowering
// : public mlir::OpConversionPattern<mlir::cir::FooOp> {
// public:
// using OpConversionPattern<mlir::cir::FooOp>::OpConversionPattern;
//
// mlir::LogicalResult matchAndRewrite(
// mlir::cir::FooOp op,
// OpAdaptor adaptor,
// mlir::ConversionPatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<mlir::LLVM::BarOp>(
// op, adaptor.getOperands()[0], adaptor.getOperands()[1]);
// return mlir::success();
// }
// }
//
// If you want fully customized LLVM IR lowering logic, simply exclude the
// `llvmOp` field from your CIR operation definition.
class LLVMLoweringInfo {
string llvmOp = "";
}

class CIR_Op<string mnemonic, list<Trait> traits = []> :
Op<CIR_Dialect, mnemonic, traits>;
Op<CIR_Dialect, mnemonic, traits>, LLVMLoweringInfo;

//===----------------------------------------------------------------------===//
// CIR Op Traits
Expand Down Expand Up @@ -2708,6 +2748,8 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,
}];

let hasVerifier = 0;

let llvmOp = "InsertElementOp";
}

//===----------------------------------------------------------------------===//
Expand All @@ -2732,6 +2774,8 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
}];

let hasVerifier = 0;

let llvmOp = "ExtractElementOp";
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3762,30 +3806,32 @@ def LLroundOp : UnaryFPToIntBuiltinOp<"llround">;
def LrintOp : UnaryFPToIntBuiltinOp<"lrint">;
def LLrintOp : UnaryFPToIntBuiltinOp<"llrint">;

class UnaryFPToFPBuiltinOp<string mnemonic>
class UnaryFPToFPBuiltinOp<string mnemonic, string llvmOpName>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let arguments = (ins CIR_AnyFloat:$src);
let results = (outs CIR_AnyFloat:$result);
let summary = "libc builtin equivalent ignoring "
"floating point exceptions and errno";
let assemblyFormat = "$src `:` type($src) attr-dict";
}

def CeilOp : UnaryFPToFPBuiltinOp<"ceil">;
def CosOp : UnaryFPToFPBuiltinOp<"cos">;
def ExpOp : UnaryFPToFPBuiltinOp<"exp">;
def Exp2Op : UnaryFPToFPBuiltinOp<"exp2">;
def FloorOp : UnaryFPToFPBuiltinOp<"floor">;
def FAbsOp : UnaryFPToFPBuiltinOp<"fabs">;
def LogOp : UnaryFPToFPBuiltinOp<"log">;
def Log10Op : UnaryFPToFPBuiltinOp<"log10">;
def Log2Op : UnaryFPToFPBuiltinOp<"log2">;
def NearbyintOp : UnaryFPToFPBuiltinOp<"nearbyint">;
def RintOp : UnaryFPToFPBuiltinOp<"rint">;
def RoundOp : UnaryFPToFPBuiltinOp<"round">;
def SinOp : UnaryFPToFPBuiltinOp<"sin">;
def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt">;
def TruncOp : UnaryFPToFPBuiltinOp<"trunc">;
let llvmOp = llvmOpName;
}

def CeilOp : UnaryFPToFPBuiltinOp<"ceil", "FCeilOp">;
def CosOp : UnaryFPToFPBuiltinOp<"cos", "CosOp">;
def ExpOp : UnaryFPToFPBuiltinOp<"exp", "ExpOp">;
def Exp2Op : UnaryFPToFPBuiltinOp<"exp2", "Exp2Op">;
def FloorOp : UnaryFPToFPBuiltinOp<"floor", "FFloorOp">;
def FAbsOp : UnaryFPToFPBuiltinOp<"fabs", "FAbsOp">;
def LogOp : UnaryFPToFPBuiltinOp<"log", "LogOp">;
def Log10Op : UnaryFPToFPBuiltinOp<"log10", "Log10Op">;
def Log2Op : UnaryFPToFPBuiltinOp<"log2", "Log2Op">;
def NearbyintOp : UnaryFPToFPBuiltinOp<"nearbyint", "NearbyintOp">;
def RintOp : UnaryFPToFPBuiltinOp<"rint", "RintOp">;
def RoundOp : UnaryFPToFPBuiltinOp<"round", "RoundOp">;
def SinOp : UnaryFPToFPBuiltinOp<"sin", "SinOp">;
def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt", "SqrtOp">;
def TruncOp : UnaryFPToFPBuiltinOp<"trunc", "FTruncOp">;

class BinaryFPToFPBuiltinOp<string mnemonic>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
Expand Down Expand Up @@ -3987,6 +4033,8 @@ def StackRestoreOp : CIR_Op<"stack_restore"> {

let arguments = (ins CIR_PointerType:$ptr);
let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr))";

let llvmOp = "StackRestoreOp";
}

def AsmATT : I32EnumAttrCase<"x86_att", 0>;
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ mlir_tablegen(CIROpsStructs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(CIROpsAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRCIREnumsGen)

clang_tablegen(CIRBuiltinsLowering.inc -gen-cir-builtins-lowering
SOURCE CIROps.td
TARGET CIRBuiltinsLowering)
112 changes: 13 additions & 99 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1330,34 +1330,6 @@ class CIRVectorCreateLowering
}
};

class CIRVectorInsertLowering
: public mlir::OpConversionPattern<mlir::cir::VecInsertOp> {
public:
using OpConversionPattern<mlir::cir::VecInsertOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VecInsertOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>(
op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex());
return mlir::success();
}
};

class CIRVectorExtractLowering
: public mlir::OpConversionPattern<mlir::cir::VecExtractOp> {
public:
using OpConversionPattern<mlir::cir::VecExtractOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VecExtractOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
op, adaptor.getVec(), adaptor.getIndex());
return mlir::success();
}
};

class CIRVectorCmpOpLowering
: public mlir::OpConversionPattern<mlir::cir::VecCmpOp> {
public:
Expand Down Expand Up @@ -3154,19 +3126,6 @@ class CIRPtrDiffOpLowering
}
};

class CIRFAbsOpLowering : public mlir::OpConversionPattern<mlir::cir::FAbsOp> {
public:
using OpConversionPattern<mlir::cir::FAbsOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::FAbsOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::FAbsOp>(
op, adaptor.getOperands().front());
return mlir::success();
}
};

class CIRExpectOpLowering
: public mlir::OpConversionPattern<mlir::cir::ExpectOp> {
public:
Expand Down Expand Up @@ -3246,19 +3205,8 @@ class CIRStackSaveLowering
}
};

class CIRStackRestoreLowering
: public mlir::OpConversionPattern<mlir::cir::StackRestoreOp> {
public:
using OpConversionPattern<mlir::cir::StackRestoreOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::StackRestoreOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::LLVM::StackRestoreOp>(op,
adaptor.getPtr());
return mlir::success();
}
};
#define GET_BUILTIN_LOWERING_CLASSES
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"

class CIRUnreachableLowering
: public mlir::OpConversionPattern<mlir::cir::UnreachableOp> {
Expand Down Expand Up @@ -3601,38 +3549,6 @@ class CIRUnaryFPBuiltinOpLowering : public mlir::OpConversionPattern<CIROp> {
}
};

using CIRCeilOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::CeilOp, mlir::LLVM::FCeilOp>;
using CIRCosOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::CosOp, mlir::LLVM::CosOp>;
using CIRExpOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::ExpOp, mlir::LLVM::ExpOp>;
using CIRExp2OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Exp2Op, mlir::LLVM::Exp2Op>;
using CIRFloorOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::FloorOp, mlir::LLVM::FFloorOp>;
using CIRFabsOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::FAbsOp, mlir::LLVM::FAbsOp>;
using CIRLogOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::LogOp, mlir::LLVM::LogOp>;
using CIRLog10OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log10Op, mlir::LLVM::Log10Op>;
using CIRLog2OpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::Log2Op, mlir::LLVM::Log2Op>;
using CIRNearbyintOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::NearbyintOp,
mlir::LLVM::NearbyintOp>;
using CIRRintOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::RintOp, mlir::LLVM::RintOp>;
using CIRRoundOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::RoundOp, mlir::LLVM::RoundOp>;
using CIRSinOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::SinOp, mlir::LLVM::SinOp>;
using CIRSqrtOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::SqrtOp, mlir::LLVM::SqrtOp>;
using CIRTruncOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::TruncOp, mlir::LLVM::FTruncOp>;

using CIRLroundOpLowering =
CIRUnaryFPBuiltinOpLowering<mlir::cir::LroundOp, mlir::LLVM::LroundOp>;
using CIRLLroundOpLowering =
Expand Down Expand Up @@ -3906,23 +3822,21 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
CIRVectorSplatLowering, CIRVectorTernaryLowering,
CIRVectorCmpOpLowering, CIRVectorSplatLowering, CIRVectorTernaryLowering,
CIRVectorShuffleIntsLowering, CIRVectorShuffleVecLowering,
CIRStackSaveLowering, CIRStackRestoreLowering, CIRUnreachableLowering,
CIRTrapLowering, CIRInlineAsmOpLowering, CIRSetBitfieldLowering,
CIRGetBitfieldLowering, CIRPrefetchLowering, CIRObjSizeOpLowering,
CIRIsConstantOpLowering, CIRCmpThreeWayOpLowering, CIRLroundOpLowering,
CIRLLroundOpLowering, CIRLrintOpLowering, CIRLLrintOpLowering,
CIRCeilOpLowering, CIRCosOpLowering, CIRExpOpLowering, CIRExp2OpLowering,
CIRFloorOpLowering, CIRFAbsOpLowering, CIRLogOpLowering,
CIRLog10OpLowering, CIRLog2OpLowering, CIRNearbyintOpLowering,
CIRRintOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
CIRSqrtOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering,
CIRStackSaveLowering, CIRUnreachableLowering, CIRTrapLowering,
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering,
CIRCmpThreeWayOpLowering, CIRLroundOpLowering, CIRLLroundOpLowering,
CIRLrintOpLowering, CIRLLrintOpLowering, CIRCopysignOpLowering,
CIRFModOpLowering, CIRFMaxOpLowering, CIRFMinOpLowering, CIRPowOpLowering,
CIRClearCacheOpLowering, CIRUndefOpLowering, CIREhTypeIdOpLowering,
CIRCatchParamOpLowering, CIRResumeOpLowering, CIRAllocExceptionOpLowering,
CIRThrowOpLowering>(converter, patterns.getContext());
CIRThrowOpLowering
#define GET_BUILTIN_LOWERING_LIST
#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc"
#undef GET_BUILTIN_LOWERING_LIST
>(converter, patterns.getContext());
}

namespace {
Expand Down
50 changes: 50 additions & 0 deletions clang/test/CIR/Lowering/builtin-floating-point.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: cir-opt %s -cir-to-llvm -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s

module {
cir.func @test(%arg0 : !cir.float) {
%1 = cir.cos %arg0 : !cir.float
// CHECK: llvm.intr.cos(%arg0) : (f32) -> f32

%2 = cir.ceil %arg0 : !cir.float
// CHECK: llvm.intr.ceil(%arg0) : (f32) -> f32

%3 = cir.exp %arg0 : !cir.float
// CHECK: llvm.intr.exp(%arg0) : (f32) -> f32

%4 = cir.exp2 %arg0 : !cir.float
// CHECK: llvm.intr.exp2(%arg0) : (f32) -> f32

%5 = cir.fabs %arg0 : !cir.float
// CHECK: llvm.intr.fabs(%arg0) : (f32) -> f32

%6 = cir.floor %arg0 : !cir.float
// CHECK: llvm.intr.floor(%arg0) : (f32) -> f32

%7 = cir.log %arg0 : !cir.float
// CHECK: llvm.intr.log(%arg0) : (f32) -> f32

%8 = cir.log10 %arg0 : !cir.float
// CHECK: llvm.intr.log10(%arg0) : (f32) -> f32

%9 = cir.log2 %arg0 : !cir.float
// CHECK: llvm.intr.log2(%arg0) : (f32) -> f32

%10 = cir.nearbyint %arg0 : !cir.float
// CHECK: llvm.intr.nearbyint(%arg0) : (f32) -> f32

%11 = cir.rint %arg0 : !cir.float
// CHECK: llvm.intr.rint(%arg0) : (f32) -> f32

%12 = cir.round %arg0 : !cir.float
// CHECK: llvm.intr.round(%arg0) : (f32) -> f32

%13 = cir.sin %arg0 : !cir.float
// CHECK: llvm.intr.sin(%arg0) : (f32) -> f32

%14 = cir.sqrt %arg0 : !cir.float
// CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32

cir.return
}
}
Loading