diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index c508ec6412b6..0ea2edb58407 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -38,8 +38,48 @@ include "mlir/IR/CommonAttrConstraints.td" // CIR Ops //===----------------------------------------------------------------------===// +// LLVMLoweringInfo is used by cir-tablegen to generate LLVM lowering logic +// automatically for CIR operations. The `llvmOp` field gives the name of the +// LLVM IR dialect operation that the CIR operation will be lowered to. The +// input arguments of the CIR operation will be passed in the same order to the +// lowered LLVM IR operation. +// +// Example: +// +// For the following CIR operation definition: +// +// def FooOp : CIR_Op<"foo"> { +// // ... +// let arguments = (ins CIR_AnyType:$arg1, CIR_AnyType:$arg2); +// let llvmOp = "BarOp"; +// } +// +// cir-tablegen will generate LLVM lowering code for the FooOp similar to the +// following: +// +// class CIRFooOpLowering +// : public mlir::OpConversionPattern { +// public: +// using OpConversionPattern::OpConversionPattern; +// +// mlir::LogicalResult matchAndRewrite( +// mlir::cir::FooOp op, +// OpAdaptor adaptor, +// mlir::ConversionPatternRewriter &rewriter) const override { +// rewriter.replaceOpWithNewOp( +// op, adaptor.getOperands()[0], adaptor.getOperands()[1]); +// return mlir::success(); +// } +// } +// +// If you want fully customized LLVM IR lowering logic, simply exclude the +// `llvmOp` field from your CIR operation definition. +class LLVMLoweringInfo { + string llvmOp = ""; +} + class CIR_Op traits = []> : - Op; + Op, LLVMLoweringInfo; //===----------------------------------------------------------------------===// // CIR Op Traits @@ -2708,6 +2748,8 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure, }]; let hasVerifier = 0; + + let llvmOp = "InsertElementOp"; } //===----------------------------------------------------------------------===// @@ -2732,6 +2774,8 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure, }]; let hasVerifier = 0; + + let llvmOp = "ExtractElementOp"; } //===----------------------------------------------------------------------===// @@ -3762,30 +3806,32 @@ def LLroundOp : UnaryFPToIntBuiltinOp<"llround">; def LrintOp : UnaryFPToIntBuiltinOp<"lrint">; def LLrintOp : UnaryFPToIntBuiltinOp<"llrint">; -class UnaryFPToFPBuiltinOp +class UnaryFPToFPBuiltinOp : CIR_Op { let arguments = (ins CIR_AnyFloat:$src); let results = (outs CIR_AnyFloat:$result); let summary = "libc builtin equivalent ignoring " "floating point exceptions and errno"; let assemblyFormat = "$src `:` type($src) attr-dict"; -} -def CeilOp : UnaryFPToFPBuiltinOp<"ceil">; -def CosOp : UnaryFPToFPBuiltinOp<"cos">; -def ExpOp : UnaryFPToFPBuiltinOp<"exp">; -def Exp2Op : UnaryFPToFPBuiltinOp<"exp2">; -def FloorOp : UnaryFPToFPBuiltinOp<"floor">; -def FAbsOp : UnaryFPToFPBuiltinOp<"fabs">; -def LogOp : UnaryFPToFPBuiltinOp<"log">; -def Log10Op : UnaryFPToFPBuiltinOp<"log10">; -def Log2Op : UnaryFPToFPBuiltinOp<"log2">; -def NearbyintOp : UnaryFPToFPBuiltinOp<"nearbyint">; -def RintOp : UnaryFPToFPBuiltinOp<"rint">; -def RoundOp : UnaryFPToFPBuiltinOp<"round">; -def SinOp : UnaryFPToFPBuiltinOp<"sin">; -def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt">; -def TruncOp : UnaryFPToFPBuiltinOp<"trunc">; + let llvmOp = llvmOpName; +} + +def CeilOp : UnaryFPToFPBuiltinOp<"ceil", "FCeilOp">; +def CosOp : UnaryFPToFPBuiltinOp<"cos", "CosOp">; +def ExpOp : UnaryFPToFPBuiltinOp<"exp", "ExpOp">; +def Exp2Op : UnaryFPToFPBuiltinOp<"exp2", "Exp2Op">; +def FloorOp : UnaryFPToFPBuiltinOp<"floor", "FFloorOp">; +def FAbsOp : UnaryFPToFPBuiltinOp<"fabs", "FAbsOp">; +def LogOp : UnaryFPToFPBuiltinOp<"log", "LogOp">; +def Log10Op : UnaryFPToFPBuiltinOp<"log10", "Log10Op">; +def Log2Op : UnaryFPToFPBuiltinOp<"log2", "Log2Op">; +def NearbyintOp : UnaryFPToFPBuiltinOp<"nearbyint", "NearbyintOp">; +def RintOp : UnaryFPToFPBuiltinOp<"rint", "RintOp">; +def RoundOp : UnaryFPToFPBuiltinOp<"round", "RoundOp">; +def SinOp : UnaryFPToFPBuiltinOp<"sin", "SinOp">; +def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt", "SqrtOp">; +def TruncOp : UnaryFPToFPBuiltinOp<"trunc", "FTruncOp">; class BinaryFPToFPBuiltinOp : CIR_Op { @@ -3987,6 +4033,8 @@ def StackRestoreOp : CIR_Op<"stack_restore"> { let arguments = (ins CIR_PointerType:$ptr); let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr))"; + + let llvmOp = "StackRestoreOp"; } def AsmATT : I32EnumAttrCase<"x86_att", 0>; diff --git a/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt b/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt index c502525d30e8..3d43b06c6217 100644 --- a/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt +++ b/clang/include/clang/CIR/Dialect/IR/CMakeLists.txt @@ -27,3 +27,7 @@ mlir_tablegen(CIROpsStructs.cpp.inc -gen-attrdef-defs) mlir_tablegen(CIROpsAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(MLIRCIREnumsGen) + +clang_tablegen(CIRBuiltinsLowering.inc -gen-cir-builtins-lowering + SOURCE CIROps.td + TARGET CIRBuiltinsLowering) diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index e250d8a35132..2283f2cc22a2 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1330,34 +1330,6 @@ class CIRVectorCreateLowering } }; -class CIRVectorInsertLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(mlir::cir::VecInsertOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex()); - return mlir::success(); - } -}; - -class CIRVectorExtractLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(mlir::cir::VecExtractOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, adaptor.getVec(), adaptor.getIndex()); - return mlir::success(); - } -}; - class CIRVectorCmpOpLowering : public mlir::OpConversionPattern { public: @@ -3156,19 +3128,6 @@ class CIRPtrDiffOpLowering } }; -class CIRFAbsOpLowering : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(mlir::cir::FAbsOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, adaptor.getOperands().front()); - return mlir::success(); - } -}; - class CIRExpectOpLowering : public mlir::OpConversionPattern { public: @@ -3248,19 +3207,8 @@ class CIRStackSaveLowering } }; -class CIRStackRestoreLowering - : public mlir::OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(mlir::cir::StackRestoreOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, - adaptor.getPtr()); - return mlir::success(); - } -}; +#define GET_BUILTIN_LOWERING_CLASSES +#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc" class CIRUnreachableLowering : public mlir::OpConversionPattern { @@ -3603,38 +3551,6 @@ class CIRUnaryFPBuiltinOpLowering : public mlir::OpConversionPattern { } }; -using CIRCeilOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRCosOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRExpOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRExp2OpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRFloorOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRFabsOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRLogOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRLog10OpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRLog2OpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRNearbyintOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRRintOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRRoundOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRSinOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRSqrtOpLowering = - CIRUnaryFPBuiltinOpLowering; -using CIRTruncOpLowering = - CIRUnaryFPBuiltinOpLowering; - using CIRLroundOpLowering = CIRUnaryFPBuiltinOpLowering; using CIRLLroundOpLowering = @@ -3909,23 +3825,21 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns, CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering, CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering, CIRVTableAddrPointOpLowering, CIRVectorCreateLowering, - CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering, - CIRVectorSplatLowering, CIRVectorTernaryLowering, + CIRVectorCmpOpLowering, CIRVectorSplatLowering, CIRVectorTernaryLowering, CIRVectorShuffleIntsLowering, CIRVectorShuffleVecLowering, - CIRStackSaveLowering, CIRStackRestoreLowering, CIRUnreachableLowering, - CIRTrapLowering, CIRInlineAsmOpLowering, CIRSetBitfieldLowering, - CIRGetBitfieldLowering, CIRPrefetchLowering, CIRObjSizeOpLowering, - CIRIsConstantOpLowering, CIRCmpThreeWayOpLowering, CIRLroundOpLowering, - CIRLLroundOpLowering, CIRLrintOpLowering, CIRLLrintOpLowering, - CIRCeilOpLowering, CIRCosOpLowering, CIRExpOpLowering, CIRExp2OpLowering, - CIRFloorOpLowering, CIRFAbsOpLowering, CIRLogOpLowering, - CIRLog10OpLowering, CIRLog2OpLowering, CIRNearbyintOpLowering, - CIRRintOpLowering, CIRRoundOpLowering, CIRSinOpLowering, - CIRSqrtOpLowering, CIRTruncOpLowering, CIRCopysignOpLowering, + CIRStackSaveLowering, CIRUnreachableLowering, CIRTrapLowering, + CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering, + CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering, + CIRCmpThreeWayOpLowering, CIRLroundOpLowering, CIRLLroundOpLowering, + CIRLrintOpLowering, CIRLLrintOpLowering, CIRCopysignOpLowering, CIRFModOpLowering, CIRFMaxOpLowering, CIRFMinOpLowering, CIRPowOpLowering, CIRClearCacheOpLowering, CIRUndefOpLowering, CIREhTypeIdOpLowering, CIRCatchParamOpLowering, CIRResumeOpLowering, CIRAllocExceptionOpLowering, - CIRThrowOpLowering>(converter, patterns.getContext()); + CIRThrowOpLowering +#define GET_BUILTIN_LOWERING_LIST +#include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc" +#undef GET_BUILTIN_LOWERING_LIST + >(converter, patterns.getContext()); } namespace { diff --git a/clang/test/CIR/Lowering/builtin-floating-point.cir b/clang/test/CIR/Lowering/builtin-floating-point.cir new file mode 100644 index 000000000000..82b733233da3 --- /dev/null +++ b/clang/test/CIR/Lowering/builtin-floating-point.cir @@ -0,0 +1,50 @@ +// RUN: cir-opt %s -cir-to-llvm -o %t.ll +// RUN: FileCheck --input-file=%t.ll %s + +module { + cir.func @test(%arg0 : !cir.float) { + %1 = cir.cos %arg0 : !cir.float + // CHECK: llvm.intr.cos(%arg0) : (f32) -> f32 + + %2 = cir.ceil %arg0 : !cir.float + // CHECK: llvm.intr.ceil(%arg0) : (f32) -> f32 + + %3 = cir.exp %arg0 : !cir.float + // CHECK: llvm.intr.exp(%arg0) : (f32) -> f32 + + %4 = cir.exp2 %arg0 : !cir.float + // CHECK: llvm.intr.exp2(%arg0) : (f32) -> f32 + + %5 = cir.fabs %arg0 : !cir.float + // CHECK: llvm.intr.fabs(%arg0) : (f32) -> f32 + + %6 = cir.floor %arg0 : !cir.float + // CHECK: llvm.intr.floor(%arg0) : (f32) -> f32 + + %7 = cir.log %arg0 : !cir.float + // CHECK: llvm.intr.log(%arg0) : (f32) -> f32 + + %8 = cir.log10 %arg0 : !cir.float + // CHECK: llvm.intr.log10(%arg0) : (f32) -> f32 + + %9 = cir.log2 %arg0 : !cir.float + // CHECK: llvm.intr.log2(%arg0) : (f32) -> f32 + + %10 = cir.nearbyint %arg0 : !cir.float + // CHECK: llvm.intr.nearbyint(%arg0) : (f32) -> f32 + + %11 = cir.rint %arg0 : !cir.float + // CHECK: llvm.intr.rint(%arg0) : (f32) -> f32 + + %12 = cir.round %arg0 : !cir.float + // CHECK: llvm.intr.round(%arg0) : (f32) -> f32 + + %13 = cir.sin %arg0 : !cir.float + // CHECK: llvm.intr.sin(%arg0) : (f32) -> f32 + + %14 = cir.sqrt %arg0 : !cir.float + // CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32 + + cir.return + } +} diff --git a/clang/utils/TableGen/CIRLoweringEmitter.cpp b/clang/utils/TableGen/CIRLoweringEmitter.cpp new file mode 100644 index 000000000000..29daa63be86b --- /dev/null +++ b/clang/utils/TableGen/CIRLoweringEmitter.cpp @@ -0,0 +1,64 @@ +//===- CIRBuiltinsEmitter.cpp - Generate lowering of builtins --=-*- C++ -*--=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TableGenBackends.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace llvm; + +namespace { +std::string ClassDefinitions; +std::string ClassList; + +void GenerateLowering(raw_ostream &OS, const Record *Operation) { + using namespace std::string_literals; + std::string Name = Operation->getName().str(); + std::string LLVMOp = Operation->getValueAsString("llvmOp").str(); + ClassDefinitions += + "class CIR" + Name + + "Lowering : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::cir::)C++" + + Name + + R"C++( op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op"; + + auto ArgCount = Operation->getValueAsDag("arguments")->getNumArgs(); + for (size_t i = 0; i != ArgCount; ++i) + ClassDefinitions += ", adaptor.getOperands()[" + std::to_string(i) + ']'; + + ClassDefinitions += R"C++(); + return mlir::success(); + } +}; +)C++"; + + ClassList += ", CIR" + Name + "Lowering\n"; +} +} // namespace + +void clang::EmitCIRBuiltinsLowering(const RecordKeeper &Records, + raw_ostream &OS) { + emitSourceFileHeader("Lowering of ClangIR builtins to LLVM IR builtins", OS); + for (const auto *Builtin : + Records.getAllDerivedDefinitions("LLVMLoweringInfo")) { + if (!Builtin->getValueAsString("llvmOp").empty()) + GenerateLowering(OS, Builtin); + } + + OS << "#ifdef GET_BUILTIN_LOWERING_CLASSES\n" + << ClassDefinitions << "\n#undef GET_BUILTIN_LOWERING_CLASSES\n#endif\n"; + OS << "#ifdef GET_BUILTIN_LOWERING_LIST\n" + << ClassList << "\n#undef GET_BUILTIN_LOWERING_LIST\n#endif\n"; +} diff --git a/clang/utils/TableGen/CMakeLists.txt b/clang/utils/TableGen/CMakeLists.txt index 5b072a1ac196..df5d8c03f5a5 100644 --- a/clang/utils/TableGen/CMakeLists.txt +++ b/clang/utils/TableGen/CMakeLists.txt @@ -4,6 +4,7 @@ add_tablegen(clang-tblgen CLANG DESTINATION "${CLANG_TOOLS_INSTALL_DIR}" EXPORT Clang ASTTableGen.cpp + CIRLoweringEmitter.cpp ClangASTNodesEmitter.cpp ClangASTPropertiesEmitter.cpp ClangAttrEmitter.cpp diff --git a/clang/utils/TableGen/TableGen.cpp b/clang/utils/TableGen/TableGen.cpp index 39c178bc4f9b..60bc1bbbb1a6 100644 --- a/clang/utils/TableGen/TableGen.cpp +++ b/clang/utils/TableGen/TableGen.cpp @@ -26,6 +26,7 @@ using namespace clang; enum ActionType { PrintRecords, DumpJSON, + GenCIRBuiltinsLowering, GenClangAttrClasses, GenClangAttrParserStringSwitches, GenClangAttrSubjectMatchRulesParserStringSwitches, @@ -120,6 +121,9 @@ cl::opt Action( "Print all records to stdout (default)"), clEnumValN(DumpJSON, "dump-json", "Dump all records as machine-readable JSON"), + clEnumValN(GenCIRBuiltinsLowering, "gen-cir-builtins-lowering", + "Generate lowering of ClangIR builtins to equivalent LLVM " + "IR builtins"), clEnumValN(GenClangAttrClasses, "gen-clang-attr-classes", "Generate clang attribute clases"), clEnumValN(GenClangAttrParserStringSwitches, @@ -325,6 +329,9 @@ bool ClangTableGenMain(raw_ostream &OS, const RecordKeeper &Records) { case DumpJSON: EmitJSON(Records, OS); break; + case GenCIRBuiltinsLowering: + EmitCIRBuiltinsLowering(Records, OS); + break; case GenClangAttrClasses: EmitClangAttrClass(Records, OS); break; diff --git a/clang/utils/TableGen/TableGenBackends.h b/clang/utils/TableGen/TableGenBackends.h index f7527ac535a8..972357bac4d8 100644 --- a/clang/utils/TableGen/TableGenBackends.h +++ b/clang/utils/TableGen/TableGenBackends.h @@ -24,6 +24,8 @@ class RecordKeeper; namespace clang { +void EmitCIRBuiltinsLowering(const llvm::RecordKeeper &RK, + llvm::raw_ostream &OS); void EmitClangDeclContext(const llvm::RecordKeeper &RK, llvm::raw_ostream &OS); /** @param PriorizeIfSubclassOf These classes should be prioritized in the output.