Skip to content

Commit

Permalink
Add pattern to transform xetile.convert_layout op from wg to sg (#954)
Browse files Browse the repository at this point in the history
* Add pattern to transform xetile.convert_layout op from wg to sg

* Fix pre-commit
  • Loading branch information
nbpatel authored Nov 6, 2024
1 parent c741d7d commit b2742c3
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 4 deletions.
139 changes: 135 additions & 4 deletions lib/Dialect/XeTile/Transforms/WgToSg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,137 @@ class WGToSGVectorTranspose
};


// This pattern transforms the convert layout op in the following manner:
// 1. Store the original vector to slm using input operand layout
// 2. Add barrier
// 3. Load the vector from slm using the result layout

// Example:
// WG IR
// #wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
// #wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>
// %vector_a = xetile.tile_conv_layout %vector_b {wg_map_result = #wg_map_a, wg_map_source = #wg_map_b}: vector<256x256xfloat> into vector<256x256xfloat>

// SG IR
// %slm = memref.alloc() : memref<256x256xf32, 3>
// %tile = xetile.init_tile %slm[offset_x, offset_y] : memref<256x256xf32, 3> -> xetile.tile<32x64xf32>
// xetile.store_tile %vector_b, %tile :vector<32x64xf32>, !xetile.tile<32x64xf32>
// gpu.barrier
// %remapped_tile = xetile.init_tile %slm[offsetX, offsetY] : memref<256x256xf32, 3> -> xetile.tile<8x256xf32>
// %remapped_vector = xetile.load_tile %reshaped_tile : xetile.tile<8x256xf32> -> vector<8x256xf32>
class WGToSGXeTileConvertLayout
:public XeOneToNConversion<xetile::ConvertLayoutOp> {
using XeOneToNConversion<xetile::ConvertLayoutOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(xetile::ConvertLayoutOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {
if (op.getSource().getType().getRank() != 2)
return mlir::failure();

auto loc = op.getLoc();
auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
auto elemTy = resType.getElementType();
auto resShape = resType.getShape();

auto dstMapAttr =
llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr("wg_map_result"));

xetile::WorkGroupMapAttr srcMapAttr;
srcMapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr("wg_map_source"));

if (!dstMapAttr) {
return mlir::failure();
}

if(!srcMapAttr) {
// Get the map from operand
auto operand = op.getSource().getDefiningOp();
srcMapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(operand->getAttr("map"));
if (!srcMapAttr) {
return mlir::failure();
}
}

auto srcMapSgData = srcMapAttr.getSgData();
auto srcSgLayout = srcMapAttr.getSgLayout();
auto dstMapSgData = dstMapAttr.getSgData();
auto dstSgLayout = dstMapAttr.getSgLayout();

auto createIndexConstant = [&](mlir::Type type, int64_t value) {
auto attr = rewriter.getIndexAttr(value);
return rewriter.create<mlir::arith::ConstantOp>(loc, type, attr);
};

rewriter.setInsertionPoint(op);
// Allocate SLM
// TODO: Allocate slm as 1D array of i8, and then create the expected view on it.
auto slmTy = mlir::MemRefType::get({resShape[0], resShape[1]}, elemTy, {}, 3);
auto slm = rewriter.create<mlir::memref::AllocOp>(loc, slmTy);

// Get SG id
auto sgId = rewriter.create<mlir::gpu::SubgroupIdOp>(
loc, rewriter.getIndexType(), nullptr);

auto indexType = rewriter.getIndexType();
auto srcMapDimY = createIndexConstant(indexType, srcSgLayout[1]);

// The sgID is a linear (1D) id. Convert it to 2D to get the x and y
// coordinates of sg
// row = i / cols
// col = i % cols
// x is row, y is col
// TODO: Floorsdiv and Remu are expensive. Find alterate.
auto storeSgIdX =
rewriter.create<mlir::index::FloorDivSOp>(loc, sgId, srcMapDimY);
auto storeSgIdY =
rewriter.create<mlir::index::RemUOp>(loc, sgId, srcMapDimY);

// Store to SLM using src map
auto memoryScopeAttr = mlir::IntegerAttr::get(rewriter.getIntegerType(32), 3);
auto order = mlir::DenseI32ArrayAttr::get(op.getContext(), {1, 0});
auto attr = imex::xetile::XeTileAttr::get(
op.getContext(), nullptr /*sgMap*/, nullptr /*wgMap*/,
order /*order*/, nullptr /*innerblocks*/, memoryScopeAttr /*memoryscope*/,
nullptr /*scatterAttr*/);
xetile::TileType srcTileTy =
imex::xetile::TileType::get({srcMapSgData[0], srcMapSgData[1]}, elemTy, attr);

auto storeOffsetX = rewriter.createOrFold<mlir::index::MulOp>(
loc, storeSgIdX, createIndexConstant(indexType, srcMapSgData[0]));
auto storeOffsetY = rewriter.createOrFold<mlir::index::MulOp>(
loc, storeSgIdY, createIndexConstant(indexType, srcMapSgData[1]));
auto storeInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, srcTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({storeOffsetX, storeOffsetY}));
rewriter.create<xetile::StoreTileOp>(loc, adaptor.getSource()[0],
storeInitTileOp);

// Add barrier
rewriter.create<mlir::gpu::BarrierOp>(loc);

// Load from SLM with result map
xetile::TileType dstTileTy =
imex::xetile::TileType::get({dstMapSgData[0], dstMapSgData[1]}, elemTy, attr);
auto newResTy =
mlir::VectorType::get({dstMapSgData[0], dstMapSgData[1]}, elemTy);

auto dstMapDimY = createIndexConstant(indexType, dstSgLayout[1]);
auto loadSgIdX = rewriter.create<mlir::index::FloorDivSOp>(loc, sgId, dstMapDimY);
auto loadSgIdY = rewriter.create<mlir::index::RemUOp>(loc, sgId, dstMapDimY);
auto loadOffsetX = rewriter.createOrFold<mlir::index::MulOp>(
loc, loadSgIdX, createIndexConstant(indexType, dstMapSgData[0]));
auto loadOffsetY = rewriter.createOrFold<mlir::index::MulOp>(
loc, loadSgIdY, createIndexConstant(indexType, dstMapSgData[1]));
auto loadInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, dstTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({loadOffsetX, loadOffsetY}));
auto loadTile = rewriter.create<xetile::LoadTileOp>(
loc, newResTy, loadInitTileOp, mlir::Attribute());

rewriter.replaceOp(op, loadTile);
return mlir::success();
}
};

class WGToSGVectorBroadcast
:public XeOneToNConversion<mlir::vector::BroadcastOp> {
Expand Down Expand Up @@ -704,16 +835,14 @@ void analyzeInitTileOps(mlir::Operation *op) {
}



void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter,
mlir::RewritePatternSet &patterns,
TileUsageAnalysis &analysis) {
patterns.insert<WGToSGInitTileOpPattern, WGToSGLoadTileOpPattern,
WGToSGTileMMAOpPattern, WGToSGStoreTileOpPattern,
WGToSGSCFForOpPattern, WGToSGUpdateTileOffsetOpPattern,
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose,
WGToSGVectorBroadcast>(patterns.getContext(), converter,
analysis);
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose, WGToSGVectorBroadcast,
WGToSGXeTileConvertLayout>(patterns.getContext(), converter, analysis);
patterns.insert<WGToSGElementWiseOpPattern<mlir::math::ExpOp, 1>,
WGToSGElementWiseOpPattern<mlir::arith::AddFOp, 2>,
WGToSGArithConstantOpPattern>(patterns.getContext(),
Expand Down Expand Up @@ -824,6 +953,8 @@ class XeTileWgToSgPass
return false;
});

target.addIllegalOp<xetile::ConvertLayoutOp>();

target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

populateXeTileWgToSgPatterns(typeConverter, patterns, analysis);
Expand Down
36 changes: 36 additions & 0 deletions test/Dialect/XeTile/Transforms/wg_to_sg_convert_layout.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s

gpu.module @test_convert_layout{
//CHECK: gpu.func @test_kernel()
gpu.func @test_kernel() {
//CHECK: %[[c0:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
//CHECK: %[[c0_0:.*]] = arith.constant dense<0.000000e+00> : vector<8x256xf32>
//CHECK: %[[SLM:.*]] = memref.alloc() : memref<256x256xf32, 3>
//CHECK: %[[R0:.*]] = gpu.subgroup_id : index
//CHECK: %[[c4:.*]] = arith.constant 4 : index
//CHECK: %[[R1:.*]] = index.floordivs %[[R0]], %[[c4]]
//CHECK: %[[R2:.*]] = index.remu %[[R0]], %[[c4]]
//CHECK: %[[c32:.*]] = arith.constant 32 : index
//CHECK: %[[R3:.*]] = index.mul %[[R1]], %[[c32]]
//CHECK: %[[c64:.*]] = arith.constant 64 : index
//CHECK: %[[R4:.*]] = index.mul %[[R2]], %[[c64]]
//CHECK: %[[INITTILESRCMAP:.*]] = xetile.init_tile %[[SLM]][%[[R3]], %[[R4]]] : memref<256x256xf32, 3> -> !xetile.tile<32x64xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: xetile.store_tile %[[c0]], %[[INITTILESRCMAP]] : vector<32x64xf32>, !xetile.tile<32x64xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: gpu.barrier
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[R5:.*]] = index.floordivs %[[R0]], %[[c1]]
//CHECK: %[[R6:.*]] = index.remu %[[R0]], %[[c1]]
//CHECK: %[[c8:.*]] = arith.constant 8 : index
//CHECK: %[[R7:.*]] = index.mul %[[R5]], %[[c8]]
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[R8:.*]] = index.mul %[[R6]], %[[c256]]
//CHECK: %[[INITTILEDSTMAP:.*]] = xetile.init_tile %[[SLM]][%[[R7]], %[[R8]]] : memref<256x256xf32, 3> -> !xetile.tile<8x256xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: %[[LOADTILE:.*]] = xetile.load_tile %[[INITTILEDSTMAP]] : !xetile.tile<8x256xf32, #xetile.tile_attr<memory_space = 3 : i32>> -> vector<8x256xf32>

%cst = arith.constant {map = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>} dense<0.000000e+00> : vector<256x256xf32>
%cst_temp = arith.constant {map = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>} dense<0.000000e+00> : vector<256x256xf32>
%convert_layout = xetile.convert_layout %cst {wg_map_result = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>, wg_map_source = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>} : vector<256x256xf32>
%add = arith.addf %cst_temp, %convert_layout {map = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>} : vector<256x256xf32>
gpu.return
}
}

0 comments on commit b2742c3

Please sign in to comment.