From 6f250a4e2cb0a111560e2d5a6bc662362c6b996b Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 24 Dec 2024 18:01:53 +0000 Subject: [PATCH] Add support for reduction in wg to sg transformation --- lib/Dialect/XeTile/Transforms/WgToSg.cpp | 238 +++++++++++++++--- .../XeTile/Transforms/WgToSg/btranspose.mlir | 2 +- .../Transforms/WgToSg/convert_layout.mlir | 12 +- .../XeTile/Transforms/WgToSg/gemm_1k.mlir | 2 +- .../XeTile/Transforms/WgToSg/gemm_4k.mlir | 2 +- .../XeTile/Transforms/WgToSg/reduction.mlir | 89 +++++++ .../XeTile/Transforms/WgToSg/round_robin.mlir | 2 +- .../XeTile/Transforms/WgToSg/unit_tests.mlir | 27 +- .../Dialect/XeTile/convert_layout.mlir | 83 ++++++ .../Dialect/XeTile/wg_reduction.mlir | 94 +++++++ .../Dialect/XeTile/xetile-wg-to-func-vc.pp | 6 +- 11 files changed, 515 insertions(+), 42 deletions(-) create mode 100644 test/Dialect/XeTile/Transforms/WgToSg/reduction.mlir create mode 100644 test/Integration/Dialect/XeTile/convert_layout.mlir create mode 100644 test/Integration/Dialect/XeTile/wg_reduction.mlir diff --git a/lib/Dialect/XeTile/Transforms/WgToSg.cpp b/lib/Dialect/XeTile/Transforms/WgToSg.cpp index ce225a53a..a86228b28 100644 --- a/lib/Dialect/XeTile/Transforms/WgToSg.cpp +++ b/lib/Dialect/XeTile/Transforms/WgToSg.cpp @@ -126,7 +126,7 @@ class WGToSGInitTileOpPattern : public XeOneToNConversion { // row = i / cols // col = i % cols auto sgIdY = - rewriter.create(loc, sgID, sgLayoutDimYConst); + rewriter.create(loc, sgID, sgLayoutDimYConst); auto sgIdX = rewriter.create(loc, sgID, sgLayoutDimYConst); @@ -496,7 +496,7 @@ class WGToSGArithConstantOpPattern auto valueType = mlir::dyn_cast(value.getType()); auto wgTileShape = valueType.getShape(); - if (!value || value.getType().getRank() != 2) + if (!value) return mlir::failure(); auto mapAttr = @@ -507,8 +507,21 @@ class WGToSGArithConstantOpPattern auto sgData = mapAttr.getSgData(); auto sgLayout = mapAttr.getSgLayout(); + mlir::SmallVector outputShape; + // If WG tile rank is 1, set the output shape as the + // non-unit dim of sgData + if(wgTileShape.size() == 1) { + if(sgData[0] == 1) + outputShape.push_back(sgData[1]); + else + outputShape.push_back(sgData[0]); + } else { + outputShape.push_back(sgData[0]); + outputShape.push_back(sgData[1]); + } + auto newTy = - mlir::VectorType::get({sgData[0], sgData[1]}, value.getElementType()); + mlir::VectorType::get(outputShape, value.getElementType()); llvm::SmallVector elems( value.value_begin(), @@ -522,12 +535,20 @@ class WGToSGArithConstantOpPattern auto attr = mlir::DenseElementsAttr::get(newTy, newValues); size_t numOps; - if (sgLayout[0] * sgData[0] == wgTileShape[0] && - sgLayout[1] * sgData[1] == wgTileShape[1]) - numOps = 1; // 1:1 mapping - else - numOps = (wgTileShape[0] / (sgLayout[0] * sgData[0])) + - (wgTileShape[1] / (sgLayout[1] * sgData[1])); + // If WG tile is 1D vector just support 1:1 mapping. + // TODO: Support round robin for 1D + if(wgTileShape.size() == 1) { + if (sgLayout[0] * sgData[0] == wgTileShape[0] || + sgLayout[1] * sgData[1] == wgTileShape[0]) + numOps = 1; + else + return mlir::failure(); + } else if(sgLayout[0] * sgData[0] == wgTileShape[0] && + sgLayout[1] * sgData[1] == wgTileShape[1]) { + numOps = 1; + } else + numOps = (wgTileShape[0] / (sgLayout[0] * sgData[0])) + + (wgTileShape[1] / (sgLayout[1] * sgData[1])); llvm::SmallVector<::mlir::Value> newOps; llvm::SmallVector newResultTypes; @@ -706,9 +727,16 @@ class WGToSGXeTileConvertLayout rewriter.setInsertionPoint(op); // Allocate SLM - // TODO: Allocate slm as 1D array of i8, and then create the expected view on it. - auto slmTy = mlir::MemRefType::get({resShape[0], resShape[1]}, elemTy, {}, 3); + auto bitWidth = elemTy.getIntOrFloatBitWidth(); + auto flattenFactor = bitWidth / 8; + auto slmShape = resShape[0] * resShape[1] * flattenFactor; + auto slmTy = mlir::MemRefType::get(slmShape, rewriter.getI8Type(), {}, 3); auto slm = rewriter.create(loc, slmTy); + ValueRange sizes; + auto zero = rewriter.create(op.getLoc(), 0); + auto viewTy = mlir::MemRefType::get({resShape[0], resShape[1]}, elemTy, {}, 3); + auto viewOp = rewriter.create( + op.getLoc(), viewTy, slm, zero, sizes); // Get SG id auto sgId = rewriter.create( @@ -724,7 +752,7 @@ class WGToSGXeTileConvertLayout // x is row, y is col // TODO: Floorsdiv and Remu are expensive. Find alterate. auto storeSgIdX = - rewriter.create(loc, sgId, srcMapDimY); + rewriter.create(loc, sgId, srcMapDimY); auto storeSgIdY = rewriter.create(loc, sgId, srcMapDimY); @@ -742,7 +770,7 @@ class WGToSGXeTileConvertLayout auto storeOffsetY = rewriter.createOrFold( loc, storeSgIdY, createIndexConstant(indexType, srcMapSgData[1])); auto storeInitTileOp = rewriter.create( - loc, srcTileTy, slm, llvm::ArrayRef({storeOffsetX, storeOffsetY})); + loc, srcTileTy, viewOp, llvm::ArrayRef({storeOffsetX, storeOffsetY})); //TODO: Set up cache attributes rewriter.create(loc, adaptor.getSource()[0], storeInitTileOp, nullptr, nullptr, nullptr); @@ -757,14 +785,14 @@ class WGToSGXeTileConvertLayout mlir::VectorType::get({dstMapSgData[0], dstMapSgData[1]}, elemTy); auto dstMapDimY = createIndexConstant(indexType, dstSgLayout[1]); - auto loadSgIdX = rewriter.create(loc, sgId, dstMapDimY); + auto loadSgIdX = rewriter.create(loc, sgId, dstMapDimY); auto loadSgIdY = rewriter.create(loc, sgId, dstMapDimY); auto loadOffsetX = rewriter.createOrFold( loc, loadSgIdX, createIndexConstant(indexType, dstMapSgData[0])); auto loadOffsetY = rewriter.createOrFold( loc, loadSgIdY, createIndexConstant(indexType, dstMapSgData[1])); auto loadInitTileOp = rewriter.create( - loc, dstTileTy, slm, llvm::ArrayRef({loadOffsetX, loadOffsetY})); + loc, dstTileTy, viewOp, llvm::ArrayRef({loadOffsetX, loadOffsetY})); //TODO: Set up cache attributes auto loadTile = rewriter.create( loc, newResTy, loadInitTileOp, mlir::Attribute(), nullptr, nullptr, nullptr); @@ -834,6 +862,127 @@ class WGToSGPrefetchOpPattern : public XeOneToNConversion { + using XeOneToNConversion< + mlir::vector::MultiDimReductionOp>::XeOneToNConversion; + + mlir::LogicalResult + matchAndRewrite(mlir::vector::MultiDimReductionOp op, OpAdaptor adaptor, + XeOneToNPatternRewriter &rewriter) const override { + + auto res = op.getResult(); + auto resType = mlir::dyn_cast(res.getType()); + auto resRank = resType.getShape().size(); + + auto mapAttr = + llvm::dyn_cast_or_null(op->getAttr("map")); + + if (!mapAttr) { + return mlir::failure(); + } + + auto sgData = mapAttr.getSgData(); + + auto src = adaptor.getSource()[0]; + auto srcType = mlir::dyn_cast(src.getType()); + + if (resRank == 2) { + bool newReduceDim = sgData[0] == 1 ? 0 : 1; + mlir::SmallVector redDims{newReduceDim}; + auto outputShape = + newReduceDim == 0 ? srcType.getDimSize(1) : srcType.getDimSize(0); + auto newTy = mlir::VectorType::get(outputShape, srcType.getElementType()); + + // ShapeCast acc to match reduction op shape. + auto acc = rewriter.create(op->getLoc(), newTy, + adaptor.getAcc()[0]); + + auto newOp = rewriter.create( + op.getLoc(), newTy, op.getKind(), src, acc, redDims); + + // Shape Cast the output of reduction back to 2D + auto accumalator = adaptor.getAcc()[0]; + auto accumalatorType = + mlir::dyn_cast(accumalator.getType()); + auto outputVectorTy = mlir::VectorType::get( + accumalatorType.getShape(), accumalatorType.getElementType()); + auto shapeCastOp = rewriter.create( + op.getLoc(), outputVectorTy, newOp); + rewriter.replaceOp(op, shapeCastOp); + return mlir::success(); + } + // Regular 2D vector.multi_reduction + else { + auto reductionDims = op.getReductionDims(); + if (reductionDims.size() != 1) + return mlir::failure(); + + bool reduceDim = reductionDims[0]; + auto outputShape = + reduceDim == 0 ? srcType.getDimSize(1) : srcType.getDimSize(0); + + mlir::SmallVector redDims{reduceDim}; + auto newTy = mlir::VectorType::get(outputShape, srcType.getElementType()); + auto newOp = rewriter.create( + op.getLoc(), newTy, op.getKind(), adaptor.getSource()[0], + adaptor.getAcc()[0], redDims); + rewriter.replaceOp(op, newOp); + return mlir::success(); + } + } +}; + +// Shape cast will support going from 1D to 2D since the vector.multi_reduction +// produces 1D + +class WGToSGVectorShapeCast + : public XeOneToNConversion { + using XeOneToNConversion::XeOneToNConversion; + + mlir::LogicalResult + matchAndRewrite(mlir::vector::ShapeCastOp op, OpAdaptor adaptor, + XeOneToNPatternRewriter &rewriter) const override { + + auto res = op.getResult(); + auto resType = mlir::dyn_cast(res.getType()); + auto resShape = resType.getShape(); + + // Assumption is 3D shape cast is used for partial reduction. + // So just replace it with the transformed source of shape_cast + if (resShape.size() == 3) { + for (mlir::Operation *userOp : op.getResult().getUsers()) { + // Check if the user operation is not a vector.multi_reduction + if (!isa(userOp)) { + return mlir::failure(); + } + } + rewriter.replaceOp(op, adaptor.getSource()[0]); + return mlir::success(); + } + + // One of the dims have to be a unit dim + if (resShape[0] != 1 && resShape[1] != 1) + return mlir::failure(); + + auto mapAttr = + llvm::dyn_cast_or_null(op->getAttr("map")); + + if (!mapAttr) { + return mlir::failure(); + } + + auto sgData = mapAttr.getSgData(); + auto newTy = + mlir::VectorType::get({sgData[0], sgData[1]}, resType.getElementType()); + + auto newOp = rewriter.create( + op.getLoc(), newTy, adaptor.getSource()[0]); + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + // Helper function to analyze the def-use chain of initTileOps. Currently we // pattern match the following def-use chain as a candidate for // load + tranpose optimization. @@ -852,10 +1001,11 @@ void analyzeInitTileOps(mlir::Operation *op) { if (!initOp->hasOneUse()) return mlir::WalkResult::skip(); ops.push_back(initOp); - auto user = *initOp->user_begin(); + auto initOpUser = *initOp->user_begin(); // InitTileOp must be consumed by a ForOp - mlir::Operation *loadUser = nullptr, *updateOffsetUser = nullptr; - if (auto scfFor = llvm::dyn_cast_if_present(user)) { + mlir::Operation *loadUser = nullptr; + mlir::BlockArgument loopArg; + if (auto scfFor = llvm::dyn_cast_if_present(initOpUser)) { auto argument = imex::getArgForOperand(scfFor, initOp.getResult()); int userCount = 0; for (auto user : argument.getUsers()) { @@ -865,13 +1015,37 @@ void analyzeInitTileOps(mlir::Operation *op) { ops.push_back(scfFor); ops.push_back(user); } else if (llvm::isa(user)) { - updateOffsetUser = user; ops.push_back(scfFor); ops.push_back(user); } + // Nested scf.for's + // init_tile -> scf.for -> update_tile_offset + // | + // scf.for -> load_tile -> vector.transpose -> (pre-op) -> + // tile_mma + else if (auto scfFor = + llvm::dyn_cast_if_present(user)) { + for (auto iterOperand : llvm::enumerate(scfFor.getInitArgs())) { + if (iterOperand.value() == argument) { + loopArg = scfFor.getRegionIterArgs()[iterOperand.index()]; + break; + } + } + + for (auto scfForUser : loopArg.getUsers()) { + if (llvm::isa(scfForUser)) { + loadUser = scfForUser; + ops.push_back(scfFor); + ops.push_back(scfForUser); + } else if (llvm::isa( + scfForUser)) { + ops.push_back(scfFor); + ops.push_back(scfForUser); + } + } + } } - // ForOp argument should have only two users, a load and an update offset - if (userCount != 2 || !(loadUser && updateOffsetUser)) + if (!loadUser) return mlir::WalkResult::skip(); } else return mlir::WalkResult::skip(); @@ -888,16 +1062,15 @@ void analyzeInitTileOps(mlir::Operation *op) { // Check if vector.transpose is consumed by TileMMA directly or // is consumed by some pre-op and then TileMMA. - if(!llvm::isa(consumerOp)){ - if(!OpTrait::hasElementwiseMappableTraits(consumerOp) && - !(llvm::isa(consumerOp))) { + if (!llvm::isa(consumerOp)) { + if (!OpTrait::hasElementwiseMappableTraits(consumerOp) && + !(llvm::isa(consumerOp))) { return mlir::WalkResult::skip(); - } - else { + } else { if (!(consumerOp->hasOneUse() && llvm::isa(*consumerOp->user_begin()))) - return mlir::WalkResult::skip(); - } + return mlir::WalkResult::skip(); + } } // At this point, we have a candidate def-use chain for optimization. @@ -917,8 +1090,10 @@ void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter, WGToSGSCFForOpPattern, WGToSGUpdateTileOffsetOpPattern, WGToSGSCFYieldOpPattern, WGToSGVectorTranspose, WGToSGVectorBroadcast, WGToSGXeTileConvertLayout, WGToSGPrefetchOpPattern, WGToSGArithExtFOpPattern, - WGToSGArithTruncFOpPattern>(patterns.getContext(), converter); + WGToSGArithTruncFOpPattern, WGToSGVectorShapeCast, WGToSGVectorMultiDimReductionOp + >(patterns.getContext(), converter); patterns.insert, + WGToSGElementWiseOpPattern, WGToSGElementWiseOpPattern, WGToSGArithConstantOpPattern>(patterns.getContext(), converter); } @@ -1016,9 +1191,10 @@ class XeTileWgToSgPass }); target.addDynamicallyLegalOp( + mlir::vector::BroadcastOp, mlir::vector::MultiDimReductionOp, + mlir::vector::ShapeCastOp>( [&](mlir::Operation *op) -> bool { auto mapAttr = llvm::dyn_cast_or_null( op->getAttr("map")); diff --git a/test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir b/test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir index d1bedbe3c..b65f4c62e 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir @@ -57,7 +57,7 @@ gpu.module @test_gemm_btranspose{ %10 = arith.addi %8, %9 : index %11 = xetile.init_tile %arg0[%10, %c0] : memref<16384x12288xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>> - //CHECK: %[[R7:.*]] = index.floordivs %[[R6]], %[[c8]] + //CHECK: %[[R7:.*]] = index.divu %[[R6]], %[[c8]] //CHECK: %[[R8:.*]] = index.remu %[[R6]], %[[c8]] //CHECK: %[[R9:.*]] = index.add %[[R8]], %[[c0]] //CHECK: %[[R10:.*]] = index.remu %[[R9]], %[[c4]] diff --git a/test/Dialect/XeTile/Transforms/WgToSg/convert_layout.mlir b/test/Dialect/XeTile/Transforms/WgToSg/convert_layout.mlir index 2f984382e..55b308e91 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/convert_layout.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/convert_layout.mlir @@ -5,26 +5,28 @@ gpu.module @test_convert_layout{ gpu.func @test_kernel() { //CHECK: %[[c0:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32> //CHECK: %[[c0_0:.*]] = arith.constant dense<0.000000e+00> : vector<8x256xf32> - //CHECK: %[[SLM:.*]] = memref.alloc() : memref<256x256xf32, 3> + //CHECK: %[[SLMALLOC:.*]] = memref.alloc() : memref<262144xi8, 3> + //CHECK: %[[cst_0:.*]] = arith.constant 0 : index + //CHECK: %[[SLMVIEW:.*]] = memref.view %[[SLMALLOC]][%[[cst_0]]][] : memref<262144xi8, 3> to memref<256x256xf32, 3> //CHECK: %[[R0:.*]] = gpu.subgroup_id : index //CHECK: %[[c4:.*]] = arith.constant 4 : index - //CHECK: %[[R1:.*]] = index.floordivs %[[R0]], %[[c4]] + //CHECK: %[[R1:.*]] = index.divu %[[R0]], %[[c4]] //CHECK: %[[R2:.*]] = index.remu %[[R0]], %[[c4]] //CHECK: %[[c32:.*]] = arith.constant 32 : index //CHECK: %[[R3:.*]] = index.mul %[[R1]], %[[c32]] //CHECK: %[[c64:.*]] = arith.constant 64 : index //CHECK: %[[R4:.*]] = index.mul %[[R2]], %[[c64]] - //CHECK: %[[INITTILESRCMAP:.*]] = xetile.init_tile %[[SLM]][%[[R3]], %[[R4]]] : memref<256x256xf32, 3> -> !xetile.tile<32x64xf32, #xetile.tile_attr> + //CHECK: %[[INITTILESRCMAP:.*]] = xetile.init_tile %[[SLMVIEW]][%[[R3]], %[[R4]]] : memref<256x256xf32, 3> -> !xetile.tile<32x64xf32, #xetile.tile_attr> //CHECK: xetile.store_tile %[[c0]], %[[INITTILESRCMAP]] : vector<32x64xf32>, !xetile.tile<32x64xf32, #xetile.tile_attr> //CHECK: gpu.barrier //CHECK: %[[c1:.*]] = arith.constant 1 : index - //CHECK: %[[R5:.*]] = index.floordivs %[[R0]], %[[c1]] + //CHECK: %[[R5:.*]] = index.divu %[[R0]], %[[c1]] //CHECK: %[[R6:.*]] = index.remu %[[R0]], %[[c1]] //CHECK: %[[c8:.*]] = arith.constant 8 : index //CHECK: %[[R7:.*]] = index.mul %[[R5]], %[[c8]] //CHECK: %[[c256:.*]] = arith.constant 256 : index //CHECK: %[[R8:.*]] = index.mul %[[R6]], %[[c256]] - //CHECK: %[[INITTILEDSTMAP:.*]] = xetile.init_tile %[[SLM]][%[[R7]], %[[R8]]] : memref<256x256xf32, 3> -> !xetile.tile<8x256xf32, #xetile.tile_attr> + //CHECK: %[[INITTILEDSTMAP:.*]] = xetile.init_tile %[[SLMVIEW]][%[[R7]], %[[R8]]] : memref<256x256xf32, 3> -> !xetile.tile<8x256xf32, #xetile.tile_attr> //CHECK: %[[LOADTILE:.*]] = xetile.load_tile %[[INITTILEDSTMAP]] : !xetile.tile<8x256xf32, #xetile.tile_attr> -> vector<8x256xf32> %cst = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<256x256xf32> diff --git a/test/Dialect/XeTile/Transforms/WgToSg/gemm_1k.mlir b/test/Dialect/XeTile/Transforms/WgToSg/gemm_1k.mlir index d676d62af..e993c3871 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/gemm_1k.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/gemm_1k.mlir @@ -28,7 +28,7 @@ gpu.module @test_wg_to_sg { //CHECK: %[[R4:.*]] = gpu.subgroup_id : index //CHECK: %[[c4:.*]] = arith.constant 4 : index //CHECK: %[[c32:.*]] = arith.constant 32 : index - //CHECK: %[[R5:.*]] = index.floordivs %[[R4]], %[[c4]] + //CHECK: %[[R5:.*]] = index.divu %[[R4]], %[[c4]] //CHECK: %[[R6:.*]] = index.remu %[[R4]], %[[c4]] //CHECK: %[[R7:.*]] = index.add %[[R5]], %[[c0]] //CHECK: %[[R8:.*]] = index.remu %[[R7]], %[[c4]] diff --git a/test/Dialect/XeTile/Transforms/WgToSg/gemm_4k.mlir b/test/Dialect/XeTile/Transforms/WgToSg/gemm_4k.mlir index be3c18ea5..675386d7e 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/gemm_4k.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/gemm_4k.mlir @@ -30,7 +30,7 @@ gpu.module @test_wg_to_sg_4k { //CHECK: %[[R4:.*]] = gpu.subgroup_id : index //CHECK: %[[c4:.*]] = arith.constant 4 : index //CHECK: %[[c64:.*]] = arith.constant 64 : index - //CHECK: %[[R5:.*]] = index.floordivs %[[R4]], %[[c4]] + //CHECK: %[[R5:.*]] = index.divu %[[R4]], %[[c4]] //CHECK: %[[R6:.*]] = index.remu %[[R4]], %[[c4]] //CHECK: %[[R7:.*]] = index.add %[[R5]], %[[c0]] //CHECK: %[[R8:.*]] = index.remu %[[R7]], %[[c4]] diff --git a/test/Dialect/XeTile/Transforms/WgToSg/reduction.mlir b/test/Dialect/XeTile/Transforms/WgToSg/reduction.mlir new file mode 100644 index 000000000..2944cd12e --- /dev/null +++ b/test/Dialect/XeTile/Transforms/WgToSg/reduction.mlir @@ -0,0 +1,89 @@ +// RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s + +#map = affine_map<() -> (0)> +#map1 = affine_map<() -> (12288)> +#map2 = affine_map<() -> (2)> +module attributes {gpu.container_module} { + func.func @postop_reduce_m_entry(%arg0: memref<16384x12288xbf16>, %arg1: memref<2048x12288xbf16>, %arg2: memref<32x2048xf32>) attributes {gemm_tiles_b = 1 : i64, gemm_tiles_x = dense<[8, 2, 4, 8]> : vector<4xi64>, gemm_tiles_y = dense<[1, 2, 8, 4]> : vector<4xi64>, physical_nd_range = dense<[8, 32]> : vector<2xi64>, region_partition = 0 : i64, region_size = 32 : i64, syn.fusion_successful, syn.tensor_signature = (tensor<16384x12288xbf16>, tensor<2048x12288xbf16>) -> tensor<32x2048xf32>, synFusionGenOps = 6 : i64, synFusionRequiredBeamSize = 1 : i64, synFusionTotalCost = 1003595802.6 : f64} { + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + gpu.launch_func @postop_reduce_m::@postop_reduce_m blocks in (%c8, %c32, %c1) threads in (%c8, %c4, %c1) args(%arg0 : memref<16384x12288xbf16>, %arg1 : memref<2048x12288xbf16>, %arg2 : memref<32x2048xf32>) + return + } + gpu.module @postop_reduce_m attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @postop_reduce_m(%arg0: memref<16384x12288xbf16>, %arg1: memref<2048x12288xbf16>, %arg2: memref<32x2048xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c12288 = arith.constant 12288 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c256 = arith.constant 256 : index + %c2048 = arith.constant 2048 : index + %c128 = arith.constant 128 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %cst = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<256x128xf32> + //CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32> + %cst_0 = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<8x128xf32> + //CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> + %cst_1 = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<128xf32> + %cst_2 = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<1x128xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.divsi %block_id_y, %c8 : index + %1 = arith.remsi %block_id_y, %c8 : index + %2 = arith.muli %block_id_x, %c4 : index + %3 = arith.addi %2, %0 : index + %4 = arith.muli %1, %c128 : index + %5 = xetile.init_tile %arg2[%3, %4] : memref<32x2048xf32> -> !xetile.tile<1x128xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %6 = arith.muli %block_id_x, %c2048 : index + %7 = arith.muli %0, %c256 : index + %8 = arith.addi %6, %7 : index + %9 = xetile.init_tile %arg0[%8, %c0] : memref<16384x12288xbf16> -> !xetile.tile<256x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %10 = xetile.init_tile %arg1[%4, %c0] : memref<2048x12288xbf16> -> !xetile.tile<128x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %11:2 = scf.for %arg3 = %c0 to %c2 step %c1 iter_args(%arg4 = %5, %arg5 = %10) -> (!xetile.tile<1x128xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>) { + %12 = xetile.update_tile_offset %arg5, [%c1024, %c0] : !xetile.tile<128x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %13 = xetile.update_tile_offset %arg4, [%c0, %c1024] : !xetile.tile<1x128xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %14:2 = scf.for %arg6 = %c0 to %c2 step %c1 iter_args(%arg7 = %cst_2, %arg8 = %9) -> (vector<1x128xf32>, !xetile.tile<256x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>) { + %16 = xetile.update_tile_offset %arg8, [%c1024, %c0] : !xetile.tile<256x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %17:3 = scf.for %arg9 = %c0 to %c12288 step %c32 iter_args(%arg10 = %cst, %arg11 = %arg8, %arg12 = %arg5) -> (vector<256x128xf32>, !xetile.tile<256x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>) { + %27 = xetile.update_tile_offset %arg12, [%c0, %c32] : !xetile.tile<128x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %28 = xetile.update_tile_offset %arg11, [%c0, %c32] : !xetile.tile<256x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %29 = xetile.load_tile %arg11 : !xetile.tile<256x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> -> vector<256x32xbf16> + %30 = math.exp %29 {map = #xetile.wg_map} : vector<256x32xbf16> + %31 = xetile.load_tile %arg12 : !xetile.tile<128x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> -> vector<128x32xbf16> + %32 = vector.transpose %31, [1, 0] {map = #xetile.wg_map} : vector<128x32xbf16> to vector<32x128xbf16> + xegpu.compile_hint + %33 = xetile.tile_mma %30, %32, %arg10 {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<256x32xbf16>, vector<32x128xbf16>, vector<256x128xf32> -> vector<256x128xf32> + xegpu.compile_hint + scf.yield %33, %28, %27 : vector<256x128xf32>, !xetile.tile<256x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + } {lowerBoundMap = #map, operandSegmentSizes = array, step = 32 : index, upperBoundMap = #map1} + //CHECK: %[[EXP:.*]] = math.exp {{%.*}} : vector<32x32xf32> + //CHECK: %[[SHAPECAST_0:.*]] = vector.shape_cast %[[CST_0]] : vector<1x32xf32> to vector<32xf32> + //CHECK: %[[REDUCTION_0:.*]] = vector.multi_reduction , %[[EXP]], %[[SHAPECAST_0]] [0] : vector<32x32xf32> to vector<32xf32> + //CHECK: %[[SHAPECAST_1:.*]] = vector.shape_cast %[[REDUCTION_0]] : vector<32xf32> to vector<1x32xf32> + %18 = math.exp %17#0 {map = #xetile.wg_map} : vector<256x128xf32> + %19 = vector.shape_cast %18 {map = #xetile.wg_map} : vector<256x128xf32> to vector<8x32x128xf32> + %20 = vector.multi_reduction , %19, %cst_0 {map = #xetile.wg_map} [1] : vector<8x32x128xf32> to vector<8x128xf32> + //CHECK: xetile.store_tile {{%.*}}, {{%.*}} : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr> + //CHECK: gpu.barrier + //CHECK: %[[LOADTILE_SLM:.*]] = xetile.load_tile {{%.*}} : !xetile.tile<8x4xf32, #xetile.tile_attr> -> vector<8x4xf32> + //CHECK: %[[REDUCTION_1:.*]] = vector.multi_reduction , %[[LOADTILE_SLM]], %[[CST_1]] [0] : vector<8x4xf32> to vector<4xf32> + //CHECK: %[[SHAPECAST_2:.*]] = vector.shape_cast %[[REDUCTION_1]] : vector<4xf32> to vector<1x4xf32> + //CHECK: %[[ADDF:.*]] = arith.addf %[[SHAPECAST_2]], {{%.*}} : vector<1x4xf32> + %21 = xetile.convert_layout %20 {wg_map_result = #xetile.wg_map} : vector<8x128xf32> + %22 = vector.multi_reduction , %21, %cst_1 {map = #xetile.wg_map} [0] : vector<8x128xf32> to vector<128xf32> + %23 = vector.shape_cast %22 {map = #xetile.wg_map} : vector<128xf32> to vector<1x128xf32> + %24 = arith.addf %23, %arg7 {map = #xetile.wg_map} : vector<1x128xf32> + scf.yield %24, %16 : vector<1x128xf32>, !xetile.tile<256x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + } {lowerBoundMap = #map, operandSegmentSizes = array, step = 1 : index, upperBoundMap = #map2} + xetile.store_tile %14#0, %arg4 : vector<1x128xf32>, !xetile.tile<1x128xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + scf.yield %13, %12 : !xetile.tile<1x128xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xbf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + } {lowerBoundMap = #map, operandSegmentSizes = array, step = 1 : index, upperBoundMap = #map2} + gpu.return + } + } +} diff --git a/test/Dialect/XeTile/Transforms/WgToSg/round_robin.mlir b/test/Dialect/XeTile/Transforms/WgToSg/round_robin.mlir index f46f2f40f..f1b08c8ce 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/round_robin.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/round_robin.mlir @@ -28,7 +28,7 @@ gpu.module @test_wg_to_sg_rr { //CHECK: %[[R4:.*]] = gpu.subgroup_id : index //CHECK: %[[c2:.*]] = arith.constant 2 : index //CHECK: %[[c32:.*]] = arith.constant 32 : index - //CHECK: %[[R5:.*]] = index.floordivs %[[R4]], %[[c2]] + //CHECK: %[[R5:.*]] = index.divu %[[R4]], %[[c2]] //CHECK: %[[R6:.*]] = index.remu %[[R4]], %[[c2]] //CHECK: %[[R7:.*]] = index.add %[[R5]], %[[c0]] //CHECK: %[[c4:.*]] = arith.constant 4 : index diff --git a/test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir b/test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir index 34734b196..c892a2eca 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir @@ -9,6 +9,31 @@ gpu.module @test_arith_extf { //CHECK: arith.truncf {{%.*}} : vector<32x32xf32> to vector<32x32xf16> %extf = arith.extf %load_tile {map = #xetile.wg_map} : vector<128x32xf16> to vector<128x32xf32> %trucf = arith.truncf %extf {map = #xetile.wg_map} : vector<128x32xf32> to vector<128x32xf16> - gpu.return + gpu.return + } + + gpu.func @test_reduction_and_shape_cast(%arg0 : vector<256x128xf32>) { + //CHECK: %[[CST_0:.*]] = arith.constant dense<-1.000000e+00> : vector<32x32xf32> + //CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32> + //CHECK: %[[SQRT:.*]] = math.sqrt %[[CST_0]] : vector<32x32xf32> + //CHECK: %[[SHAPECAST_0:.*]] = vector.shape_cast %[[CST_1]] : vector<1x32xf32> to vector<32xf32> + //CHECK: %[[REDUCTION_0:.*]] = vector.multi_reduction , %[[SQRT]], %[[SHAPECAST_0]] [0] : vector<32x32xf32> to vector<32xf32> + //CHECK: %[[SHAPECAST_1:.*]] = vector.shape_cast %[[REDUCTION_0]] : vector<32xf32> to vector<1x32xf32> + %cst = arith.constant {map = #xetile.wg_map} dense<-1.0> : vector<256x128xf32> + %cst_0 = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<8x128xf32> + %sqrt = math.sqrt %cst {map = #xetile.wg_map} : vector<256x128xf32> + %reshape = vector.shape_cast %sqrt {map = #xetile.wg_map} : vector<256x128xf32> to vector<8x32x128xf32> + %reduction = vector.multi_reduction , %reshape, %cst_0 {map = #xetile.wg_map} [1] : vector<8x32x128xf32> to vector<8x128xf32> + //CHECK: xetile.store_tile {{%.*}}, {{%.*}} : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr> + //CHECK: gpu.barrier + //CHECK: %[[LOADTILE_SLM:.*]] = xetile.load_tile {{%.*}} : !xetile.tile<8x4xf32, #xetile.tile_attr> -> vector<8x4xf32> + //CHECK: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> + //CHECK: %[[REDUCTION_1:.*]] = vector.multi_reduction , %[[LOADTILE_SLM]], %[[CST_2]] [0] : vector<8x4xf32> to vector<4xf32> + //CHECK: %[[SHAPECAST_2:.*]] = vector.shape_cast %[[REDUCTION_1]] : vector<4xf32> to vector<1x4xf32> + %conv_layout = xetile.convert_layout %reduction {wg_map_result = #xetile.wg_map} : vector<8x128xf32> + %cst_1 = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<128xf32> + %reduce = vector.multi_reduction , %conv_layout, %cst_1 {map = #xetile.wg_map} [0] : vector<8x128xf32> to vector<128xf32> + %shape_cast = vector.shape_cast %reduce {map = #xetile.wg_map} : vector<128xf32> to vector<1x128xf32> + gpu.return } } diff --git a/test/Integration/Dialect/XeTile/convert_layout.mlir b/test/Integration/Dialect/XeTile/convert_layout.mlir new file mode 100644 index 000000000..3c011a7d0 --- /dev/null +++ b/test/Integration/Dialect/XeTile/convert_layout.mlir @@ -0,0 +1,83 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +module @conv_layout attributes {gpu.container_module} { + func.func @convert_layout(%a: memref<64x64xf32>, %b: memref<64x64xf32>) -> memref<64x64xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + + %a_gpu = gpu.alloc host_shared () : memref<64x64xf32> + memref.copy %a, %a_gpu : memref<64x64xf32> to memref<64x64xf32> + %b_gpu = gpu.alloc host_shared () : memref<64x64xf32> + memref.copy %b, %b_gpu : memref<64x64xf32> to memref<64x64xf32> + %c_gpu = gpu.alloc host_shared () : memref<64x64xf32> + + gpu.launch_func @kernel::@test_convert_layout blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%a_gpu : memref<64x64xf32>, %b_gpu : memref<64x64xf32>, %c_gpu : memref<64x64xf32>) + + gpu.dealloc %a_gpu : memref<64x64xf32> + gpu.dealloc %b_gpu : memref<64x64xf32> + return %c_gpu : memref<64x64xf32> + } + +gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_convert_layout(%arg0 : memref<64x64xf32>, %arg1 : memref<64x64xf32>, %arg2 : memref<64x64xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c1 = arith.constant 1 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c1 : index + %n = arith.muli %block_id_y, %c1 : index + %init_tile_1 = xetile.init_tile %arg0[%m, %n] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> + %load_tile_1 = xetile.load_tile %init_tile_1: !xetile.tile<64x64xf32, #xetile.tile_attr>> -> vector<64x64xf32> + %init_tile_2 = xetile.init_tile %arg1[%m, %n] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> + %load_tile_2 = xetile.load_tile %init_tile_2: !xetile.tile<64x64xf32, #xetile.tile_attr>> -> vector<64x64xf32> + %convert_layout = xetile.convert_layout %load_tile_1 {wg_map_result = #xetile.wg_map, wg_map_source = #xetile.wg_map} : vector<64x64xf32> + %add = arith.addf %load_tile_2, %convert_layout {map = #xetile.wg_map} : vector<64x64xf32> + %init_store_tile = xetile.init_tile %arg2[%m, %n] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> + xetile.store_tile %add, %init_store_tile : vector<64x64xf32>, !xetile.tile<64x64xf32, #xetile.tile_attr>> + gpu.return + } +} + +func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1_f32 = arith.constant 1.0 : f32 + %c2_f32 = arith.constant 2.0 : f32 + %c100_f32 = arith.constant 100.0 : f32 + %a = memref.alloc() : memref<64x64xf32> + %b = memref.alloc() : memref<64x64xf32> + %c_ref = memref.alloc() : memref<64x64xf32> + + + // intialize matrix A, B ; A[i, j] = 1 + scf.for %i = %c0 to %c64 step %c1 { + scf.for %j = %c0 to %c64 step %c1 { + memref.store %c1_f32, %a[%i, %j] : memref<64x64xf32> + memref.store %c1_f32, %b[%i, %j] : memref<64x64xf32> + memref.store %c2_f32, %c_ref[%i, %j] : memref<64x64xf32> + } + } + + %c = call @convert_layout(%a, %b) : (memref<64x64xf32>, memref<64x64xf32>) -> memref<64x64xf32> + %cast_c = memref.cast %c : memref<64x64xf32> to memref<*xf32> + %cast_c_ref = memref.cast %c_ref :memref<64x64xf32> to memref<*xf32> + //call @printMemrefF32(%cast_c): (memref<*xf32>) -> () + //call @printMemrefF32(%cast_c_ref): (memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_c, %cast_c_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %a : memref<64x64xf32> + memref.dealloc %b : memref<64x64xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/wg_reduction.mlir b/test/Integration/Dialect/XeTile/wg_reduction.mlir new file mode 100644 index 000000000..f511c69ab --- /dev/null +++ b/test/Integration/Dialect/XeTile/wg_reduction.mlir @@ -0,0 +1,94 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +module @reduction attributes {gpu.container_module} { + func.func @reduce_test(%a: memref<256x1024xf32>) -> memref<1x1024xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + + %a_gpu = gpu.alloc host_shared () : memref<256x1024xf32> + memref.copy %a, %a_gpu : memref<256x1024xf32> to memref<256x1024xf32> + %b_gpu = gpu.alloc host_shared () : memref<1x1024xf32> + + gpu.launch_func @kernel::@test_reduction blocks in (%c1, %c8, %c1) threads in (%c8, %c4, %c1) args(%a_gpu : memref<256x1024xf32>, %b_gpu : memref<1x1024xf32>) + + gpu.dealloc %a_gpu : memref<256x1024xf32> + return %b_gpu : memref<1x1024xf32> + } + +gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_reduction(%arg0 : memref<256x1024xf32>, %arg1 : memref<1x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c128 = arith.constant 128 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c256 : index + %n = arith.muli %block_id_y, %c128 : index + %init_tile = xetile.init_tile %arg0[%m, %n] : memref<256x1024xf32> -> !xetile.tile<256x128xf32, #xetile.tile_attr>> + %load_tile = xetile.load_tile %init_tile: !xetile.tile<256x128xf32, #xetile.tile_attr>> -> vector<256x128xf32> + %cst_0 = arith.constant {map = #xetile.wg_map} dense<0.0> : vector<8x128xf32> + %reshape = vector.shape_cast %load_tile {map = #xetile.wg_map} : vector<256x128xf32> to vector<8x32x128xf32> + %reduction = vector.multi_reduction , %reshape, %cst_0 {map = #xetile.wg_map} [1] : vector<8x32x128xf32> to vector<8x128xf32> + %conv_layout = xetile.convert_layout %reduction {wg_map_result = #xetile.wg_map} : vector<8x128xf32> + %cst_1 = arith.constant {map = #xetile.wg_map} dense<0.0> : vector<128xf32> + %reduce = vector.multi_reduction , %conv_layout, %cst_1 {map = #xetile.wg_map} [0] : vector<8x128xf32> to vector<128xf32> + %shape_cast = vector.shape_cast %reduce {map = #xetile.wg_map} : vector<128xf32> to vector<1x128xf32> + %init_store_tile = xetile.init_tile %arg1[%c0, %n] : memref<1x1024xf32> -> !xetile.tile<1x128xf32, #xetile.tile_attr>> + xetile.store_tile %shape_cast, %init_store_tile : vector<1x128xf32>, !xetile.tile<1x128xf32, #xetile.tile_attr>> + gpu.return + } +} + +func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c1024 = arith.constant 1024 : index + %c32 = arith.constant 32 : index + %c0_f32 = arith.constant 0.0 : f32 + %c32_f32 = arith.constant 32.0 : f32 + %c1_f32 = arith.constant 1.0 : f32 + %c100_f32 = arith.constant 100.0 : f32 + %a = memref.alloc() : memref<256x1024xf32> + %b_ref = memref.alloc() : memref<1024xf32> + + + // intialize matrix A ; A[i, j] = 1 + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + memref.store %c1_f32, %a[%i, %j] : memref<256x1024xf32> + } + } + + scf.for %j = %c0 to %c1024 step %c1 { + %sum = scf.for %i = %c0 to %c256 step %c1 iter_args(%arg = %c0_f32) -> (f32) { + %val = memref.load %a[%i, %j] : memref<256x1024xf32> + %2 = arith.addf %arg, %val : f32 + scf.yield %2 : f32 + } + memref.store %sum, %b_ref[%j] : memref<1024xf32> + } + + %b = call @reduce_test(%a) : (memref<256x1024xf32>) -> memref<1x1024xf32> + %cast_b = memref.cast %b : memref<1x1024xf32> to memref<*xf32> + %cast_b_ref = memref.cast %b_ref : memref<1024xf32> to memref<*xf32> + //call @printMemrefF32(%cast_b): (memref<*xf32>) -> () + //call @printMemrefF32(%cast_b_ref): (memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_b, %cast_b_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %a : memref<256x1024xf32> + memref.dealloc %b_ref : memref<1024xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp b/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp index 9037120f3..24d17e617 100644 --- a/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp +++ b/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp @@ -4,8 +4,9 @@ cse xetile-init-duplicate xetile-canonicalization + xetile-blockop-fallback xetile-blocking - canonicalize + cse convert-xetile-to-xegpu cse imex-xegpu-hoist-transpose @@ -13,6 +14,9 @@ imex-xegpu-optimize-transpose) cse imex-vector-linearize + cse + imex-remove-single-elem-vector + cse gpu.module(convert-xegpu-to-vc) reconcile-unrealized-casts bf16-to-gpu