Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pattern to transform xetile.convert_layout op from wg to sg #954

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 135 additions & 4 deletions lib/Dialect/XeTile/Transforms/WgToSg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,137 @@ class WGToSGVectorTranspose
};


// This pattern transforms the convert layout op in the following manner:
// 1. Store the original vector to slm using input operand layout
// 2. Add barrier
// 3. Load the vector from slm using the result layout

// Example:
// WG IR
// #wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
// #wg_map_a = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>
// %vector_a = xetile.tile_conv_layout %vector_b {wg_map_result = #wg_map_a, wg_map_source = #wg_map_b}: vector<256x256xfloat> into vector<256x256xfloat>

// SG IR
// %slm = memref.alloc() : memref<256x256xf32, 3>
// %tile = xetile.init_tile %slm[offset_x, offset_y] : memref<256x256xf32, 3> -> xetile.tile<32x64xf32>
// xetile.store_tile %vector_b, %tile :vector<32x64xf32>, !xetile.tile<32x64xf32>
// gpu.barrier
// %remapped_tile = xetile.init_tile %slm[offsetX, offsetY] : memref<256x256xf32, 3> -> xetile.tile<8x256xf32>
// %remapped_vector = xetile.load_tile %reshaped_tile : xetile.tile<8x256xf32> -> vector<8x256xf32>
class WGToSGXeTileConvertLayout
:public XeOneToNConversion<xetile::ConvertLayoutOp> {
using XeOneToNConversion<xetile::ConvertLayoutOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(xetile::ConvertLayoutOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {
if (op.getSource().getType().getRank() != 2)
return mlir::failure();

auto loc = op.getLoc();
auto res = op.getResult();
auto resType = mlir::dyn_cast<mlir::VectorType>(res.getType());
auto elemTy = resType.getElementType();
auto resShape = resType.getShape();

auto dstMapAttr =
llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr("wg_map_result"));

xetile::WorkGroupMapAttr srcMapAttr;
srcMapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(op->getAttr("wg_map_source"));

if (!dstMapAttr) {
return mlir::failure();
}

if(!srcMapAttr) {
// Get the map from operand
auto operand = op.getSource().getDefiningOp();
srcMapAttr = llvm::dyn_cast_or_null<xetile::WorkGroupMapAttr>(operand->getAttr("map"));
if (!srcMapAttr) {
return mlir::failure();
}
}

auto srcMapSgData = srcMapAttr.getSgData();
auto srcSgLayout = srcMapAttr.getSgLayout();
auto dstMapSgData = dstMapAttr.getSgData();
auto dstSgLayout = dstMapAttr.getSgLayout();

auto createIndexConstant = [&](mlir::Type type, int64_t value) {
auto attr = rewriter.getIndexAttr(value);
return rewriter.create<mlir::arith::ConstantOp>(loc, type, attr);
};

rewriter.setInsertionPoint(op);
// Allocate SLM
// TODO: Allocate slm as 1D array of i8, and then create the expected view on it.
auto slmTy = mlir::MemRefType::get({resShape[0], resShape[1]}, elemTy, {}, 3);
auto slm = rewriter.create<mlir::memref::AllocOp>(loc, slmTy);

// Get SG id
auto sgId = rewriter.create<mlir::gpu::SubgroupIdOp>(
loc, rewriter.getIndexType(), nullptr);

auto indexType = rewriter.getIndexType();
auto srcMapDimY = createIndexConstant(indexType, srcSgLayout[1]);

// The sgID is a linear (1D) id. Convert it to 2D to get the x and y
// coordinates of sg
// row = i / cols
// col = i % cols
// x is row, y is col
// TODO: Floorsdiv and Remu are expensive. Find alterate.
auto storeSgIdX =
rewriter.create<mlir::index::FloorDivSOp>(loc, sgId, srcMapDimY);
auto storeSgIdY =
rewriter.create<mlir::index::RemUOp>(loc, sgId, srcMapDimY);

// Store to SLM using src map
auto memoryScopeAttr = mlir::IntegerAttr::get(rewriter.getIntegerType(32), 3);
auto order = mlir::DenseI32ArrayAttr::get(op.getContext(), {1, 0});
auto attr = imex::xetile::XeTileAttr::get(
op.getContext(), nullptr /*sgMap*/, nullptr /*wgMap*/,
order /*order*/, nullptr /*innerblocks*/, memoryScopeAttr /*memoryscope*/,
nullptr /*scatterAttr*/);
xetile::TileType srcTileTy =
imex::xetile::TileType::get({srcMapSgData[0], srcMapSgData[1]}, elemTy, attr);

auto storeOffsetX = rewriter.createOrFold<mlir::index::MulOp>(
loc, storeSgIdX, createIndexConstant(indexType, srcMapSgData[0]));
auto storeOffsetY = rewriter.createOrFold<mlir::index::MulOp>(
loc, storeSgIdY, createIndexConstant(indexType, srcMapSgData[1]));
auto storeInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, srcTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({storeOffsetX, storeOffsetY}));
rewriter.create<xetile::StoreTileOp>(loc, adaptor.getSource()[0],
storeInitTileOp);

// Add barrier
rewriter.create<mlir::gpu::BarrierOp>(loc);

// Load from SLM with result map
xetile::TileType dstTileTy =
imex::xetile::TileType::get({dstMapSgData[0], dstMapSgData[1]}, elemTy, attr);
auto newResTy =
mlir::VectorType::get({dstMapSgData[0], dstMapSgData[1]}, elemTy);

auto dstMapDimY = createIndexConstant(indexType, dstSgLayout[1]);
auto loadSgIdX = rewriter.create<mlir::index::FloorDivSOp>(loc, sgId, dstMapDimY);
auto loadSgIdY = rewriter.create<mlir::index::RemUOp>(loc, sgId, dstMapDimY);
auto loadOffsetX = rewriter.createOrFold<mlir::index::MulOp>(
loc, loadSgIdX, createIndexConstant(indexType, dstMapSgData[0]));
auto loadOffsetY = rewriter.createOrFold<mlir::index::MulOp>(
loc, loadSgIdY, createIndexConstant(indexType, dstMapSgData[1]));
auto loadInitTileOp = rewriter.create<xetile::InitTileOp>(
loc, dstTileTy, slm, llvm::ArrayRef<mlir::OpFoldResult>({loadOffsetX, loadOffsetY}));
auto loadTile = rewriter.create<xetile::LoadTileOp>(
loc, newResTy, loadInitTileOp, mlir::Attribute());

rewriter.replaceOp(op, loadTile);
return mlir::success();
}
};

class WGToSGVectorBroadcast
:public XeOneToNConversion<mlir::vector::BroadcastOp> {
Expand Down Expand Up @@ -704,16 +835,14 @@ void analyzeInitTileOps(mlir::Operation *op) {
}



void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter,
mlir::RewritePatternSet &patterns,
TileUsageAnalysis &analysis) {
patterns.insert<WGToSGInitTileOpPattern, WGToSGLoadTileOpPattern,
WGToSGTileMMAOpPattern, WGToSGStoreTileOpPattern,
WGToSGSCFForOpPattern, WGToSGUpdateTileOffsetOpPattern,
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose,
WGToSGVectorBroadcast>(patterns.getContext(), converter,
analysis);
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose, WGToSGVectorBroadcast,
WGToSGXeTileConvertLayout>(patterns.getContext(), converter, analysis);
patterns.insert<WGToSGElementWiseOpPattern<mlir::math::ExpOp, 1>,
WGToSGElementWiseOpPattern<mlir::arith::AddFOp, 2>,
WGToSGArithConstantOpPattern>(patterns.getContext(),
Expand Down Expand Up @@ -824,6 +953,8 @@ class XeTileWgToSgPass
return false;
});

target.addIllegalOp<xetile::ConvertLayoutOp>();

target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

populateXeTileWgToSgPatterns(typeConverter, patterns, analysis);
Expand Down
36 changes: 36 additions & 0 deletions test/Dialect/XeTile/Transforms/wg_to_sg_convert_layout.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s

gpu.module @test_convert_layout{
//CHECK: gpu.func @test_kernel()
gpu.func @test_kernel() {
//CHECK: %[[c0:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
//CHECK: %[[c0_0:.*]] = arith.constant dense<0.000000e+00> : vector<8x256xf32>
//CHECK: %[[SLM:.*]] = memref.alloc() : memref<256x256xf32, 3>
//CHECK: %[[R0:.*]] = gpu.subgroup_id : index
//CHECK: %[[c4:.*]] = arith.constant 4 : index
//CHECK: %[[R1:.*]] = index.floordivs %[[R0]], %[[c4]]
//CHECK: %[[R2:.*]] = index.remu %[[R0]], %[[c4]]
//CHECK: %[[c32:.*]] = arith.constant 32 : index
//CHECK: %[[R3:.*]] = index.mul %[[R1]], %[[c32]]
//CHECK: %[[c64:.*]] = arith.constant 64 : index
//CHECK: %[[R4:.*]] = index.mul %[[R2]], %[[c64]]
//CHECK: %[[INITTILESRCMAP:.*]] = xetile.init_tile %[[SLM]][%[[R3]], %[[R4]]] : memref<256x256xf32, 3> -> !xetile.tile<32x64xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: xetile.store_tile %[[c0]], %[[INITTILESRCMAP]] : vector<32x64xf32>, !xetile.tile<32x64xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: gpu.barrier
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[R5:.*]] = index.floordivs %[[R0]], %[[c1]]
//CHECK: %[[R6:.*]] = index.remu %[[R0]], %[[c1]]
//CHECK: %[[c8:.*]] = arith.constant 8 : index
//CHECK: %[[R7:.*]] = index.mul %[[R5]], %[[c8]]
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[R8:.*]] = index.mul %[[R6]], %[[c256]]
//CHECK: %[[INITTILEDSTMAP:.*]] = xetile.init_tile %[[SLM]][%[[R7]], %[[R8]]] : memref<256x256xf32, 3> -> !xetile.tile<8x256xf32, #xetile.tile_attr<memory_space = 3 : i32>>
//CHECK: %[[LOADTILE:.*]] = xetile.load_tile %[[INITTILEDSTMAP]] : !xetile.tile<8x256xf32, #xetile.tile_attr<memory_space = 3 : i32>> -> vector<8x256xf32>

%cst = arith.constant {map = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>} dense<0.000000e+00> : vector<256x256xf32>
%cst_temp = arith.constant {map = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>} dense<0.000000e+00> : vector<256x256xf32>
%convert_layout = xetile.convert_layout %cst {wg_map_result = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>, wg_map_source = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>} : vector<256x256xf32>
%add = arith.addf %cst_temp, %convert_layout {map = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>} : vector<256x256xf32>
gpu.return
}
}
Loading