From 03f08f2af21ff25ac49c185983370c19f64bb2cf Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 14 Nov 2024 16:45:40 +0000 Subject: [PATCH] Add pattern for arith.extf and arith.truncf ops --- lib/Dialect/XeTile/Transforms/WgToSg.cpp | 68 +++++++++++++++++-- .../Dialect/XeTile/Transforms/unit_tests.mlir | 14 ++++ 2 files changed, 76 insertions(+), 6 deletions(-) create mode 100644 test/Dialect/XeTile/Transforms/unit_tests.mlir diff --git a/lib/Dialect/XeTile/Transforms/WgToSg.cpp b/lib/Dialect/XeTile/Transforms/WgToSg.cpp index e64c13071..7fed049e5 100644 --- a/lib/Dialect/XeTile/Transforms/WgToSg.cpp +++ b/lib/Dialect/XeTile/Transforms/WgToSg.cpp @@ -538,6 +538,63 @@ class WGToSGArithConstantOpPattern } }; +// TODO: Templatize this pattern for similar elementwise ops +class WGToSGArithExtFOpPattern + : public XeOneToNConversion { + using XeOneToNConversion::XeOneToNConversion; + + mlir::LogicalResult + matchAndRewrite(mlir::arith::ExtFOp op, OpAdaptor adaptor, + XeOneToNPatternRewriter &rewriter) const override { + + auto res = op.getResult(); + auto resType = mlir::dyn_cast(res.getType()); + + auto mapAttr = + llvm::dyn_cast_or_null(op->getAttr("map")); + if (!mapAttr) { + return mlir::failure(); + } + + auto sgData = mapAttr.getSgData(); + + auto newTy = + mlir::VectorType::get({sgData[0], sgData[1]}, resType.getElementType()); + + auto newOp = rewriter.create(op.getLoc(), newTy, adaptor.getOperands()[0]); + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + +class WGToSGArithTruncFOpPattern + : public XeOneToNConversion { + using XeOneToNConversion::XeOneToNConversion; + + mlir::LogicalResult + matchAndRewrite(mlir::arith::TruncFOp op, OpAdaptor adaptor, + XeOneToNPatternRewriter &rewriter) const override { + + auto res = op.getResult(); + auto resType = mlir::dyn_cast(res.getType()); + + auto mapAttr = + llvm::dyn_cast_or_null(op->getAttr("map")); + if (!mapAttr) { + return mlir::failure(); + } + + auto sgData = mapAttr.getSgData(); + + auto newTy = + mlir::VectorType::get({sgData[0], sgData[1]}, resType.getElementType()); + + auto newOp = rewriter.create(op.getLoc(), newTy, adaptor.getOperands()[0]); + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + class WGToSGVectorTranspose :public XeOneToNConversion { using XeOneToNConversion::XeOneToNConversion; @@ -578,7 +635,6 @@ class WGToSGVectorTranspose } }; - // This pattern transforms the convert layout op in the following manner: // 1. Store the original vector to slm using input operand layout // 2. Add barrier @@ -853,10 +909,9 @@ void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter, patterns.insert(patterns.getContext(), - converter, analysis); + WGToSGSCFYieldOpPattern, WGToSGVectorTranspose, WGToSGVectorBroadcast, + WGToSGXeTileConvertLayout, WGToSGPrefetchOpPattern, WGToSGArithExtFOpPattern, + WGToSGArithTruncFOpPattern>(patterns.getContext(), converter, analysis); patterns.insert, WGToSGElementWiseOpPattern, WGToSGArithConstantOpPattern>(patterns.getContext(), @@ -955,7 +1010,8 @@ class XeTileWgToSgPass }); target.addDynamicallyLegalOp( [&](mlir::Operation *op) -> bool { auto mapAttr = llvm::dyn_cast_or_null( diff --git a/test/Dialect/XeTile/Transforms/unit_tests.mlir b/test/Dialect/XeTile/Transforms/unit_tests.mlir new file mode 100644 index 000000000..ebb120480 --- /dev/null +++ b/test/Dialect/XeTile/Transforms/unit_tests.mlir @@ -0,0 +1,14 @@ +// RUN: imex-opt --split-input-file --xetile-wg-to-sg %s -verify-diagnostics | FileCheck %s + +gpu.module @test_arith_extf { + gpu.func @test_kernel(%arg0: memref<128x32xf16>) { + %c0 = arith.constant 0 : index + %tile = xetile.init_tile %arg0[%c0, %c0] : memref<128x32xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %load_tile = xetile.load_tile %tile : !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<128x32xf16> + //CHECK: arith.extf {{%.*}} : vector<32x32xf16> to vector<32x32xf32> + //CHECK: arith.truncf {{%.*}} : vector<32x32xf32> to vector<32x32xf16> + %extf = arith.extf %load_tile {map = #xetile.wg_map} : vector<128x32xf16> to vector<128x32xf32> + %trucf = arith.truncf %extf {map = #xetile.wg_map} : vector<128x32xf32> to vector<128x32xf16> + gpu.return + } +}