diff --git a/lib/Dialect/XeTile/Transforms/WgToSg.cpp b/lib/Dialect/XeTile/Transforms/WgToSg.cpp index 5264a8621..77ef9f7d2 100644 --- a/lib/Dialect/XeTile/Transforms/WgToSg.cpp +++ b/lib/Dialect/XeTile/Transforms/WgToSg.cpp @@ -579,6 +579,137 @@ class WGToSGVectorTranspose }; +// This pattern transforms the convert layout op in the following manner: +// 1. Store the original vector to slm using input operand layout +// 2. Add barrier +// 3. Load the vector from slm using the result layout + +// Example: +// WG IR +// #wg_map_b = #xetile.wg_map +// #wg_map_a = #xetile.wg_map +// %vector_a = xetile.tile_conv_layout %vector_b {wg_map_result = #wg_map_a, wg_map_source = #wg_map_b}: vector<256x256xfloat> into vector<256x256xfloat> + +// SG IR +// %slm = memref.alloc() : memref<256x256xf32, 3> +// %tile = xetile.init_tile %slm[offset_x, offset_y] : memref<256x256xf32, 3> -> xetile.tile<32x64xf32> +// xetile.store_tile %vector_b, %tile :vector<32x64xf32>, !xetile.tile<32x64xf32> +// gpu.barrier +// %remapped_tile = xetile.init_tile %slm[offsetX, offsetY] : memref<256x256xf32, 3> -> xetile.tile<8x256xf32> +// %remapped_vector = xetile.load_tile %reshaped_tile : xetile.tile<8x256xf32> -> vector<8x256xf32> +class WGToSGXeTileConvertLayout + :public XeOneToNConversion { + using XeOneToNConversion::XeOneToNConversion; + + mlir::LogicalResult + matchAndRewrite(xetile::ConvertLayoutOp op, OpAdaptor adaptor, + XeOneToNPatternRewriter &rewriter) const override { + if (op.getSource().getType().getRank() != 2) + return mlir::failure(); + + auto loc = op.getLoc(); + auto res = op.getResult(); + auto resType = mlir::dyn_cast(res.getType()); + auto elemTy = resType.getElementType(); + auto resShape = resType.getShape(); + + auto dstMapAttr = + llvm::dyn_cast_or_null(op->getAttr("wg_map_result")); + + xetile::WorkGroupMapAttr srcMapAttr; + srcMapAttr = llvm::dyn_cast_or_null(op->getAttr("wg_map_source")); + + if (!dstMapAttr) { + return mlir::failure(); + } + + if(!srcMapAttr) { + // Get the map from operand + auto operand = op.getSource().getDefiningOp(); + srcMapAttr = llvm::dyn_cast_or_null(operand->getAttr("map")); + if (!srcMapAttr) { + return mlir::failure(); + } + } + + auto srcMapSgData = srcMapAttr.getSgData(); + auto srcSgLayout = srcMapAttr.getSgLayout(); + auto dstMapSgData = dstMapAttr.getSgData(); + auto dstSgLayout = dstMapAttr.getSgLayout(); + + auto createIndexConstant = [&](mlir::Type type, int64_t value) { + auto attr = rewriter.getIndexAttr(value); + return rewriter.create(loc, type, attr); + }; + + rewriter.setInsertionPoint(op); + // Allocate SLM + // TODO: Allocate slm as 1D array of i8, and then create the expected view on it. + auto slmTy = mlir::MemRefType::get({resShape[0], resShape[1]}, elemTy, {}, 3); + auto slm = rewriter.create(loc, slmTy); + + // Get SG id + auto sgId = rewriter.create( + loc, rewriter.getIndexType(), nullptr); + + auto indexType = rewriter.getIndexType(); + auto srcMapDimY = createIndexConstant(indexType, srcSgLayout[1]); + + // The sgID is a linear (1D) id. Convert it to 2D to get the x and y + // coordinates of sg + // row = i / cols + // col = i % cols + // x is row, y is col + // TODO: Floorsdiv and Remu are expensive. Find alterate. + auto storeSgIdX = + rewriter.create(loc, sgId, srcMapDimY); + auto storeSgIdY = + rewriter.create(loc, sgId, srcMapDimY); + + // Store to SLM using src map + auto memoryScopeAttr = mlir::IntegerAttr::get(rewriter.getIntegerType(32), 3); + auto order = mlir::DenseI32ArrayAttr::get(op.getContext(), {1, 0}); + auto attr = imex::xetile::XeTileAttr::get( + op.getContext(), nullptr /*sgMap*/, nullptr /*wgMap*/, + order /*order*/, nullptr /*innerblocks*/, memoryScopeAttr /*memoryscope*/, + nullptr /*scatterAttr*/); + xetile::TileType srcTileTy = + imex::xetile::TileType::get({srcMapSgData[0], srcMapSgData[1]}, elemTy, attr); + + auto storeOffsetX = rewriter.createOrFold( + loc, storeSgIdX, createIndexConstant(indexType, srcMapSgData[0])); + auto storeOffsetY = rewriter.createOrFold( + loc, storeSgIdY, createIndexConstant(indexType, srcMapSgData[1])); + auto storeInitTileOp = rewriter.create( + loc, srcTileTy, slm, llvm::ArrayRef({storeOffsetX, storeOffsetY})); + rewriter.create(loc, adaptor.getSource()[0], + storeInitTileOp); + + // Add barrier + rewriter.create(loc); + + // Load from SLM with result map + xetile::TileType dstTileTy = + imex::xetile::TileType::get({dstMapSgData[0], dstMapSgData[1]}, elemTy, attr); + auto newResTy = + mlir::VectorType::get({dstMapSgData[0], dstMapSgData[1]}, elemTy); + + auto dstMapDimY = createIndexConstant(indexType, dstSgLayout[1]); + auto loadSgIdX = rewriter.create(loc, sgId, dstMapDimY); + auto loadSgIdY = rewriter.create(loc, sgId, dstMapDimY); + auto loadOffsetX = rewriter.createOrFold( + loc, loadSgIdX, createIndexConstant(indexType, dstMapSgData[0])); + auto loadOffsetY = rewriter.createOrFold( + loc, loadSgIdY, createIndexConstant(indexType, dstMapSgData[1])); + auto loadInitTileOp = rewriter.create( + loc, dstTileTy, slm, llvm::ArrayRef({loadOffsetX, loadOffsetY})); + auto loadTile = rewriter.create( + loc, newResTy, loadInitTileOp, mlir::Attribute()); + + rewriter.replaceOp(op, loadTile); + return mlir::success(); + } + }; class WGToSGVectorBroadcast :public XeOneToNConversion { @@ -704,16 +835,14 @@ void analyzeInitTileOps(mlir::Operation *op) { } - void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter, mlir::RewritePatternSet &patterns, TileUsageAnalysis &analysis) { patterns.insert(patterns.getContext(), converter, - analysis); + WGToSGSCFYieldOpPattern, WGToSGVectorTranspose, WGToSGVectorBroadcast, + WGToSGXeTileConvertLayout>(patterns.getContext(), converter, analysis); patterns.insert, WGToSGElementWiseOpPattern, WGToSGArithConstantOpPattern>(patterns.getContext(), @@ -824,6 +953,8 @@ class XeTileWgToSgPass return false; }); + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); populateXeTileWgToSgPatterns(typeConverter, patterns, analysis); diff --git a/test/Dialect/XeTile/Transforms/wg_to_sg_convert_layout.mlir b/test/Dialect/XeTile/Transforms/wg_to_sg_convert_layout.mlir new file mode 100644 index 000000000..2f984382e --- /dev/null +++ b/test/Dialect/XeTile/Transforms/wg_to_sg_convert_layout.mlir @@ -0,0 +1,36 @@ +// RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s + +gpu.module @test_convert_layout{ + //CHECK: gpu.func @test_kernel() + gpu.func @test_kernel() { + //CHECK: %[[c0:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32> + //CHECK: %[[c0_0:.*]] = arith.constant dense<0.000000e+00> : vector<8x256xf32> + //CHECK: %[[SLM:.*]] = memref.alloc() : memref<256x256xf32, 3> + //CHECK: %[[R0:.*]] = gpu.subgroup_id : index + //CHECK: %[[c4:.*]] = arith.constant 4 : index + //CHECK: %[[R1:.*]] = index.floordivs %[[R0]], %[[c4]] + //CHECK: %[[R2:.*]] = index.remu %[[R0]], %[[c4]] + //CHECK: %[[c32:.*]] = arith.constant 32 : index + //CHECK: %[[R3:.*]] = index.mul %[[R1]], %[[c32]] + //CHECK: %[[c64:.*]] = arith.constant 64 : index + //CHECK: %[[R4:.*]] = index.mul %[[R2]], %[[c64]] + //CHECK: %[[INITTILESRCMAP:.*]] = xetile.init_tile %[[SLM]][%[[R3]], %[[R4]]] : memref<256x256xf32, 3> -> !xetile.tile<32x64xf32, #xetile.tile_attr> + //CHECK: xetile.store_tile %[[c0]], %[[INITTILESRCMAP]] : vector<32x64xf32>, !xetile.tile<32x64xf32, #xetile.tile_attr> + //CHECK: gpu.barrier + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[R5:.*]] = index.floordivs %[[R0]], %[[c1]] + //CHECK: %[[R6:.*]] = index.remu %[[R0]], %[[c1]] + //CHECK: %[[c8:.*]] = arith.constant 8 : index + //CHECK: %[[R7:.*]] = index.mul %[[R5]], %[[c8]] + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[R8:.*]] = index.mul %[[R6]], %[[c256]] + //CHECK: %[[INITTILEDSTMAP:.*]] = xetile.init_tile %[[SLM]][%[[R7]], %[[R8]]] : memref<256x256xf32, 3> -> !xetile.tile<8x256xf32, #xetile.tile_attr> + //CHECK: %[[LOADTILE:.*]] = xetile.load_tile %[[INITTILEDSTMAP]] : !xetile.tile<8x256xf32, #xetile.tile_attr> -> vector<8x256xf32> + + %cst = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<256x256xf32> + %cst_temp = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<256x256xf32> + %convert_layout = xetile.convert_layout %cst {wg_map_result = #xetile.wg_map, wg_map_source = #xetile.wg_map} : vector<256x256xf32> + %add = arith.addf %cst_temp, %convert_layout {map = #xetile.wg_map} : vector<256x256xf32> + gpu.return + } + }