diff --git a/third_party/amd/CMakeLists.txt b/third_party/amd/CMakeLists.txt index b42c09214034..8228c3d39111 100644 --- a/third_party/amd/CMakeLists.txt +++ b/third_party/amd/CMakeLists.txt @@ -3,7 +3,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) - add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms AMDGPUToLLVM) + add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms TritonAMDGPUDialectToLLVM) endif() if(TRITON_BUILD_UT) add_subdirectory(unittest) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index b039c006d3b1..9ec33d1aca53 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -212,7 +212,6 @@ def make_llir(src, metadata, options): ## For now it is used as a controller for developers only. __HIP_FTZ = True amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ) - amd.passes.ttgpuir.add_amdgpu_to_llvm(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) diff --git a/third_party/amd/include/AMDGPUToLLVM/AMDGPUToLLVMPass.h b/third_party/amd/include/AMDGPUToLLVM/AMDGPUToLLVMPass.h deleted file mode 100644 index 909d7f7a307d..000000000000 --- a/third_party/amd/include/AMDGPUToLLVM/AMDGPUToLLVMPass.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef TRITON_CONVERSION_AMDGPU_TO_LLVM_PASS_H -#define TRITON_CONVERSION_AMDGPU_TO_LLVM_PASS_H - -#include -#include -#include -#include - -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { - -class ModuleOp; -template class OperationPass; - -namespace triton { - -namespace amdgpu {} // namespace amdgpu - -std::unique_ptr> createConvertAMDGPUToLLVMPass(); - -} // namespace triton - -} // namespace mlir - -#endif diff --git a/third_party/amd/include/AMDGPUToLLVM/CMakeLists.txt b/third_party/amd/include/AMDGPUToLLVM/CMakeLists.txt deleted file mode 100644 index 39f64085ced1..000000000000 --- a/third_party/amd/include/AMDGPUToLLVM/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls --name AMDGPUToLLVM) -add_public_tablegen_target(AMDGPUConversionPassIncGen) diff --git a/third_party/amd/include/AMDGPUToLLVM/Passes.h b/third_party/amd/include/AMDGPUToLLVM/Passes.h deleted file mode 100644 index 30f8d29c3986..000000000000 --- a/third_party/amd/include/AMDGPUToLLVM/Passes.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef AMDGPU_CONVERSION_PASSES_H -#define AMDGPU_CONVERSION_PASSES_H - -#include "amd/include/AMDGPUToLLVM/AMDGPUToLLVMPass.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace triton { - -#define GEN_PASS_REGISTRATION -#include "AMD/include/AMDGPUToLLVM/Passes.h.inc" - -} // namespace triton -} // namespace mlir - -#endif diff --git a/third_party/amd/include/AMDGPUToLLVM/Passes.td b/third_party/amd/include/AMDGPUToLLVM/Passes.td deleted file mode 100644 index 5fd74b6fd417..000000000000 --- a/third_party/amd/include/AMDGPUToLLVM/Passes.td +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef AMDGPU_CONVERSION_PASSES -#define AMDGPU_CONVERSION_PASSES - -include "mlir/Pass/PassBase.td" - - -def ConvertAMDGPUToLLVM : Pass<"convert-amd-gpu-to-llvm", "mlir::ModuleOp"> { - let summary = "Convert AMDGPU to LLVM"; - let description = [{ - }]; - let constructor = "mlir::triton::createConvertAMDGPUToLLVMPass()"; - - let dependentDialects = ["mlir::LLVM::LLVMDialect"]; -} - -#endif diff --git a/third_party/amd/include/CMakeLists.txt b/third_party/amd/include/CMakeLists.txt index 4e7089f327cc..08707d601b23 100644 --- a/third_party/amd/include/CMakeLists.txt +++ b/third_party/amd/include/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(AMDGPUToLLVM) +add_subdirectory(Dialect) add_subdirectory(TritonAMDGPUToLLVM) add_subdirectory(TritonAMDGPUTransforms) -add_subdirectory(Dialect) diff --git a/third_party/amd/include/Dialect/AMDGPU/IR/CMakeLists.txt b/third_party/amd/include/Dialect/AMDGPU/IR/CMakeLists.txt deleted file mode 100644 index 819cc768b9b0..000000000000 --- a/third_party/amd/include/Dialect/AMDGPU/IR/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) - -set(LLVM_TARGET_DEFINITIONS AMDGPUOps.td) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=amdgpu) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=amdgpu) -mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) -mlir_tablegen(Ops.h.inc -gen-op-decls) -mlir_tablegen(Ops.cpp.inc -gen-op-defs) -add_mlir_doc(AMDGPUDialect AMDGPUDialect dialects/ -gen-dialect-doc) -add_mlir_doc(AMDGPUOps AMDGPUOps dialects/ -gen-op-doc) -add_public_tablegen_target(AMDGPUTableGen) - -set(LLVM_TARGET_DEFINITIONS AMDGPUAttrDefs.td) -mlir_tablegen(AMDGPUAttrDefs.h.inc -gen-attrdef-decls) -mlir_tablegen(AMDGPUAttrDefs.cpp.inc -gen-attrdef-defs) -add_public_tablegen_target(AMDGPUAttrDefsIncGen) diff --git a/third_party/amd/include/Dialect/CMakeLists.txt b/third_party/amd/include/Dialect/CMakeLists.txt index 4122b4f2009b..4f9163bdf016 100644 --- a/third_party/amd/include/Dialect/CMakeLists.txt +++ b/third_party/amd/include/Dialect/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(AMDGPU) +add_subdirectory(TritonAMDGPU) diff --git a/third_party/amd/include/Dialect/AMDGPU/CMakeLists.txt b/third_party/amd/include/Dialect/TritonAMDGPU/CMakeLists.txt similarity index 100% rename from third_party/amd/include/Dialect/AMDGPU/CMakeLists.txt rename to third_party/amd/include/Dialect/TritonAMDGPU/CMakeLists.txt diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt new file mode 100644 index 000000000000..25a57075be01 --- /dev/null +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt @@ -0,0 +1,16 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonAMDGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=amdgpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=amdgpu) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +add_mlir_doc(TritonAMDGPUDialect TritonAMDGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonAMDGPUOps TritonAMDGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonAMDGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonAMDGPUAttrDefs.td) +mlir_tablegen(TritonAMDGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonAMDGPUAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(TritonAMDGPUAttrDefsIncGen) diff --git a/third_party/amd/include/Dialect/AMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h similarity index 85% rename from third_party/amd/include/Dialect/AMDGPU/IR/Dialect.h rename to third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h index d89431737200..f7f824e30b23 100644 --- a/third_party/amd/include/Dialect/AMDGPU/IR/Dialect.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -24,17 +24,20 @@ #ifndef TRITON_DIALECT_AMDGPU_IR_DIALECT_H_ #define TRITON_DIALECT_AMDGPU_IR_DIALECT_H_ -#include "amd/include/Dialect/AMDGPU/IR/Dialect.h.inc" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +// clang-format off +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" +// clang-format on #define GET_ATTRDEF_CLASSES -#include "amd/include/Dialect/AMDGPU/IR/AMDGPUAttrDefs.h.inc" +#include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.h.inc" #define GET_OP_CLASSES -#include "amd/include/Dialect/AMDGPU/IR/Ops.h.inc" +#include "amd/include/Dialect/TritonAMDGPU/IR/Ops.h.inc" namespace mlir { namespace triton { diff --git a/third_party/amd/include/Dialect/AMDGPU/IR/AMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td similarity index 85% rename from third_party/amd/include/Dialect/AMDGPU/IR/AMDGPUAttrDefs.td rename to third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index ce0610fba008..31a43acd2f89 100644 --- a/third_party/amd/include/Dialect/AMDGPU/IR/AMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -21,15 +21,15 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -#ifndef AMDGPU_ATTRDEFS -#define AMDGPU_ATTRDEFS +#ifndef TRITON_AMDGPU_ATTRDEFS +#define TRITON_AMDGPU_ATTRDEFS include "mlir/IR/AttrTypeBase.td" -include "AMDGPUDialect.td" +include "TritonAMDGPUDialect.td" -class AMDGPU_Attr traits = [], +class TritonAMDGPU_Attr traits = [], string baseCppClass = "::mlir::Attribute"> - : AttrDef { + : AttrDef { } #endif diff --git a/third_party/amd/include/Dialect/AMDGPU/IR/AMDGPUDialect.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td similarity index 88% rename from third_party/amd/include/Dialect/AMDGPU/IR/AMDGPUDialect.td rename to third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td index 9f099cd9692e..d5956cf7a33c 100644 --- a/third_party/amd/include/Dialect/AMDGPU/IR/AMDGPUDialect.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td @@ -21,17 +21,17 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -#ifndef AMDGPU_DIALECT -#define AMDGPU_DIALECT +#ifndef TRITON_AMDGPU_DIALECT +#define TRITON_AMDGPU_DIALECT include "mlir/IR/OpBase.td" -def AMDGPU_Dialect : Dialect { +def TritonAMDGPU_Dialect : Dialect { let name = "amdgpu"; let cppNamespace = "::mlir::triton::amdgpu"; let description = [{ - AMDGPU Dialect. + TritonAMDGPU Dialect hosts AMD specific ops at TritonGPU abstraction level. }]; let dependentDialects = []; diff --git a/third_party/amd/include/Dialect/AMDGPU/IR/AMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td similarity index 91% rename from third_party/amd/include/Dialect/AMDGPU/IR/AMDGPUOps.td rename to third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index f93b3c967d6d..8960c0514359 100644 --- a/third_party/amd/include/Dialect/AMDGPU/IR/AMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -22,14 +22,14 @@ */ -#ifndef AMDGPU_OPS -#define AMDGPU_OPS +#ifndef TRITON_AMDGPU_OPS +#define TRITON_AMDGPU_OPS include "mlir/IR/OpBase.td" include "mlir/IR/EnumAttr.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "AMDGPUDialect.td" -include "AMDGPUAttrDefs.td" +include "TritonAMDGPUDialect.td" +include "TritonAMDGPUAttrDefs.td" #endif diff --git a/third_party/amd/lib/AMDGPUToLLVM/AMDGPUToLLVMPass.cpp b/third_party/amd/lib/AMDGPUToLLVM/AMDGPUToLLVMPass.cpp deleted file mode 100644 index 82631f238898..000000000000 --- a/third_party/amd/lib/AMDGPUToLLVM/AMDGPUToLLVMPass.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include "AMDGPUToLLVM/AMDGPUToLLVMPass.h" -#include "Dialect/AMDGPU/IR/Dialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" -// clang-format off -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "amd/lib/TritonAMDGPUToLLVM/Utility.h" -// clang-format on - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::gpu; -namespace tta = mlir::triton::amdgpu; - -#define GEN_PASS_CLASSES -#include "AMDGPUToLLVM/Passes.h.inc" - -namespace mlir::triton::AMD { - -class AMDDialectLLVMConversionTarget : public ConversionTarget { -public: - explicit AMDDialectLLVMConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addIllegalDialect(); - addIllegalDialect(); - addLegalOp(); - } -}; - -void populateAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, - PatternBenefit benefit) { - // TODO: Add some actual patterns to lower -} -} // namespace mlir::triton::AMD - -class ConvertAMDGPUToLLVM - : public ConvertAMDGPUToLLVMBase { - -public: - explicit ConvertAMDGPUToLLVM() {} - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - RewritePatternSet patterns(context); - mlir::LowerToLLVMOptions option(context); - - TritonGPUToLLVMTypeConverter typeConverter(context, option); - ModuleAxisInfoAnalysis axisInfoAnalysis(mod); - - constexpr int benefit = 1; - AMD::populateAMDGPUToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis, - benefit); - - AMD::AMDDialectLLVMConversionTarget convTarget(*context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -namespace mlir { -namespace triton { -std::unique_ptr> createConvertAMDGPUToLLVMPass() { - return std::make_unique<::ConvertAMDGPUToLLVM>(); -} -} // namespace triton -} // namespace mlir diff --git a/third_party/amd/lib/AMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/AMDGPUToLLVM/CMakeLists.txt deleted file mode 100644 index 0be169d1e0bc..000000000000 --- a/third_party/amd/lib/AMDGPUToLLVM/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -add_triton_library(AMDGPUToLLVM - AMDGPUToLLVMPass.cpp - - DEPENDS - AMDGPUConversionPassIncGen - AMDGPUIR -) diff --git a/third_party/amd/lib/CMakeLists.txt b/third_party/amd/lib/CMakeLists.txt index 269d13b1fd8b..15c000ab8886 100644 --- a/third_party/amd/lib/CMakeLists.txt +++ b/third_party/amd/lib/CMakeLists.txt @@ -1,4 +1,4 @@ -add_subdirectory(AMDGPUToLLVM) add_subdirectory(Dialect) add_subdirectory(TritonAMDGPUToLLVM) +add_subdirectory(TritonAMDGPUDialectToLLVM) add_subdirectory(TritonAMDGPUTransforms) diff --git a/third_party/amd/lib/Dialect/CMakeLists.txt b/third_party/amd/lib/Dialect/CMakeLists.txt index 4122b4f2009b..4f9163bdf016 100644 --- a/third_party/amd/lib/Dialect/CMakeLists.txt +++ b/third_party/amd/lib/Dialect/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(AMDGPU) +add_subdirectory(TritonAMDGPU) diff --git a/third_party/amd/lib/Dialect/AMDGPU/CMakeLists.txt b/third_party/amd/lib/Dialect/TritonAMDGPU/CMakeLists.txt similarity index 100% rename from third_party/amd/lib/Dialect/AMDGPU/CMakeLists.txt rename to third_party/amd/lib/Dialect/TritonAMDGPU/CMakeLists.txt diff --git a/third_party/amd/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/CMakeLists.txt similarity index 51% rename from third_party/amd/lib/Dialect/AMDGPU/IR/CMakeLists.txt rename to third_party/amd/lib/Dialect/TritonAMDGPU/IR/CMakeLists.txt index 7afadd36319a..f550b6e20c9f 100644 --- a/third_party/amd/lib/Dialect/AMDGPU/IR/CMakeLists.txt +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/CMakeLists.txt @@ -1,9 +1,9 @@ -add_triton_library(AMDGPUIR +add_triton_library(TritonAMDGPUIR Dialect.cpp DEPENDS - AMDGPUTableGen - AMDGPUAttrDefsIncGen + TritonAMDGPUTableGen + TritonAMDGPUAttrDefsIncGen LINK_LIBS PUBLIC MLIRLLVMDialect diff --git a/third_party/amd/lib/Dialect/AMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp similarity index 82% rename from third_party/amd/lib/Dialect/AMDGPU/IR/Dialect.cpp rename to third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index a52d1974f32d..5631d56b24b6 100644 --- a/third_party/amd/lib/Dialect/AMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -25,24 +25,24 @@ #include "mlir/IR/OpImplementation.h" // clang-format off -#include "Dialect/AMDGPU/IR/Dialect.h" -#include "Dialect/AMDGPU/IR/Dialect.cpp.inc" +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "Dialect/TritonAMDGPU/IR/Dialect.cpp.inc" // clang-format on using namespace mlir; using namespace mlir::triton::amdgpu; -void mlir::triton::amdgpu::AMDGPUDialect::initialize() { +void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST -#include "Dialect/AMDGPU/IR/AMDGPUAttrDefs.cpp.inc" +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" >(); addOperations< #define GET_OP_LIST -#include "Dialect/AMDGPU/IR/Ops.cpp.inc" +#include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" >(); } #define GET_OP_CLASSES -#include "Dialect/AMDGPU/IR/Ops.cpp.inc" +#include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..e6da8f28777e --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt @@ -0,0 +1,6 @@ +add_triton_library(TritonAMDGPUDialectToLLVM + TritonAMDGPUToLLVMPatterns.cpp + + DEPENDS + TritonAMDGPUIR +) diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp new file mode 100644 index 000000000000..5d172fea9cfa --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -0,0 +1,9 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +namespace mlir::triton::AMD { +void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { + // TODO: Insert TrtionAMDGPU dialect patterns. +} +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index 67e5369b8650..764f31a610e1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -30,6 +30,9 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 8649911a7c2d..08631e211ec2 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -200,6 +200,10 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); AMD::populateSPMDOpToLLVMPattern(typeConverter, patterns, AMDBenefit); + + mlir::triton::AMD::populateTritonAMDGPUToLLVMPatterns(typeConverter, + patterns, AMDBenefit); + // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns // to help convert scalar expression to LLVM. diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 25599d688eda..da5718ac6efe 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -1,5 +1,4 @@ -#include "AMDGPUToLLVM/AMDGPUToLLVMPass.h" -#include "Dialect/AMDGPU/IR/Dialect.h" +#include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "TritonAMDGPUToLLVM/Passes.h" #include "TritonAMDGPUToLLVM/TargetUtils.h" #include "TritonAMDGPUTransforms/Passes.h" @@ -66,8 +65,6 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUStreamPipelinePass); ADD_PASS_WRAPPER_1("add_stream_pipelinev2", mlir::createTritonAMDGPUStreamPipelineV2Pass, int); - ADD_PASS_WRAPPER_0("add_amdgpu_to_llvm", - mlir::triton::createConvertAMDGPUToLLVMPass); } void addControlConstant(llvm::Module *module, const char *name, @@ -100,7 +97,7 @@ void init_triton_amd(py::module &&m) { m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); // registry.insert(); mlir::registerROCDLDialectTranslation(registry); context.appendDialectRegistry(registry);