Skip to content

Commit

Permalink
Add pattern for arith.extf and arith.truncf ops
Browse files Browse the repository at this point in the history
  • Loading branch information
nbpatel committed Nov 14, 2024
1 parent ab065d2 commit 03f08f2
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
68 changes: 62 additions & 6 deletions lib/Dialect/XeTile/Transforms/WgToSg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,63 @@ class WGToSGArithConstantOpPattern
}
};

// TODO: Templatize this pattern for similar elementwise ops
class WGToSGArithExtFOpPattern
: public XeOneToNConversion<mlir::arith::ExtFOp> {
using XeOneToNConversion<mlir::arith::ExtFOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::arith::ExtFOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {

auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());

auto mapAttr =
llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr("map"));
if (!mapAttr) {
return mlir::failure();
}

auto sgData = mapAttr.getSgData();

auto newTy =
mlir::VectorType::get({sgData[0], sgData[1]}, resType.getElementType());

auto newOp = rewriter.create<mlir::arith::ExtFOp>(op.getLoc(), newTy, adaptor.getOperands()[0]);
rewriter.replaceOp(op, newOp);
return mlir::success();
}
};

class WGToSGArithTruncFOpPattern
: public XeOneToNConversion<mlir::arith::TruncFOp> {
using XeOneToNConversion<mlir::arith::TruncFOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(mlir::arith::TruncFOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {

auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());

auto mapAttr =
llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr("map"));
if (!mapAttr) {
return mlir::failure();
}

auto sgData = mapAttr.getSgData();

auto newTy =
mlir::VectorType::get({sgData[0], sgData[1]}, resType.getElementType());

auto newOp = rewriter.create<mlir::arith::TruncFOp>(op.getLoc(), newTy, adaptor.getOperands()[0]);
rewriter.replaceOp(op, newOp);
return mlir::success();
}
};

class WGToSGVectorTranspose
:public XeOneToNConversion<mlir::vector::TransposeOp> {
using XeOneToNConversion<mlir::vector::TransposeOp>::XeOneToNConversion;
Expand Down Expand Up @@ -578,7 +635,6 @@ class WGToSGVectorTranspose
}
};


// This pattern transforms the convert layout op in the following manner:
// 1. Store the original vector to slm using input operand layout
// 2. Add barrier
Expand Down Expand Up @@ -853,10 +909,9 @@ void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter,
patterns.insert<WGToSGInitTileOpPattern, WGToSGLoadTileOpPattern,
WGToSGTileMMAOpPattern, WGToSGStoreTileOpPattern,
WGToSGSCFForOpPattern, WGToSGUpdateTileOffsetOpPattern,
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose,
WGToSGVectorBroadcast, WGToSGPrefetchOpPattern,
WGToSGXeTileConvertLayout>(patterns.getContext(),
converter, analysis);
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose, WGToSGVectorBroadcast,
WGToSGXeTileConvertLayout, WGToSGPrefetchOpPattern, WGToSGArithExtFOpPattern,
WGToSGArithTruncFOpPattern>(patterns.getContext(), converter, analysis);
patterns.insert<WGToSGElementWiseOpPattern<mlir::math::ExpOp, 1>,
WGToSGElementWiseOpPattern<mlir::arith::AddFOp, 2>,
WGToSGArithConstantOpPattern>(patterns.getContext(),
Expand Down Expand Up @@ -955,7 +1010,8 @@ class XeTileWgToSgPass
});

target.addDynamicallyLegalOp<mlir::arith::ConstantOp, mlir::arith::AddFOp,
mlir::math::ExpOp, mlir::vector::TransposeOp,
mlir::math::ExpOp, mlir::arith::ExtFOp,
mlir::arith::TruncFOp, mlir::vector::TransposeOp,
mlir::vector::BroadcastOp>(
[&](mlir::Operation *op) -> bool {
auto mapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(
Expand Down
14 changes: 14 additions & 0 deletions test/Dialect/XeTile/Transforms/unit_tests.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: imex-opt --split-input-file --xetile-wg-to-sg %s -verify-diagnostics | FileCheck %s

gpu.module @test_arith_extf {
gpu.func @test_kernel(%arg0: memref<128x32xf16>) {
%c0 = arith.constant 0 : index
%tile = xetile.init_tile %arg0[%c0, %c0] : memref<128x32xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%load_tile = xetile.load_tile %tile : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<128x32xf16>
//CHECK: arith.extf {{%.*}} : vector<32x32xf16> to vector<32x32xf32>
//CHECK: arith.truncf {{%.*}} : vector<32x32xf32> to vector<32x32xf16>
%extf = arith.extf %load_tile {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<128x32xf16> to vector<128x32xf32>
%trucf = arith.truncf %extf {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<128x32xf32> to vector<128x32xf16>
gpu.return
}
}

0 comments on commit 03f08f2

Please sign in to comment.