From 81f331e9873610e3981c110dd52884b3c4614d2d Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Mon, 16 Dec 2024 18:45:57 +0000 Subject: [PATCH] Remove old implementation of XeTileToXeGPU. --- include/imex/Conversion/Passes.td | 5 +- .../Conversion/XeTileToXeGPU/XeTileToXeGPU.h | 9 +- .../Transforms/XeTileOneToNConversion.h} | 4 +- lib/Conversion/XeGPUToVC/XeGPUToVC.cpp | 16 +- .../XeTileToXeGPU/ArithOpConversion.cpp | 407 ------ .../XeTileToXeGPU/ArithOpConversion.h | 29 - lib/Conversion/XeTileToXeGPU/CMakeLists.txt | 4 - .../XeTileToXeGPU/SCFOpConversion.cpp | 124 -- .../XeTileToXeGPU/SCFOpConversion.h | 29 - .../XeTileToXeGPU/XeTileOpConversion.cpp | 1226 ----------------- .../XeTileToXeGPU/XeTileOpConversion.h | 30 - .../XeTileToXeGPU/XeTileToXeGPU.cpp | 256 ++-- lib/Dialect/XeTile/Transforms/Blocking.cpp | 41 +- .../XeTile/Transforms/BlockingAnalysis.cpp | 26 +- lib/Dialect/XeTile/Transforms/CMakeLists.txt | 1 + lib/Dialect/XeTile/Transforms/WgToSg.cpp | 5 +- .../Transforms/XeTileOneToNConversion.cpp} | 19 +- test/Conversion/XeTileToXeGPU/addf.mlir | 214 --- .../XeTileToXeGPU/array_length_load.mlir | 1 - .../Conversion/XeTileToXeGPU/create_mask.mlir | 59 - .../XeTileToXeGPU/elementwise_ops.mlir | 297 ---- test/Conversion/XeTileToXeGPU/gemm_preop.mlir | 2 +- test/Conversion/XeTileToXeGPU/lit.local.cfg | 10 + .../XeTileToXeGPU/non_pow2_stacking.mlir | 24 - test/Conversion/XeTileToXeGPU/prefetch.mlir | 29 - test/Conversion/XeTileToXeGPU/reduction.mlir | 273 ---- .../sg_gemm_1k_1k_1k_f16_f32.mlir | 52 +- .../sg_gemm_1k_1k_1k_f16_f32_slm.mlir | 2 +- .../sg_gemm_1k_1k_1k_i8_i32.mlir | 28 +- .../sg_gemm_1k_1k_1k_tf32_tf32.mlir | 30 +- .../XeTileToXeGPU/sg_init_tile.mlir | 4 +- .../XeTileToXeGPU/sg_load_tile.mlir | 5 +- .../XeTileToXeGPU/sg_mixed_scf.mlir | 2 +- .../XeTileToXeGPU/sg_scattered_ops.mlir | 78 +- test/Conversion/XeTileToXeGPU/sg_scf_for.mlir | 44 +- test/Conversion/XeTileToXeGPU/sg_softmax.mlir | 169 +-- .../XeTileToXeGPU/sg_store_tile.mlir | 5 +- .../Conversion/XeTileToXeGPU/sg_tile_mma.mlir | 21 +- .../XeTileToXeGPU/sg_tiled_broadcast.mlir | 95 -- .../XeTileToXeGPU/sg_tiled_load_tile.mlir | 20 - .../XeTileToXeGPU/sg_tiled_scattered_ops.mlir | 69 - .../XeTileToXeGPU/sg_tiled_scf_for.mlir | 57 - .../XeTileToXeGPU/sg_tiled_softmax.mlir | 346 ----- .../XeTileToXeGPU/sg_tiled_store_tile.mlir | 58 - .../XeTileToXeGPU/sg_tiled_tile_mma.mlir | 91 -- test/Conversion/XeTileToXeGPU/test_order.mlir | 19 +- test/Conversion/XeTileToXeGPU/unit_tests.mlir | 75 + .../Conversion/XeTileToXeGPU/unpack_pack.mlir | 45 - .../Blocking/unit_tests_transform.mlir | 911 +++++------- .../Dialect/XeTile/xetile-to-func-vc.pp | 2 +- .../Dialect/XeTile/xetile-wg-to-func-vc.pp | 2 +- 51 files changed, 758 insertions(+), 4612 deletions(-) rename include/imex/{Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h => Dialect/XeTile/Transforms/XeTileOneToNConversion.h} (97%) delete mode 100644 lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp delete mode 100644 lib/Conversion/XeTileToXeGPU/ArithOpConversion.h delete mode 100644 lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp delete mode 100644 lib/Conversion/XeTileToXeGPU/SCFOpConversion.h delete mode 100644 lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp delete mode 100644 lib/Conversion/XeTileToXeGPU/XeTileOpConversion.h rename lib/{Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp => Dialect/XeTile/Transforms/XeTileOneToNConversion.cpp} (93%) delete mode 100644 test/Conversion/XeTileToXeGPU/addf.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/create_mask.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/elementwise_ops.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/non_pow2_stacking.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/prefetch.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/reduction.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/sg_tiled_broadcast.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/sg_tiled_load_tile.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/sg_tiled_scattered_ops.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/sg_tiled_scf_for.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/sg_tiled_softmax.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/sg_tiled_store_tile.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/sg_tiled_tile_mma.mlir create mode 100644 test/Conversion/XeTileToXeGPU/unit_tests.mlir delete mode 100644 test/Conversion/XeTileToXeGPU/unpack_pack.mlir diff --git a/include/imex/Conversion/Passes.td b/include/imex/Conversion/Passes.td index 35d6d2ff5..87d3ba251 100644 --- a/include/imex/Conversion/Passes.td +++ b/include/imex/Conversion/Passes.td @@ -394,10 +394,7 @@ def ConvertXeTileToXeGPU: Pass<"convert-xetile-to-xegpu", "::mlir::gpu::GPUModul let options = [ Option<"device", "device", "std::string", /*default=*/"\"pvc\"", - "gpu platform architecture where these ops are running">, - Option<"EnableTransform", "enable-2d-transform", "bool", - /*default=*/"false", - "Using 2D transform or 4D Conversion."> + "gpu platform architecture where these ops are running"> ]; } diff --git a/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h b/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h index 7868a81ff..276dd8481 100644 --- a/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h +++ b/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h @@ -21,8 +21,6 @@ #include #include -#include "XeTileToXeGPUConversion.h" - namespace mlir { class MLIRContext; class ModuleOp; @@ -37,12 +35,9 @@ namespace imex { #define GEN_PASS_DECL_CONVERTXETILETOXEGPU #include "imex/Conversion/Passes.h.inc" -class XeOneToNTypeConverter; - /// Populate the given list with patterns rewrite XeTile Ops -void populateXeTileToXeGPUConversionPatterns(XeOneToNTypeConverter &converter, - mlir::RewritePatternSet &patterns, - imex::TileUsageAnalysis &analysis); +void populateXeTileToXeGPUConversionPatterns(mlir::TypeConverter &converter, + mlir::RewritePatternSet &patterns); /// Create a pass to convert the XeTile dialect to the XeGPU dialect. std::unique_ptr> diff --git a/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h b/include/imex/Dialect/XeTile/Transforms/XeTileOneToNConversion.h similarity index 97% rename from include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h rename to include/imex/Dialect/XeTile/Transforms/XeTileOneToNConversion.h index d33a26b72..71d04f4ce 100644 --- a/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h +++ b/include/imex/Dialect/XeTile/Transforms/XeTileOneToNConversion.h @@ -1,4 +1,4 @@ -//===- TypeConverter.h - XeTileToXeGPU conversion -------*- C++ -*-===// +//===- XeTileOneToNConversion.h --- XeTileOneToNConversion -----*- C++ -*-===// // // Copyright 2022 Intel Corporation // Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. @@ -9,7 +9,7 @@ /// /// \file /// This file defines the XeOneToNConversion, the base class for -/// XeTileToXeGPU conversion, XeOneToNTypeConverter, converting types used in +/// doing OneToN conversion, XeOneToNTypeConverter, converting types used in /// XeTile dialect to types used in XeGPU dialect, XeOneToNPatternRewriter a /// wrapper around ConversionPatterRewriter providng interface for supporting /// OneToN replace. diff --git a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp index 550a2257a..7ce8422cc 100644 --- a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp +++ b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp @@ -603,6 +603,7 @@ class VectorShapeCastPattern : public OpConversionPattern { if (!dstType) return failure(); + if (dstType == adaptor.getSource().getType() || shapeCastOp.getResultVectorType().getNumElements() == 1) { rewriter.replaceOp(shapeCastOp, adaptor.getSource()); @@ -760,7 +761,10 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase { target.addDynamicallyLegalDialect( [&](Operation *op) { return isLegalXeGPUSCFOp(op, typeConverter); }); - target.addIllegalOp(); + target.addDynamicallyLegalOp([&](ShapeCastOp op) { + return typeConverter.isLegal(op.getType()) && + typeConverter.isLegal(op.getSource().getType()); + }); // TODO: can we change it to addDynamicLegalOp? target.addLegalOp(); @@ -786,8 +790,9 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase { }); typeConverter.addConversion([&](VectorType type) -> Type { - // TODO: it looks like needs some improvement for matching upstream - // passes + // TODO: I don't think we need to convert 2D VectorType to + // 1D VectorType. It needs to removed after we move vector + // linearization after this pass unsigned rank = type.getRank(); auto elemType = type.getElementType(); @@ -795,6 +800,11 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase { if (rank < 1) return elemType; + // TODO: a temporary fix to avoid do type conversion + // for create_mask result + if (elemType.isInteger(1)) + return type; + unsigned sum = 1; for (unsigned i = 0; i < rank; i++) { sum *= type.getShape()[i]; diff --git a/lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp deleted file mode 100644 index 08aa74a6b..000000000 --- a/lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp +++ /dev/null @@ -1,407 +0,0 @@ -//===- ArithOpConversion.cpp - XeTileToXeGPU conversion -------*- C++ -*-===// -// -// Copyright 2022 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the ArithOpConversionPattern, used in XeTileToXeGPU -/// conversion, converting the Arith Ops. -/// -//===----------------------------------------------------------------------===// - -#include "ArithOpConversion.h" - -namespace imex { - -static mlir::Value createBinOp(mlir::vector::CombiningKind kind, - mlir::Value lhs, mlir::Value rhs, - mlir::Type elemTy, mlir::Location &loc, - XeOneToNPatternRewriter &rewriter) { - - // ADD and MUL are defined for both Integers and Floats, - // need to generate code based on element data type. - if (kind == mlir::vector::CombiningKind::ADD) { - if (mlir::isa(elemTy)) { - return rewriter.create(loc, lhs, rhs); - } - if (mlir::isa(elemTy)) { - return rewriter.create(loc, lhs, rhs); - } - } - - if (kind == mlir::vector::CombiningKind::MUL) { - if (mlir::isa(elemTy)) { - return rewriter.create(loc, lhs, rhs); - } - if (mlir::isa(elemTy)) { - return rewriter.create(loc, lhs, rhs); - } - } - - switch (kind) { - // the following are for ints only - case mlir::vector::CombiningKind::MINUI: - return rewriter.create(loc, lhs, rhs); - case mlir::vector::CombiningKind::MINSI: - return rewriter.create(loc, lhs, rhs); - case mlir::vector::CombiningKind::MAXUI: - return rewriter.create(loc, lhs, rhs); - case mlir::vector::CombiningKind::MAXSI: - return rewriter.create(loc, lhs, rhs); - case mlir::vector::CombiningKind::AND: - return rewriter.create(loc, lhs, rhs); - case mlir::vector::CombiningKind::OR: - return rewriter.create(loc, lhs, rhs); - case mlir::vector::CombiningKind::XOR: - return rewriter.create(loc, lhs, rhs); - // the following are for floats only - case mlir::vector::CombiningKind::MINNUMF: - return rewriter.create(loc, lhs, rhs); - case mlir::vector::CombiningKind::MAXNUMF: - return rewriter.create(loc, lhs, rhs); - case mlir::vector::CombiningKind::MINIMUMF: - return rewriter.create(loc, lhs, rhs); - case mlir::vector::CombiningKind::MAXIMUMF: - return rewriter.create(loc, lhs, rhs); - default: - llvm_unreachable("Unexpected CombiningKind."); - return lhs; - } -} - -llvm::SmallVector -lowerOuterReduction(mlir::ValueRange sources, llvm::ArrayRef shape, - mlir::vector::CombiningKind kind, mlir::Location loc, - mlir::Type elemTy, XeOneToNPatternRewriter &rewriter) { - assert(shape.size() == 4 && "shape should be 4D."); - llvm::SmallVector intermediates; - for (auto j = 0; j < shape[1]; j++) { - auto combiningVal = sources[j]; - for (auto i = 1; i < shape[0]; i++) { - combiningVal = createBinOp(kind, combiningVal, sources[i * shape[1] + j], - elemTy, loc, rewriter); - } - { - // TODO: After blocking If the first dimension of the small block is not - // 1, the combiningVal is now in shape as, e.g., vector<4x16xf16> instead - // of vector<1x16xf16> then more reductions are needed in dim0, to make it - // as vector<1x16xf16>. Currently, this is not implemented, since we are - // now restricted blocking pass to set it as 1 now. It may cannot achieve - // peak performance in some cases. - assert(shape[2] == 1 && - "more reductions is needed in dim0, but not supported."); - } - intermediates.push_back(combiningVal); - } - return intermediates; -} - -// expected input is type of vector, where i and n is power of 2 -// and the third dim is always 1, which should be set by the blocking pass. -// For a vector of vector<32x64xf16> with reduction on dim 1, it will blocked -// into a vector<32x4x1x16> with reduction on dim 1 and dim 3. -// lowerInnerReductionWithIntraVectorShuffles performs the reduction with -// arithmetic operations on vector<16xf16>. To perform reduction on dim 1, -// simple vector arithmetic operations are issued, we will get 32 vectors of -// vector<16xf16>, each vector<16xf16> represents the partial reduction result -// of each row. To perform redcution on dim 3, it uses two vector shuffles -/// to shuffle values from two conjuction rows. For example, given -// row1 = [a0, a1, ..., a15], and row2 = [b0, b1, ..., b15]. It will shuffle -// the vector into row1' = [a0, .., a7, b0, ..., b7], -// row2' = [a8, ..., a15, b8, ..., b15], and then perform the vector arith op -// on row1' and row2', geting the result: c = [c0, ..., c7, c8, ..., c15]. -// here, c0, ..., c7 are the partial reduction results of row1 and c8, ..., c15 -// are the partial results of row2. This process will be repeated until get the -// final result, such that each element in c represents a final reduction result -// of a row. -llvm::SmallVector lowerInnerReductionWithIntraVectorShuffles( - mlir::ValueRange sources, llvm::ArrayRef shape, - mlir::vector::CombiningKind kind, mlir::Location loc, mlir::Type elemTy, - XeOneToNPatternRewriter &rewriter) { - - assert(shape.size() == 4 && "shape should be 4D."); - - auto isPowerOfTwo = [](auto n) { return (n & (n - 1)) == 0; }; - - // make sure the dim0 of the block is 1 in blocking pass - // different from outer reduction, this is strictly required - // for this method. - assert(shape[2] == 1 && "dim0 of the block has to be 1."); - assert(isPowerOfTwo(shape[0]) && isPowerOfTwo(shape[3]) && - "sizes of dim1 and dim4 should be power of 2."); - - auto genShuffleMasks = [&](int blkSize, int vecSize) { - llvm::SmallVector mask1; - llvm::SmallVector mask2; - auto s1 = 0, s2 = blkSize; - for (auto i = 0; i < vecSize; i++) { - if (i && i % blkSize == 0) { - s1 += blkSize; - s2 += blkSize; - } - - mask1.push_back(s1); - mask2.push_back(s2); - s1++; - s2++; - } - return std::make_pair(mask1, mask2); - }; - - // Stage 1: vector equals to a grid of ixj of vector<1xnxf16> - // after lowering to xegpu. This stage performs j-1 reduction operations on - // j dim of the grid, the result is a vector of vector. - llvm::SmallVector intermediates(shape[0]); - for (auto i = 0; i < shape[0]; i++) { - auto combiningVal = sources[i * shape[1]]; - for (auto j = 1; j < shape[1]; j++) { - combiningVal = createBinOp(kind, combiningVal, sources[i * shape[1] + j], - elemTy, loc, rewriter); - } - // cast the result of e.g., vector<1x16xf16> into vector<16xf16> - auto targetTy = mlir::VectorType::get({shape[3]}, elemTy); - combiningVal = - rewriter.create(loc, targetTy, combiningVal); - intermediates[i] = combiningVal; - } - - // Stage 2: doing intra vector reduction with shuffle Ops. - // Each vector in the result of stage 1 can be viewed as a row - // each row has e.g., 32 elements: - // v1 = [a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 ... a31] - // v2 = [b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 ... b31] - // ... - // vn = [p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 ... p31] - // To reduce it, we repeatedly shuffle halves of two consecutive vectors. - // One can view it as: transpose halves of two partial aggregates, reduce - // vertically, get 1 vector with reduced halves of two vectors. For example, - // for v1 and v2, we get: - // nv1 = [a0, .., a15, b0, .., b15] - // nv2 = [a16, .., a31, b16, .., b31] - // nv_reduced = reductionOp(nv1,nv2) - // such that the left half of the vector contains the partial reduction - // of v1, and the right half contains the partial reduction of v2. - // and the the number of vectors is reduced by half after one iteration. - // and we reduce the block size by half, and repeat the process until - // the block size is 1. - // The intermediate result of this stage is an array of vectors with - // type, e.g., vector, array size is `i/n`. And these vectors - // will be merged into a single vector with type vector. - - // each row should not have > 1 partial aggregate at the end - auto partialRowAggSize{shape[3]}; - auto numVecsLeft{shape[0]}; - while (partialRowAggSize != 1 && numVecsLeft != 1) { - partialRowAggSize /= 2; - auto workList = intermediates; - intermediates.clear(); - assert(workList.size() % 2 == 0 && "The size should be divisible by 2."); - auto masks = genShuffleMasks(partialRowAggSize, shape[3]); - for (size_t i = 0; i < workList.size(); i += 2) { - auto v1 = workList[i]; - auto v2 = workList[i + 1]; - auto shuffleOp1 = - rewriter.create(loc, v1, v2, masks.first); - auto shuffleOp2 = - rewriter.create(loc, v1, v2, masks.second); - auto reductionVal = - createBinOp(kind, shuffleOp1, shuffleOp2, elemTy, loc, rewriter); - intermediates.push_back(reductionVal); - } - numVecsLeft /= 2; - } - - if (partialRowAggSize > 1) { - assert(intermediates.size() == 1 && - "We must have ONE row with non-finalized aggregates."); - auto toFinalize = intermediates.back(); - intermediates.clear(); - uint32_t currentAggVecSize = shape[3]; - do { - currentAggVecSize /= 2; - partialRowAggSize /= 2; - auto [vecUpperMask, vecLowerMask] = - genShuffleMasks(partialRowAggSize, currentAggVecSize); - auto shuffleOp1 = rewriter.create( - loc, toFinalize, toFinalize, vecUpperMask); - auto shuffleOp2 = rewriter.create( - loc, toFinalize, toFinalize, vecLowerMask); - toFinalize = - createBinOp(kind, shuffleOp1, shuffleOp2, elemTy, loc, rewriter); - } while (partialRowAggSize != 1); - intermediates.push_back(toFinalize); - } - return intermediates; -} - -// TODO: Debug the IGC crash on this path. Currently, the upstream lows -// vector.reduction into a spirv.CL.mul operation. But the generated -// code caused a crash in IGC. -llvm::SmallVector lowerInnerReductionWithVectorReduction( - mlir::ValueRange sources, llvm::ArrayRef shape, - mlir::vector::CombiningKind kind, mlir::Location loc, mlir::Type elemTy, - XeOneToNPatternRewriter &rewriter) { - - assert(shape.size() == 4 && "shape should be 4D."); - // vector equals to a grid of ixj of vector<1xnxf16> - // this stage will use vector.shapecast to cast vector<1xnxf16> into 1D and - // use vector.reduction firstly to perform the reduction over each vector, - // and then use arith opertors to perform the reduction over the - // aforementioned results for a row. - llvm::SmallVector results; - for (auto i = 0; i < shape[0]; i++) { - llvm::SmallVector reductions; - // perform reduction over each vector in a row - for (auto j = 0; j < shape[1]; j++) { - auto targetTy = mlir::VectorType::get({shape[2] * shape[3]}, elemTy); - auto cast = rewriter.create( - loc, targetTy, sources[i * shape[1] + j]); - auto value = rewriter.create(loc, kind, cast); - reductions.push_back(value); - } - auto reductionVal = reductions[0]; - // perform reduction over the results of each vector in a row - for (auto j = 1; j < shape[1]; j++) { - reductionVal = - createBinOp(kind, reductionVal, reductions[j], elemTy, loc, rewriter); - } - results.push_back(reductionVal); - } - return results; -} - -class SgVectorMultiDimReductionOpPattern - : public XeOneToNConversion { - using XeOneToNConversion< - mlir::vector::MultiDimReductionOp>::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(mlir::vector::MultiDimReductionOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto srcTy = op.getSource().getType(); - auto elemTy = srcTy.getElementType(); - auto dims = op.getReductionDims(); - // its input should be a 4D vector, and has 2 reduction dims, - // otherwise run the blocking pass first. - if (dims.size() != 2 || srcTy.getRank() != 4) - return mlir::failure(); - - auto loc = op.getLoc(); - auto shape = srcTy.getShape(); - auto sources = adaptor.getSource(); - - rewriter.setInsertionPoint(op); - // doing reduction on outer dimension - if (dims[0] == 0 && dims[1] == 2) { - auto intermediates = lowerOuterReduction(sources, shape, op.getKind(), - loc, elemTy, rewriter); - { - // TODO: need a better way to represent the result (align with - // unpack/pack logic). currently we just shuffle them and cast it to the - // type/shape in xetile program. - auto reducedVal = packVectorsWith(intermediates, concat, loc, rewriter); - auto targetTy = mlir::VectorType::get({shape[1], shape[3]}, elemTy); - auto newOp = rewriter.create(loc, targetTy, - reducedVal); - rewriter.replaceOp(op, newOp); - } - return mlir::success(); - } - - // doing reduction on inner dimension - if (dims[0] == 1 && dims[1] == 3) { - auto intermediates = lowerInnerReductionWithIntraVectorShuffles( - sources, shape, op.getKind(), loc, elemTy, rewriter); - - { // TODO: need a better way to represent the result (align with - // unpack/pack logic). - // currently we just shuffle them and cast it to the type/shape in - // xetile program. - auto reductionVal = - packVectorsWith(intermediates, concat, loc, rewriter); - auto targetTy = mlir::VectorType::get({shape[0], shape[2]}, elemTy); - auto newOp = rewriter.create(loc, targetTy, - reductionVal); - rewriter.replaceOp(op, newOp); - } - return mlir::success(); - } - - // something is wrong - return op.emitError("unsupported reduction operation."); - } -}; - -class SgArithConstantOpPattern - : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(mlir::arith::ConstantOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto value = llvm::dyn_cast(op.getValue()); - - // We only interesting 4D vectors - if (!value || value.getType().getRank() != 4) - return mlir::failure(); - - llvm::SmallVector elems( - value.value_begin(), - value.value_end()); - - auto shape = value.getType().getShape(); - auto elemTy = value.getElementType(); - auto vecTy = mlir::VectorType::get({shape[2], shape[3]}, elemTy); - - // slice a block of (shape[2], shape[3]) from elems. - auto slice = [&](int i, int j) { - llvm::SmallVector block; - auto width = shape[1] * shape[3]; - i = i * shape[2]; - j = j * shape[3]; - for (int64_t r = 0; r < shape[2]; r++) - for (int64_t c = 0; c < shape[3]; c++) - block.push_back(elems[(i + r) * width + j + c]); - return block; - }; - - rewriter.setInsertionPoint(op); - llvm::SmallVector newOps; - for (auto i = 0; i < shape[0]; i++) { - for (auto j = 0; j < shape[1]; j++) { - auto values = slice(i, j); - auto attr = mlir::DenseElementsAttr::get(vecTy, values); - auto newOp = rewriter.create(loc, attr); - newOps.push_back(newOp); - } - } - - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -bool isLegalArithOp(mlir::Operation *op) { - if (llvm::isa(op)) { - auto constOp = llvm::cast(op); - auto resultTy = constOp.getResult().getType(); - if (mlir::isa(resultTy) && - mlir::cast(resultTy).getRank() == 4) - return false; - } - return true; -} - -void populateArithOpConversionPatterns(imex::XeOneToNTypeConverter &converter, - mlir::RewritePatternSet &patterns, - TileUsageAnalysis &analysis) { - patterns.add( - patterns.getContext(), converter, analysis); -} - -} // namespace imex diff --git a/lib/Conversion/XeTileToXeGPU/ArithOpConversion.h b/lib/Conversion/XeTileToXeGPU/ArithOpConversion.h deleted file mode 100644 index acdfa424f..000000000 --- a/lib/Conversion/XeTileToXeGPU/ArithOpConversion.h +++ /dev/null @@ -1,29 +0,0 @@ -//===- ArithOpConversion.h - XeTileToXeGPU conversion -------*- C++ -*-===// -// -// Copyright 2022 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines the ArithOpConversionPattern, used in XeTileToXeGPU -/// conversion, converting the Arith Ops. -/// -//===----------------------------------------------------------------------===// -#ifndef _ArithOpConversion_H_INCLUDED_ -#define _ArithOpConversion_H_INCLUDED_ - -#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h" -#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h" - -namespace imex { -bool isLegalArithOp(mlir::Operation *op); - -void populateArithOpConversionPatterns(imex::XeOneToNTypeConverter &converter, - mlir::RewritePatternSet &patterns, - TileUsageAnalysis &analysis); - -} // namespace imex -#endif diff --git a/lib/Conversion/XeTileToXeGPU/CMakeLists.txt b/lib/Conversion/XeTileToXeGPU/CMakeLists.txt index 7920c9a1a..8712d39a1 100644 --- a/lib/Conversion/XeTileToXeGPU/CMakeLists.txt +++ b/lib/Conversion/XeTileToXeGPU/CMakeLists.txt @@ -1,9 +1,5 @@ add_imex_conversion_library(IMEXXeTileToXeGPU - ArithOpConversion.cpp - SCFOpConversion.cpp XeTileToXeGPU.cpp - XeTileOpConversion.cpp - XeTileToXeGPUConversion.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/imex/Conversion/XeTileToXeGPU diff --git a/lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp deleted file mode 100644 index f9f628f15..000000000 --- a/lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp +++ /dev/null @@ -1,124 +0,0 @@ -//===- SCFOpConversion.cpp - XeTileToXeGPU conversion -------*- C++ -*-===// -// -// Copyright 2022 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the Conversion Patter for SCFOps. -/// -//===----------------------------------------------------------------------===// - -#include - -#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h" - -namespace imex { - -struct SgSCFForOpBlockPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor, - imex::XeOneToNPatternRewriter &rewriter) const override { - // OpAdaptor is defined with ValueRange, so it contains results after - // One-to-N mapping - llvm::SmallVector convertedArgs; - for (auto &values : adaptor.getInitArgs()) - convertedArgs.append(values.begin(), values.end()); - - auto newOp = rewriter.create( - op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), - convertedArgs); - - // compute the type mapping (from origial to convereted) between - // orginal args and converted args. The standard typeconverter - // way doesnot work because array_length value is set in-the-fly based on - // whether tile is create for load or not, thus a TileType could be - // lowered into many different types of TensorDescType (due to different - // setting of array_length). But typeconverter has no knowledge about when - // to use array_lenght and when not. - auto typeConverter = getTypeConverter(); - auto argTys = op.getRegion().getArgumentTypes(); - mlir::OneToNTypeMapping argumentMapping(argTys); // vectorty - llvm::ArrayRef args(op.getRegion().getArguments().begin(), - op.getRegion().getArguments().end()); - llvm::ArrayRef newArgs( - newOp.getRegion().getArguments().begin(), - newOp.getRegion().getArguments().end()); - auto status = - typeConverter.computeTypeMapping(args, newArgs, argumentMapping); - if (mlir::failed(status)) { - llvm_unreachable("It is an unexpected failure of computing " - "type mapping for SCF::ForOp arguments."); - } - - // apply the signature convertion for SCFFor body arguments, an - // UnrealizedConversionCastOp will be inserted by typeConverter - rewriter.applySignatureConversion(&op.getRegion().getBlocks().front(), - argumentMapping); - - if (newOp.getBody()) - rewriter.eraseBlock(newOp.getBody()); - rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), - newOp.getRegion().end()); - - rewriter.replaceOp(op, newOp.getResults()); - return mlir::success(); - } -}; - -struct SgSCFYieldOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor, - imex::XeOneToNPatternRewriter &rewriter) const override { - llvm::SmallVector convertedResults; - for (auto &values : adaptor.getResults()) - convertedResults.append(values.begin(), values.end()); - - auto newOp = - rewriter.create(op.getLoc(), convertedResults); - - rewriter.replaceOp(op, newOp); - return mlir::success(); - } -}; - -bool isLegalSCFOp(mlir::Operation *op) { - bool result = true; - if (llvm::isa(op)) { - auto forOp = llvm::cast(op); - for (const auto &arg : forOp.getInitArgs()) { - auto type = arg.getType(); - result &= !mlir::isa(type); - - if (mlir::isa(type)) - result &= (mlir::cast(type).getRank() != 4); - } - } - - if (llvm::isa(op)) { - auto yieldOp = llvm::cast(op); - for (const auto &arg : yieldOp.getResults()) { - auto type = arg.getType(); - result &= !mlir::isa(type); - if (mlir::isa(type)) - result &= (mlir::cast(type).getRank() != 4); - } - } - return result; -} - -void populateSCFOpConversionPatterns(imex::XeOneToNTypeConverter &converter, - mlir::RewritePatternSet &patterns, - TileUsageAnalysis &analysis) { - patterns.add( - patterns.getContext(), converter, analysis); -} - -} // namespace imex diff --git a/lib/Conversion/XeTileToXeGPU/SCFOpConversion.h b/lib/Conversion/XeTileToXeGPU/SCFOpConversion.h deleted file mode 100644 index e540e7d62..000000000 --- a/lib/Conversion/XeTileToXeGPU/SCFOpConversion.h +++ /dev/null @@ -1,29 +0,0 @@ -//===- ArithOpConversion.h - XeTileToXeGPU conversion -------*- C++ -*-===// -// -// Copyright 2022 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines the ArithOpConversionPattern, used in XeTileToXeGPU -/// conversion, converting the Arith Ops. -/// -//===----------------------------------------------------------------------===// -#ifndef _SCFOpConversion_H_INCLUDED_ -#define _SCFOpConversion_H_INCLUDED_ - -#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h" -#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h" - -namespace imex { -bool isLegalSCFOp(mlir::Operation *op); - -void populateSCFOpConversionPatterns(imex::XeOneToNTypeConverter &converter, - mlir::RewritePatternSet &patterns, - TileUsageAnalysis &analysis); - -} // namespace imex -#endif diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp deleted file mode 100644 index 80d6b1fcd..000000000 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ /dev/null @@ -1,1226 +0,0 @@ -//===- XeTileOpConversion.h - XeTileToXeGPU conversion -------*- C++ -*-===// -// -// Copyright 2022 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements ConversionPatterns for XeTileOps, used in XeTileToXeGPU -/// conversion, converting the XeTile dialect to the XeGPU dialect. -/// -//===----------------------------------------------------------------------===// - -#include "XeTileOpConversion.h" -#include "imex/Utils/XeArch.h" -#include "imex/Utils/XeCommon.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "llvm/ADT/SmallVector.h" -#include -#include -#include -#include -#include -#include -#include - -namespace imex { -using namespace mlir; -using mlir::vector::CreateMaskOp; -using mlir::vector::ExtractOp; -using mlir::vector::ExtractStridedSliceOp; -using mlir::vector::ShapeCastOp; -using mlir::vector::ShuffleOp; -using mlir::vector::SplatOp; - -// Check that lowerUnpackOrPack will be able to evenly combine/split the input -// grid into the output grid. -static bool isUnpackPackCompatible(xetile::TileUnpackOp unpackOp, - xetile::TilePackOp packOp) { - auto inTy = unpackOp.getInVec().getType(); - auto inGrids = inTy.getShape().take_front(2); - auto inBlkSizes = unpackOp.getInnerBlocksAttr(); - - auto outTy = packOp.getOutVec().getType(); - llvm::ArrayRef outGrids = outTy.getShape().take_front(2); - mlir::DenseI64ArrayAttr outBlkSizes = packOp.getInnerBlocksAttr(); - - if (inBlkSizes[0] < outBlkSizes[0] && inGrids[0] % outGrids[0] != 0) - return false; - if (inBlkSizes[0] > outBlkSizes[0] && outGrids[0] % inGrids[0] != 0) - return false; - if (inBlkSizes[1] < outBlkSizes[1] && inGrids[1] % outGrids[1] != 0) - return false; - if (inBlkSizes[1] > outBlkSizes[1] && outGrids[1] % inGrids[1] != 0) - return false; - - return true; -} - -// a unified function lowering Unpack and Pack ops. -static llvm::SmallVector -lowerUnpackOrPack(mlir::PatternRewriter &rewriter, mlir::Location loc, - mlir::ValueRange inputs, mlir::DenseI64ArrayAttr inBlkSizes, - mlir::DenseI64ArrayAttr outBlkSizes, - llvm::ArrayRef inGrids, - llvm::ArrayRef outGrids) { - - // handle based on the dim0, and save results into intermediates - llvm::SmallVector intermediates(outGrids[0] * inGrids[1]); - if (inBlkSizes[0] == outBlkSizes[0]) { // do nothing - intermediates = inputs; - } else if (inBlkSizes[0] < outBlkSizes[0]) { // stack on dim 0 - // `nums` small vectors will be stacked into one big vector - auto nums = inGrids[0] / outGrids[0]; - llvm::SmallVector valSet; - for (auto j = 0; j < inGrids[1]; j++) { - for (auto i = 0; i < inGrids[0]; i++) { - auto idx = i * inGrids[1] + j; - valSet.push_back(inputs[idx]); - if (valSet.size() == static_cast(nums)) { - auto newOp = packVectorsWith(valSet, stack, loc, rewriter); - intermediates[i / nums * inGrids[1] + j] = newOp; - valSet.clear(); - } - } - } - } else { - // do extract on dim0 using vector::ExtractStridedSliceOp - // intermediates.resize(outGrids[0] * inGrids[1]); - llvm::SmallVector blkSizes({outBlkSizes[0], inBlkSizes[1]}); - - // each vector will be horizonally cut into `nums` subvectors - auto nums = outGrids[0] / inGrids[0]; - llvm::SmallVector strides({1, 1}); - for (auto i = 0; i < inGrids[0]; i++) { - for (auto j = 0; j < inGrids[1]; j++) { - auto startPos = i * nums * inGrids[1] + j; - auto v = inputs[i * inGrids[1] + j]; - for (auto k = 0; k < nums; k++) { - llvm::SmallVector offsets({k * blkSizes[0], 0}); - auto newOp = rewriter.create( - loc, v, offsets, blkSizes, strides); - auto idx = startPos + k * inGrids[1]; - intermediates[idx] = newOp; - } - } - } - } - - // handle intermediates based on the dim1, and save results into newOps - llvm::SmallVector newOps; - llvm::SmallVector interGrids = {outGrids[0], inGrids[1]}; - if (inBlkSizes[1] == outBlkSizes[1]) { - // do nothing since they have the same size - newOps = intermediates; - } else if (inBlkSizes[1] < outBlkSizes[1]) { - // doing concat since blkSZ of input vector is smaller - // `nums` of small vectors will be concated into a big one - size_t nums = inGrids[1] / outGrids[1]; - llvm::SmallVector valSet; - for (auto i = 0; i < interGrids[0]; i++) { - for (auto j = 0; j < interGrids[1]; j++) { - valSet.push_back(intermediates[i * interGrids[1] + j]); - if (valSet.size() == nums) { - auto newOp = packVectorsWith(valSet, concat, loc, rewriter); - newOps.push_back(newOp); - valSet.clear(); - } - } - } - } else { // doing extract on dim 1 - llvm::SmallVector blkSizes({outBlkSizes[0], outBlkSizes[1]}); - llvm::SmallVector strides({1, 1}); - auto nums = outGrids[1] / interGrids[1]; - for (auto i = 0; i < interGrids[0]; i++) { - for (auto j = 0; j < interGrids[1]; j++) { - auto v = intermediates[i * interGrids[1] + j]; - for (int64_t k = 0; k < nums; k++) { - llvm::SmallVector offsets({0, k * blkSizes[1]}); - auto newOp = rewriter.create( - loc, v, offsets, blkSizes, strides); - newOps.push_back(newOp); - } - } - } - } - - return newOps; -} - -// It lowers a pair of Unpack and Pack operators at a time. -// the pattern first matchs TileUnpackOp, and finds its TilePackOp -// user. It can avoid some vector shuffle and extract ops by -// looking at the target block size (innerBlock from TilePackOp) -// directly. It requires 1-1 mapping of UnpackOp and PackOp, which -// should be enforced by a separate pass. -class SgTileUnpackOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(xetile::TileUnpackOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - - auto inputs = adaptor.getInVec(); - auto inTy = op.getInVec().getType(); - auto inGrids = inTy.getShape().take_front(2); - auto inBlkSizes = op.getInnerBlocksAttr(); - - // the default grids used as outGrids when unpack is not paired with a pack - int64_t defautlOutGrids[2] = {1, 1}; - llvm::ArrayRef outGrids; - mlir::DenseI64ArrayAttr outBlkSizes; - auto packOp = llvm::dyn_cast(*(op->user_begin())); - if (op->hasOneUse() && packOp && isUnpackPackCompatible(op, packOp)) { - // lower the Unpack and Pack pair - auto outTy = packOp.getOutVec().getType(); - outGrids = outTy.getShape().take_front(2); - outBlkSizes = packOp.getInnerBlocksAttr(); - } else { // lower the Unpack only - auto outTy = op.getOutVec().getType(); - outGrids = llvm::ArrayRef(defautlOutGrids, 2); - auto ctx = op.getContext(); - outBlkSizes = mlir::DenseI64ArrayAttr::get(ctx, outTy.getShape()); - } - - rewriter.setInsertionPoint(op); - auto newOps = lowerUnpackOrPack(rewriter, op->getLoc(), inputs, inBlkSizes, - outBlkSizes, inGrids, outGrids); - - if (op->hasOneUse() && packOp && isUnpackPackCompatible(op, packOp)) { - // lowered Unpack and Pack as pair - rewriter.replaceOp(packOp, newOps); - rewriter.eraseOp(op); - } else { // lowering unpack only - rewriter.replaceOp(op, newOps); - } - return mlir::success(); - } -}; - -class SgTilePackOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(xetile::TilePackOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto input = op.getInVec(); - auto defOp = input.getDefiningOp(); - // Unpack and Pack appeared as a pair, it should be handled - // by UnpackOpPattern in this case. - if (defOp && defOp->hasOneUse() && isUnpackPackCompatible(defOp, op)) - return mlir::failure(); - - auto inTy = op.getInVec().getType(); - auto inGrids = llvm::SmallVector({1, 1}); - auto inBlkSizes = - mlir::DenseI64ArrayAttr::get(op.getContext(), inTy.getShape()); - - auto outTy = op.getOutVec().getType(); - auto outGrids = outTy.getShape().take_front(2); - auto outBlkSizes = op.getInnerBlocksAttr(); - - rewriter.setInsertionPoint(op); - auto newOps = lowerUnpackOrPack(rewriter, op->getLoc(), {input}, inBlkSizes, - outBlkSizes, inGrids, outGrids); - - // it is simple one-to-one mapping - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -// A helper to compute the right array length given the inner block width, -// and the tile width, as well as the element type. Both inner block width -// and tile width are in number of elements. It is computed based on hardware -// constraints (on PVC): array_length * inner_block_width * sizeof(elemTy) <= -// 256 bits. So, if tile width is larger than 256/sizeof(elemTy), the maximum -// supported array_length will be used. -// When array_length > 1 is specified, sub-GRF sized blocks are loaded into -// separate GRFs. We do not handle that yet, and we may not really "want" to: -// We would waste GRFs. If multiple blocks (e.g., <1x16xf16, array_length=2>) -// fit into one GRF, let them. -int getBlockArrayLength(mlir::Operation *op, mlir::Type elemTy, int innerHeight, - int inner_block_width, int tile_width) { - auto uArch = std::make_shared(); - auto elemBits = elemTy.getIntOrFloatBitWidth(); - auto params = uArch->get2DLoadConfig(op, elemBits, false, false); - assert(mlir::succeeded(params) && "Invalid Config Params"); - // Do not let an inner block get array_length'ed to blocks finer than one GRF. - if (innerHeight * inner_block_width * elemBits <= - uArch->getOneGRFSizeBits()) { - return 1; - } - llvm::SmallVector supportedArrLen = params->array_length; - const int maxBlockWidth = std::min(params->restriction, tile_width); - - int result = 1; - for (auto len : supportedArrLen) { - if (len * inner_block_width <= maxBlockWidth) - result = len; - } - return result; -} - -// It rewrites a XeTile::init_tile into one or more mlir::xegpu::create_nd_desc -// It is one of start points of generating 1:N values. -class SgInitTileOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(xetile::InitTileOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - mlir::Value source = op.getSource(); - auto tileTy = op.getType(); - auto innerBlocks = tileTy.getInnerBlocks(); - auto shape = llvm::to_vector(tileTy.getShape()); - auto indexType = rewriter.getIndexType(); - - auto MemorySpace = op.getSourceMemorySpaceAsInt() == 3 - ? mlir::xegpu::MemorySpace::SLM - : mlir::xegpu::MemorySpace::Global; - - if (tileTy.getRank() != 2) - return op.emitOpError("The tile shape should be 2D."); - - if (!innerBlocks || innerBlocks.size() != 2) - return op.emitOpError("Missing valid innerBlock for the tile in op."); - - // Need to make a copy, so we can swap values. - auto innerBlk = llvm::to_vector(innerBlocks.asArrayRef()); - - // using array_length for load if dim1 of innerBlocks is smaller than - // dim1 of shape. - auto elemTy = tileTy.getElementType(); - - llvm::SmallVector xegpuOps; - // scattered tiles are lowered into create_tdesc ops with chunk_size = 1. - if (tileTy.getScatterAttr() == mlir::BoolAttr::get(op.getContext(), true)) { - llvm::SmallVector grids( - {shape[0] / innerBlk[0], shape[1] / innerBlk[1]}); - auto elems = innerBlk[0] * innerBlk[1]; - // TODO: get this from uArch. 32 is the max number of SIMD lanes. - assert(elems <= 32 && "Scattered tile size should be <= 32"); - mlir::xegpu::SGMapAttr sgMap = nullptr; - if (auto attr = tileTy.getSgMap()) { - llvm::SmallVector layout( - attr.getWiLayout().asArrayRef().begin(), - attr.getWiLayout().asArrayRef().end()); - llvm::SmallVector data(attr.getWiData().asArrayRef().begin(), - attr.getWiData().asArrayRef().end()); - sgMap = mlir::xegpu::SGMapAttr::get(op.getContext(), layout, data); - } - auto tdescTy = mlir::xegpu::TensorDescType::get( - elems, elemTy, 1 /* chunk_size */, MemorySpace, sgMap); - auto indiceTy = mlir::VectorType::get(elems, indexType); - auto indices = adaptor.getIndices(); - for (int64_t i = 0; i < grids[0]; i++) { - for (int64_t j = 0; j < grids[1]; j++) { - auto indice = indices[i * grids[1] + j]; - indice = rewriter.create(loc, indiceTy, indice); - auto createOp = rewriter.create( - loc, tdescTy, source, indice); - xegpuOps.push_back(createOp); - } - } - } else { - auto array_length = isForLoad(op) && shape[1] > innerBlk[1] - ? getBlockArrayLength(op, elemTy, innerBlk[0], - innerBlk[1], shape[1]) - : 1; - auto width = array_length * innerBlk[1]; - - llvm::SmallVector blocks( - {shape[0] / innerBlk[0], shape[1] / width}); - - llvm::SmallVector offsets; - auto staticOffsets = op.getStaticOffsets(); - auto dynamicOffsets = op.getOffsets(); - for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) { - if (mlir::ShapedType::isDynamic(staticOffsets[i])) { - offsets.push_back(dynamicOffsets[j++]); - } else { - offsets.push_back(rewriter.create( - op.getLoc(), rewriter.getIndexAttr(staticOffsets[i]))); - } - } - - // For col-major memref initial offsets need to be swapped. - auto offsetsY = offsets.pop_back_val(); - auto offsetsX = offsets.pop_back_val(); - - auto tDescTy = mlir::xegpu::TensorDescType::get( - innerBlk, elemTy, array_length, true /*boundary_check*/, MemorySpace); - - auto createIndexConstant = [&](mlir::Type type, int64_t value) { - auto attr = rewriter.getIndexAttr(value); - return rewriter.create(loc, type, attr); - }; - - rewriter.setInsertionPoint(op); - xegpuOps.resize(blocks[0] * blocks[1]); - for (int i = 0; i < blocks[0]; i++) { - for (int j = 0; j < blocks[1]; j++) { - auto subOffX = createIndexConstant(indexType, (innerBlk[0] * i)); - auto subOffY = createIndexConstant(indexType, (width * j)); - auto tDescOffsetX = rewriter.createOrFold( - loc, subOffX, offsetsX); - auto tDescOffsetY = rewriter.createOrFold( - loc, subOffY, offsetsY); - mlir::SmallVector tDescOffsets = - llvm::to_vector<4>(llvm::map_range( - offsets, - [](mlir::Value v) -> mlir::OpFoldResult { return v; })); - tDescOffsets.push_back(tDescOffsetX); - tDescOffsets.push_back(tDescOffsetY); - - // Handle memref source. - if (auto MemRefTypedSource = - mlir::dyn_cast>(source)) { - // Hnadle the case where the shape is static. - if (MemRefTypedSource.getType().hasStaticShape()) { - auto createNdOp = rewriter.create( - op.getLoc(), tDescTy /*resultTy*/, - MemRefTypedSource - /*source*/, - tDescOffsets /*offsets*/); - - xegpuOps[blocks[1] * i + j] = createNdOp; - } else { - // Handle the case where the shape is dynamic. - auto createNdOp = rewriter.create( - loc, tDescTy, MemRefTypedSource, tDescOffsets, - op.getMixedSizes(), op.getMixedStrides()); - xegpuOps[blocks[1] * i + j] = createNdOp; - } - } else if (auto intSourceType = - mlir::dyn_cast>( - source)) { - // Handle the case where the source is an integer. - auto createNdOp = rewriter.create( - loc, tDescTy, intSourceType, tDescOffsets, op.getMixedSizes(), - op.getMixedStrides()); - xegpuOps[blocks[1] * i + j] = createNdOp; - } else { - return mlir::failure(); - } - } - } - } - - rewriter.replaceOp(op, xegpuOps); - return mlir::success(); - } -}; - -static mlir::xegpu::CachePolicy -translateCachePolicy(imex::xetile::CachePolicyAttr val, - mlir::xegpu::CachePolicy defaultVal) { - if (!val) - return defaultVal; - - switch (val.getValue()) { - case imex::xetile::CachePolicy::CACHED: - return mlir::xegpu::CachePolicy::CACHED; - case imex::xetile::CachePolicy::UNCACHED: - return mlir::xegpu::CachePolicy::UNCACHED; - case imex::xetile::CachePolicy::STREAMING: - return mlir::xegpu::CachePolicy::STREAMING; - case imex::xetile::CachePolicy::READ_INVALIDATE: - return mlir::xegpu::CachePolicy::READ_INVALIDATE; - case imex::xetile::CachePolicy::WRITE_BACK: - return mlir::xegpu::CachePolicy::WRITE_BACK; - case imex::xetile::CachePolicy::WRITE_THROUGH: - return mlir::xegpu::CachePolicy::WRITE_THROUGH; - } - llvm_unreachable("Invalid CachePolicy value"); -} - -template -static auto getCachePolicy(OpTy op, mlir::xegpu::CachePolicy defaultVal = - mlir::xegpu::CachePolicy::CACHED) { - - auto getCachePolicyAttr = [&](imex::xetile::CachePolicyAttr val) { - return mlir::xegpu::CachePolicyAttr::get( - op.getContext(), translateCachePolicy(val, defaultVal)); - }; - - auto L1 = getCachePolicyAttr(op.getL1HintAttr()); - auto L2 = getCachePolicyAttr(op.getL2HintAttr()); - auto L3 = getCachePolicyAttr(op.getL3HintAttr()); - - return std::make_tuple(L1, L2, L3); -} - -// It lowers a XeTile::prefetch_tile into one or more mlir::xegpu::prefetch_2d. -// The adaptor will provide the set of xegpu.create_nd_desc lowered for -// its input tile. -struct SgPrefetchTileOpPattern - : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - ::mlir::LogicalResult - matchAndRewrite(xetile::PrefetchTileOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto tileTy = op.getTile().getType(); - auto tiles = adaptor.getTile(); - auto innerBlocks = tileTy.getInnerBlocks(); - - if (tileTy.getRank() != 2) - return mlir::failure(); - - if (!innerBlocks || innerBlocks.size() != 2) - return mlir::failure(); - - auto shape = tileTy.getShape(); - auto expectedNumTensorDescs = - (shape[0] / innerBlocks[0]) * (shape[1] / innerBlocks[1]); - if (expectedNumTensorDescs != static_cast(tiles.size())) { - op.emitOpError("Failed to lower LoadTileOp because shape[0] * shape[1] " - "!= sources.size()."); - return mlir::failure(); - } - - auto [L1, L2, L3] = getCachePolicy(op); - - for (auto tile : tiles) { - rewriter.create(op.getLoc(), tile, L1, L2, L3); - } - - rewriter.eraseOp(op); - - return mlir::success(); - } -}; - -// It lowers XeTile::load_tile into one or more mlir::xegpu::load_2d -// The adaptor will provide the set of xegpu.create_nd_desc lowered for -// its input tile. -struct SgLoadTileOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(xetile::LoadTileOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto tileTy = op.getSource().getType(); - auto blockSZ = tileTy.getInnerBlocks(); - - // It expects the tile has been tiled using blocking pass - if (!blockSZ) - return mlir::failure(); - - auto elemTy = tileTy.getElementType(); - auto sources = adaptor.getSource(); - - auto [L1, L2, L3] = getCachePolicy(op); - - // The tile is in col-major order, which should be canonicalized to - // row-major in canonicalization pass. - auto srcOrder = tileTy.getOrder(); - if (srcOrder.asArrayRef() != mlir::ArrayRef({1, 0})) - return mlir::failure(); - - rewriter.setInsertionPoint(op); - llvm::SmallVector<::mlir::Value> xegpuOps; - for (auto src : sources) { - auto tdescTy = llvm::dyn_cast(src.getType()); - assert(tdescTy && "Expecting a TensorDescType value for load_tile."); - auto shape = tdescTy.getShape().vec(); - auto array_length = tdescTy.getArrayLength(); - - if (array_length != 1) - shape.insert(shape.begin(), array_length); - - auto vectorTy = mlir::VectorType::get(shape, elemTy); - auto ldOp = rewriter.create( - op.getLoc(), vectorTy, src, nullptr, nullptr, nullptr, L1, L2, L3); - if (array_length == 1) { - xegpuOps.push_back(ldOp); - } else { - for (auto i = 0; i < array_length; i++) { - auto extractOp = rewriter.create(op.getLoc(), ldOp, i); - xegpuOps.push_back(extractOp); - } - } - } - - rewriter.replaceOp(op, xegpuOps); - return mlir::success(); - } -}; - -// It lowers XeTile::load into one ore more mlir::xegpu::load with chunk_size=1. -// since xetile::load typically works on 2D representation of the tile, while -// mlir::xegpu::load works on 1D representation, shapecast is used to convert -// vector type operands to 1D representation. -struct SgLoadGatherOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(xetile::LoadGatherOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto tiles = adaptor.getTile(); - auto masks = adaptor.getMask(); - auto tileTy = op.getTile().getType(); - auto innerBlk = tileTy.getInnerBlocks(); - auto resTy = - mlir::VectorType::get(innerBlk.asArrayRef(), tileTy.getElementType()); - auto vecTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1], - tileTy.getElementType()); - auto maskTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1], - rewriter.getIntegerType(1)); - llvm::SmallVector xegpuOps; - auto transposeAttr = mlir::UnitAttr(); - auto [L1, L2, L3] = getCachePolicy(op); - for (auto [t, m] : llvm::zip(tiles, masks)) { - m = rewriter.create(op.getLoc(), maskTy, m); - auto ldOp = rewriter.create( - op.getLoc(), vecTy, t, m, transposeAttr, L1, L2, L3); - auto v = rewriter.create(op.getLoc(), resTy, ldOp); - xegpuOps.push_back(v); - } - rewriter.replaceOp(op, xegpuOps); - return mlir::success(); - } -}; - -// It lowers a XeTile::store_tile into one or more mlir::xegpu::store_2d -// The adaptor will provide the set of xegpu.create_nd_desc lowered for -// its input tile, and similar to its input vector value. -struct SgStoreTileOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - ::mlir::LogicalResult - matchAndRewrite(xetile::StoreTileOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto tiles = adaptor.getTile(); - auto values = adaptor.getValue(); - - if (tiles.size() != values.size()) { - return op.emitOpError("[Failed to lower the StoreOp]") - << "tile and value size doesn't match." - << "tiles: " << tiles.size() << ", " - << "values: " << values.size() << "\n"; - } - - auto [L1, L2, L3] = - getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK); - - for (size_t i = 0; i < tiles.size(); i++) - rewriter.create(op.getLoc(), values[i], tiles[i], - L1, L2, L3); - - rewriter.eraseOp(op); - return ::mlir::success(); - } -}; - -// It lowers XeTile::store into one ore more mlir::xegpu::store with -// chunk_size=1. Similar to xetile::load, shapecast is used to convert vector -// type operands to 1D representation. -struct SgStoreScatterOpPattern - : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(xetile::StoreScatterOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto values = adaptor.getValue(); - auto tdescs = adaptor.getTile(); - auto masks = adaptor.getMask(); - - auto tileTy = op.getTile().getType(); - auto innerBlk = tileTy.getInnerBlocks(); - auto vecTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1], - tileTy.getElementType()); - auto maskTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1], - rewriter.getIntegerType(1)); - auto transposeAttr = mlir::UnitAttr(); - auto [L1, L2, L3] = - getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK); - for (auto [v, t, m] : llvm::zip(values, tdescs, masks)) { - m = rewriter.create(op.getLoc(), maskTy, m); - v = rewriter.create(op.getLoc(), vecTy, v); - rewriter.create(op.getLoc(), v, t, m, - transposeAttr, L1, L2, L3); - } - rewriter.eraseOp(op); - return mlir::success(); - } -}; - -// It lowers a XeTile::tile_mma into one or more mlir::xegpu::dpas -// The adaptor provides new inputs for each old input. -struct SgTileMMAOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - ::mlir::LogicalResult - matchAndRewrite(xetile::TileMMAOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - - auto aShape = op.getAType().getShape(); - auto bShape = op.getBType().getShape(); - - if (aShape.size() != 4 || bShape.size() != 4) { - op.emitOpError() << "Operand A and B for mma should be 4d.\n"; - return mlir::failure(); - } - - if (aShape[3] != bShape[2] || aShape[1] != bShape[0]) { - op.emitOpError() << "A and B size doesn't match. A should be m x k, and " - "B should be k x n"; - return mlir::failure(); - } - - uint64_t M = aShape[0]; - uint64_t K = aShape[1]; - uint64_t N = bShape[1]; - - auto loc = op.getLoc(); - auto AValues = adaptor.getA(); - auto BValues = adaptor.getB(); - auto CValues = adaptor.getC(); - - auto elemTy = op.getOutput().getType().getElementType(); - auto subCTy = mlir::VectorType::get({aShape[2], bShape[3]}, elemTy); - - mlir::SmallVector xegpuOps; - for (uint64_t i = 0; i < M; i++) { - for (uint64_t j = 0; j < N; j++) { - mlir::Value tmpC; - if (op.getC()) - tmpC = CValues[i * N + j]; // init with acc - for (uint64_t k = 0; k < K; k++) { - auto aVec = AValues[i * K + k]; - auto bVec = BValues[k * N + j]; - tmpC = rewriter.create( - loc, subCTy /*result*/, aVec /*lhs*/, bVec /*rhs*/, tmpC /*acc*/); - } - xegpuOps.push_back(tmpC); - } - } - rewriter.replaceOp(op, xegpuOps); - return mlir::success(); - } -}; - -struct SgUpdateTileOffsetOpPattern - : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(xetile::UpdateTileOffsetOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto tileTy = op.getTile().getType(); - auto tdescs = adaptor.getTile(); - llvm::SmallVector newOps; - if (tileTy.getScatterAttr() == mlir::BoolAttr::get(op.getContext(), true)) { - auto indices = adaptor.getIndices(); - for (auto [tdesc, idx] : llvm::zip_equal(tdescs, indices)) { - auto type = mlir::cast(idx.getType()); - auto flatTy = - mlir::VectorType::get(type.getNumElements(), type.getElementType()); - idx = rewriter.create(op.getLoc(), flatTy, idx); - auto xegpuTile = rewriter.create( - op.getLoc(), tdesc.getType(), tdesc, idx); - newOps.push_back(xegpuTile); - } - } else { - auto offsetX = op.getOffsetX(); - auto offsetY = op.getOffsetY(); - int64_t kDynamics[2] = {mlir::ShapedType::kDynamic, - mlir::ShapedType::kDynamic}; - for (const auto &tdesc : tdescs) { - // if the traversal is col-major, we need to reverse the offsets at - // XeGPU level because only row-major traversal is supported. - auto xegpuTile = rewriter.create( - op.getLoc(), tdesc.getType(), tdesc, - mlir::ValueRange({offsetX, offsetY}), - llvm::ArrayRef(kDynamics, 2)); - newOps.push_back(xegpuTile); - } - } - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -extern llvm::SmallVector -lowerOuterReduction(mlir::ValueRange sources, llvm::ArrayRef shape, - mlir::vector::CombiningKind kind, mlir::Location loc, - mlir::Type elemTy, XeOneToNPatternRewriter &rewriter); - -extern llvm::SmallVector -lowerInnerReductionWithIntraVectorShuffles(mlir::ValueRange sources, - llvm::ArrayRef shape, - mlir::vector::CombiningKind kind, - mlir::Location loc, - mlir::Type elemTy, - XeOneToNPatternRewriter &rewriter); - -extern llvm::SmallVector lowerInnerReductionWithVectorReduction( - mlir::ValueRange sources, llvm::ArrayRef shape, - mlir::vector::CombiningKind kind, mlir::Location loc, mlir::Type elemTy, - XeOneToNPatternRewriter &rewriter); - -struct SgTileReductionOpPattern - : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(xetile::ReductionOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto srcTy = op.getSource().getType(); - auto elemTy = srcTy.getElementType(); - auto dims = op.getReductionDims(); - // its input should be a 4D vector, and has 2 reduction dims, - // otherwise run the blocking pass first. - if (dims.size() != 2 || srcTy.getRank() != 4) - return mlir::failure(); - - auto loc = op.getLoc(); - auto shape = srcTy.getShape(); - auto sources = adaptor.getSource(); - - rewriter.setInsertionPoint(op); - // doing reduction on outer dimension - if (dims[0] == 0 && dims[1] == 2) { - auto intermediates = lowerOuterReduction(sources, shape, op.getKind(), - loc, elemTy, rewriter); - rewriter.replaceOp(op, intermediates); - return mlir::success(); - } - - // doing reduction on inner dimension, otherwise it is not supported. - assert(dims[0] == 1 && dims[1] == 3 && "unsupported reduction operation."); - - auto intermediates = lowerInnerReductionWithIntraVectorShuffles( - sources, shape, op.getKind(), loc, elemTy, rewriter); - llvm::SmallVector newOps; - { - // intermediate is a vector of values with type of vector - // (where n is max of min(shape[0]/2,16) and 1), - // each element is the reduced value for a row. For example, - // for vector<32x4x1x16> with reduction on dim 1 and dim 3. the - // intermediate values will be two values of vector<16xf16>. The values - // in the first vector represents the reduction result of the first 16 - // rows. Here we will extract each value and splat it to a vector<1x1xf16> - // as results to their consumers. - for (auto v : intermediates) { - auto targetTy = mlir::VectorType::get({1, 1}, elemTy); - auto vecTy = mlir::dyn_cast(v.getType()); - assert(vecTy && "expect vector type"); - for (auto i = 0; i < vecTy.getShape()[0]; i++) { - auto pos = rewriter.create( - op.getLoc(), rewriter.getI32IntegerAttr(i)); - auto extractOp = - rewriter.create(loc, v, pos); - auto splatOp = rewriter.create( - op.getLoc(), targetTy, extractOp); - newOps.push_back(splatOp); - } - } - } - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -// A transpose op for a larger vector will be lowered into multiple -// explicit transpose ops for smaller vectors and the order/use of -// these these new transpose ops are transposed too. For example: -// xetile.transpose %1, [1, 0]: vector<16x48> -> vector<48x16> will -// be lowered into 6 transpose ops on vector<8x16> assuming the smaller -// vector shape is 8x16. So it will from: -// |--------------|--------------|--------------| -// | 0: 8x16 | 1: 8x16 | 2: 8x16 | -// |--------------|--------------|--------------| -// | 3: 8x16 | 4: 8x16 | 5: 8x16 | -// |--------------|--------------|--------------| -// -// to: -// -// |--------------|--------------| -// | 0: 16x8 | 3: 16x8 | -// |--------------|--------------| -// | 1: 16x8 | 4: 16x8 | -// |--------------|--------------| -// | 2: 16x8 | 5: 16x8 | -// |--------------|--------------| -// (the number before `:` is the id of the block) - -template -struct SgTransposeOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - using RangeT = llvm::ArrayRef; - using OpAdaptor = typename OpTy::template GenericAdaptor; - - mlir::LogicalResult - matchAndRewrite(OpTy op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto resType = op.getResult().getType(); - if (resType.getRank() != 4) - return ((mlir::PatternRewriter &)rewriter) - .notifyMatchFailure(op, "Expected a 4D vector"); - - auto srcVectors = adaptor.getVector(); - auto shape = resType.getShape(); - if (shape[0] * shape[1] != static_cast(srcVectors.size())) - return ((mlir::PatternRewriter &)rewriter) - .notifyMatchFailure(op, "Invalid shape"); - - auto permutation = op.getPermutation(); - auto outerPerm = permutation.take_front(2); - int64_t innerPerm[2] = {permutation[2] - 2, permutation[3] - 2}; - - auto newResType = - mlir::VectorType::get(shape.take_back(2), resType.getElementType()); - - mlir::Location loc = op.getLoc(); - llvm::SmallVector results; - for (auto i : llvm::seq(0, shape[0])) { - for (auto j : llvm::seq(0, shape[1])) { - size_t ij[2] = {i, j}; - auto idx = ij[outerPerm[1]] + shape[outerPerm[1]] * ij[outerPerm[0]]; - mlir::Value arg = srcVectors[idx]; - mlir::Value res = rewriter.create( - loc, newResType, arg, innerPerm); - results.emplace_back(res); - } - } - rewriter.replaceOp(op, results); - return mlir::success(); - } -}; - -bool isLegalElementWiseOp(mlir::Operation *op) { - // Check that all results are of vector type and has rank > 2. - auto numResults = op->getNumResults(); - for (unsigned i = 0; i < numResults; i++) { - auto res = op->getResult(i); - auto resType = mlir::dyn_cast(res.getType()); - if (!resType || resType.getRank() <= 2) - return true; - } - return false; -} - -// Convert a llvm::ArrayRef of operands range, where each range consists of a -// list of same operand, To a llvm::ArrayRef of operand range, where the range -// is created from element from each list of operand. - -llvm::SmallVector> -verticalToHorizontalToValueRange(llvm::ArrayRef operands) { - auto numBlocks = operands[0].size(); - llvm::SmallVector> values(numBlocks); - for (auto operand : operands) { - for (unsigned i = 0; i < operand.size(); i++) { - values[i].push_back(operand[i]); - } - } - return values; -} - -template -struct ElementWiseOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - using RangeT = llvm::ArrayRef; - using OpAdaptor = typename Op::template GenericAdaptor; - - mlir::LogicalResult - matchAndRewrite(Op op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto numResults = op.getOperation()->getNumResults(); - - llvm::SmallVector newResultTypes; - for (unsigned i = 0; i < numResults; i++) { - mlir::Value result = ::llvm::cast<::mlir::TypedValue<::mlir::Type>>( - *op.getODSResults(i).begin()); - auto resultType = mlir::dyn_cast(result.getType()); - // Check if the result types are 4D vectors, if any of the result type is - // not a 4D vector, return failure. - if (!resultType || resultType.getRank() != 4) - return mlir::failure(); - - auto shape = resultType.getShape(); - // Get the new result type, this is the type of the result of the new - // blocked op that works on the 2-D vector. - auto vecTy = mlir::VectorType::get({shape[2], shape[3]}, - resultType.getElementType()); - newResultTypes.push_back(vecTy); - } - - // Get the operands - auto operands = adaptor.getOperands(); - - // The operands are in the form of llvm::ArrayRef, where - // each ValueRange consists of a list of same operand. However, to use the - // operands in the new op, the operands of the same block should together in - // a ValueRange (Vector of operands of the each block should be in a - // vector). - auto horizontalOperands = verticalToHorizontalToValueRange(operands); - // Get the attributes - auto attributes = op.getOperation()->getAttrs(); - Op newOp; - llvm::SmallVector newOps; - for (auto newOperands : horizontalOperands) { - // We are using the generic builder that is supported by all ops. - // static void build(::mlir::OpBuilder &, ::mlir::OperationState - // &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, - // ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); - newOp = rewriter.create(op.getLoc(), newResultTypes, newOperands, - attributes); - for (unsigned i = 0; i < numResults; i++) { - mlir::Value result = ::llvm::cast<::mlir::TypedValue<::mlir::Type>>( - *newOp.getODSResults(i).begin()); - newOps.push_back(result); - } - } - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -struct SgBroadcastOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(xetile::BroadcastOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto resultTy = op.getResult().getType(); - auto resultShape = resultTy.getShape(); - auto dstType = mlir::VectorType::get(resultShape.take_back(2), - resultTy.getElementType()); - - auto dim = op.getBroadcastDim(); - assert(dim.size() == 2 && "Expecting 2D broadcast dim."); - - llvm::SmallVector newOps; - if (dim[0] == 0 && dim[1] == 2) { - // clang-format off - // broadcast along the first dim, we simply need to replicate the source. - // For example, for - // xetile.broadcast %src [0]: vector<1x64xf16> -> vector<32x64xf16> - // After blocking (assuming block size = [1, 16]) and lowering to xegpu, - // its input values (source) will be a vector of values with type <1x16xf16> - // and size = 4, which can be viewed as: - // | vector<1x16xf16> | vector<1x16xf16> | vector<1x16xf16> | vector<1x16xf16> | - // so we need to replicate it 32 times (resultShape[0]) to get final results: - // 0: | vector<1x16xf16> | vector<1x16xf16> | vector<1x16xf16> | vector<1x16xf16> | - // ...... - // 31: | vector<1x16xf16> | vector<1x16xf16> | vector<1x16xf16> | vector<1x16xf16> | - // clang-format on - for (auto i = 0; i < resultShape[0]; i++) - newOps.append(adaptor.getSource().begin(), adaptor.getSource().end()); - } else if (dim[0] == 1 && dim[1] == 3) { - // clang-format off - // broadcast along the second dim, we use both splatOp and replicates. - // For example: xetile.broadcast %src [1]: vector<32x1xf16> -> - // vector<32x64xf16>. After blocking (assuming block size = [1, 16]) and - // lowering to xegpu, the input value (source) will be a vector of values - // with type <1x1xf16> and size = 32, which can be viewed as: - // 0: | vector<1x1xf16> | - // ... - // 31: | vector<1x1xf16> | - // first, splatOp is used to broadcast the value of vector<1x1xf16> to - // vector<1x16xf16> - // 0: | vector<1x16xf16> | - // ... - // 31: | vector<1x16xf16> | - // and then we replicate the splatOp 4 times (resultShape[1]) to get the - // final results: - // 0: | vector<1x16xf16> | vector<1x16xf16> | vector<1x16xf16> | vector<1x16xf16> | - // ... - // 31: | vector<1x16xf16> | vector<1x16xf16> | vector<1x16xf16> | vector<1x16xf16> | - // clang-format on - for (auto src : adaptor.getSource()) { - auto ty = mlir::dyn_cast(src.getType()); - assert(ty && ty.getNumElements() == 1 && - "Expecting a <1x1xelemty> vector type."); - auto ext = rewriter.create( - op.getLoc(), src, llvm::ArrayRef({0, 0})); - auto splatOp = - rewriter.create(op.getLoc(), dstType, ext); - newOps.append(resultShape[1], splatOp); - } - } else { - return mlir::failure(); - } - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -struct SgVectorCreateMaskOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(CreateMaskOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto res = op.getResult(); - auto resType = mlir::dyn_cast(res.getType()); - // 4D vector ops only. - if (resType.getRank() != 4) { - op.emitOpError() << "type is not 4D vector"; - return mlir::failure(); - } - - mlir::Location loc = op->getLoc(); - auto shape = resType.getShape(); - if (shape[2] != 1) { - op.emitOpError() << "Unsupported inner block sizes"; - return mlir::failure(); - } - auto newTy = - mlir::VectorType::get({shape[2], shape[3]}, resType.getElementType()); - llvm::SmallVector newOps; - mlir::Value ub0 = adaptor.getOperands()[0][0]; - auto constDef = ub0.getDefiningOp(); - if (constDef && constDef.value() == shape[0]) { - // Case 1: all rows are enabled. - // See assumptions about the supported create_mask op in - // VectorCreateMaskOpPattern in xetile blocking pass. The second and - // forth operands are the same. This value is the mask of the inner - // dimension of the original shape. Different masks are created based on - // the new inner dimension size. - auto one = rewriter.create(loc, 1); - llvm::SmallVector> newOperands; - mlir::Value mask = adaptor.getOperands()[3][0]; - auto innerDimSize = - rewriter.create(loc, shape[3]); - for (int j = 0; j < shape[1]; ++j) { - newOperands.push_back({one, mask}); - mask = rewriter.create(loc, mask, innerDimSize); - } - - for (int i = 0; i < shape[0]; ++i) { - for (int j = 0; j < shape[1]; ++j) { - auto newOp = - rewriter.create(op.getLoc(), newTy, newOperands[j]); - newOps.push_back(newOp); - } - } - - } else { - // Case 2: all columns are enabled. - for (int i = 0; i < shape[0]; ++i) { - auto elemIndex = rewriter.create(loc, i); - auto cmp = rewriter.create( - loc, mlir::arith::CmpIPredicate::slt, elemIndex, ub0); - auto bcast = rewriter.create(loc, newTy, cmp); - for (int j = 0; j < shape[1]; ++j) - newOps.push_back(bcast); - } - } - - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -struct SgVectorSplatOpPattern : public XeOneToNConversion { - using XeOneToNConversion::XeOneToNConversion; - - mlir::LogicalResult - matchAndRewrite(SplatOp op, OpAdaptor adaptor, - XeOneToNPatternRewriter &rewriter) const override { - auto type = op.getAggregate().getType(); - if (type.getRank() != 4) - return mlir::failure(); - auto shape = type.getShape(); - auto newType = - mlir::VectorType::get(shape.take_back(2), type.getElementType()); - auto newOp = rewriter.create(op.getLoc(), op.getInput(), newType); - llvm::SmallVector newOps(shape[0] * shape[1], newOp); - rewriter.replaceOp(op, newOps); - return mlir::success(); - } -}; - -void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter, - mlir::RewritePatternSet &patterns, - TileUsageAnalysis &analysis) { - patterns - .add, - SgTransposeOpPattern, SgBroadcastOpPattern, - SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>( - patterns.getContext(), converter, analysis); - - // Element-wise math operations - patterns.add, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern>(patterns.getContext(), - converter, analysis); - - // Arithmetic operations - patterns.add, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern>(patterns.getContext(), - converter, analysis); - - // Typecast operations - patterns.add, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern, - ElementWiseOpPattern>( - patterns.getContext(), converter, analysis); -} - -} // namespace imex diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.h b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.h deleted file mode 100644 index dc8ec3a15..000000000 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.h +++ /dev/null @@ -1,30 +0,0 @@ -//===- XeTileOpConversion.h - XeTileToXeGPU conversion -------*- C++ -*-===// -// -// Copyright 2022 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines ConversionPatterns for XeTileOps, used in XeTileToXeGPU -/// conversion, converting the XeTile dialect to the XeGPU dialect. -/// -//===----------------------------------------------------------------------===// -#ifndef _XeTileOpConversion_H_INCLUDED_ -#define _XeTileOpConversion_H_INCLUDED_ - -#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h" -#include "imex/Utils/XeArch.h" -namespace imex { - -bool isLegalElementWiseOp(mlir::Operation *op); - -void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter, - mlir::RewritePatternSet &patterns, - TileUsageAnalysis &analysis); - -} // namespace imex - -#endif diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp index a939beae9..cf7e0363e 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp @@ -23,9 +23,6 @@ #include #include -#include "ArithOpConversion.h" -#include "SCFOpConversion.h" -#include "XeTileOpConversion.h" #include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h" #include "imex/Utils/XeArch.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -38,9 +35,6 @@ namespace imex { #include namespace imex { -// TODO: clean up this after consolidation -static bool Enable2DTransform = false; - // Converts an Attribute representing memory space to xegpu::MemorySpaceAttr. // It currently only supports memory space represented as integer attribute. // TODO: improve it to support other types of memory space attributes, e.g., @@ -508,82 +502,6 @@ class XeTileConversionTarget : public mlir::ConversionTarget { } return true; }); - - if (!Enable2DTransform) { - addLegalOp(); - addLegalOp(); - addLegalOp(); - addLegalOp(); - addLegalOp(); - addLegalOp(); - addLegalOp(); - - addDynamicallyLegalDialect( - [&](mlir::Operation *op) { return isLegalArithOp(op); }); - - addDynamicallyLegalDialect( - [&](mlir::Operation *op) { return isLegalSCFOp(op); }); - - // Arith ops, since we support all the arith ops, we can dynamically make - // the whole dialect legal. - addDynamicallyLegalDialect( - [&](mlir::Operation *op) -> std::optional { - return isLegalElementWiseOp(op); - }); - - // Math Ops - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { - return isLegalElementWiseOp(op); - }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { - return isLegalElementWiseOp(op); - }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { - return isLegalElementWiseOp(op); - }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { - return isLegalElementWiseOp(op); - }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { - return isLegalElementWiseOp(op); - }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { - return isLegalElementWiseOp(op); - }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { - return isLegalElementWiseOp(op); - }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { - return isLegalElementWiseOp(op); - }); - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { - return isLegalElementWiseOp(op); - }); - - addDynamicallyLegalOp( - [&](mlir::Operation *op) -> bool { - return isLegalElementWiseOp(op); - }); - - addDynamicallyLegalOp( - [](mlir::vector::TransposeOp op) { - return op.getResult().getType().getRank() == 2; - }); - - addDynamicallyLegalOp( - [&](mlir::vector::SplatOp op) { - return op.getAggregate().getType().getRank() != 4; - }); - } } private: @@ -615,8 +533,6 @@ struct ConvertXeTileToXeGPUPass // convert XeTile to XeGPU uArchInterface = std::make_shared(); else return errorHandler(llvm::Twine("Invalid device: ") + device); - // TODO: cleanup - Enable2DTransform = EnableTransform; return mlir::success(); } @@ -639,94 +555,79 @@ struct ConvertXeTileToXeGPUPass // convert XeTile to XeGPU XeTileConversionTarget target(context, uArchInterface); mlir::RewritePatternSet patterns(&context); - if (!Enable2DTransform) { - auto &analysis = getAnalysis(); - XeOneToNTypeConverter typeConverter(context); - populateXeTileToXeGPUConversionPatterns(typeConverter, patterns, - analysis); - if (mlir::failed( - mlir::applyPartialConversion(mod, target, std::move(patterns)))) - return signalPassFailure(); - } else { - mlir::TypeConverter typeConverter; - - typeConverter.addConversion( - [&](mlir::Type type) -> mlir::Type { return type; }); - - typeConverter.addConversion( - [&](xetile::TileType type) -> mlir::xegpu::TensorDescType { - auto context = type.getContext(); - auto elemTy = type.getElementType(); - auto scatterAttr = type.getScatterAttr(); - bool isScattered = scatterAttr ? scatterAttr.getValue() : false; - - mlir::xegpu::SGMapAttr sgMap = nullptr; - if (auto attr = type.getSgMap()) { - auto layout = - llvm::to_vector_of(attr.getWiLayout().asArrayRef()); - auto data = - llvm::to_vector_of(attr.getWiData().asArrayRef()); - sgMap = mlir::xegpu::SGMapAttr::get(context, layout, data); - } - - auto memSpaceAttr = convertMemorySpace(type.getMemorySpace()); - auto memSpace = memSpaceAttr ? memSpaceAttr.getValue() - : mlir::xegpu::MemorySpace::Global; - - mlir::Attribute encoding; - llvm::SmallVector shape; - if (isScattered) { - // Scattered tile is lowered to scattered tensor_desc with chunk - // size 1. It supports both global memory and shared memory. while - // scattered tile can support 2D shape, scattered tensor_desc only - // support 1D shape. - auto chunkSizeAttr = mlir::IntegerAttr::get( - mlir::IntegerType::get(context, 64), 1); - encoding = mlir::xegpu::ScatterTensorDescAttr::get( - context, memSpaceAttr, chunkSizeAttr); - shape.push_back(type.getNumElements()); - } else if (memSpace == mlir::xegpu::MemorySpace::Global) { - // Blocked tile on global memory is lowered to blocked tensor_desc - // with the same shape. - // TODO: update TileType with array_length and use it here. - auto arrayLenAttr = mlir::IntegerAttr::get( - mlir::IntegerType::get(context, 64), 1); - auto boundaryCheckAttr = mlir::BoolAttr::get(context, true); - encoding = mlir::xegpu::BlockTensorDescAttr::get( - context, memSpaceAttr, arrayLenAttr, boundaryCheckAttr); - shape = llvm::to_vector(type.getShape()); - } else { - // TODO: Lowering strategy for blocked tiles on SLM is not - // finalized yet. - assert(0 && "SLM space for blocked tile is not supported yet."); - } - return mlir::xegpu::TensorDescType::get(context, shape, elemTy, - encoding, sgMap); - }); - - auto materializeWithCast = [&](mlir::OpBuilder &builder, mlir::Type type, - mlir::ValueRange inputs, - mlir::Location loc) -> mlir::Value { - assert(inputs.size() == 1 && "Expecting single input"); - return builder - .create(loc, type, inputs) - .getResult(0); - }; - - typeConverter.addArgumentMaterialization(materializeWithCast); - typeConverter.addTargetMaterialization(materializeWithCast); - typeConverter.addSourceMaterialization(materializeWithCast); - - patterns - .add(typeConverter, - patterns.getContext()); - if (mlir::failed( - mlir::applyPartialConversion(mod, target, std::move(patterns)))) - return signalPassFailure(); - } + mlir::TypeConverter typeConverter; + + typeConverter.addConversion( + [&](mlir::Type type) -> mlir::Type { return type; }); + + typeConverter.addConversion( + [&](xetile::TileType type) -> mlir::xegpu::TensorDescType { + auto context = type.getContext(); + auto elemTy = type.getElementType(); + auto scatterAttr = type.getScatterAttr(); + bool isScattered = scatterAttr ? scatterAttr.getValue() : false; + + mlir::xegpu::SGMapAttr sgMap = nullptr; + if (auto attr = type.getSgMap()) { + auto layout = + llvm::to_vector_of(attr.getWiLayout().asArrayRef()); + auto data = + llvm::to_vector_of(attr.getWiData().asArrayRef()); + sgMap = mlir::xegpu::SGMapAttr::get(context, layout, data); + } + + auto memSpaceAttr = convertMemorySpace(type.getMemorySpace()); + auto memSpace = memSpaceAttr ? memSpaceAttr.getValue() + : mlir::xegpu::MemorySpace::Global; + + mlir::Attribute encoding; + llvm::SmallVector shape; + if (isScattered) { + // Scattered tile is lowered to scattered tensor_desc with chunk + // size 1. It supports both global memory and shared memory. while + // scattered tile can support 2D shape, scattered tensor_desc only + // support 1D shape. + auto chunkSizeAttr = + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), 1); + encoding = mlir::xegpu::ScatterTensorDescAttr::get( + context, memSpaceAttr, chunkSizeAttr); + shape.push_back(type.getNumElements()); + } else if (memSpace == mlir::xegpu::MemorySpace::Global) { + // Blocked tile on global memory is lowered to blocked tensor_desc + // with the same shape. + // TODO: update TileType with array_length and use it here. + auto arrayLenAttr = + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), 1); + auto boundaryCheckAttr = mlir::BoolAttr::get(context, true); + encoding = mlir::xegpu::BlockTensorDescAttr::get( + context, memSpaceAttr, arrayLenAttr, boundaryCheckAttr); + shape = llvm::to_vector(type.getShape()); + } else { + // TODO: Lowering strategy for blocked tiles on SLM is not + // finalized yet. + assert(0 && "SLM space for blocked tile is not supported yet."); + } + return mlir::xegpu::TensorDescType::get(context, shape, elemTy, + encoding, sgMap); + }); + + auto materializeWithCast = [&](mlir::OpBuilder &builder, mlir::Type type, + mlir::ValueRange inputs, + mlir::Location loc) -> mlir::Value { + assert(inputs.size() == 1 && "Expecting single input"); + return builder.create(loc, type, inputs) + .getResult(0); + }; + + typeConverter.addArgumentMaterialization(materializeWithCast); + typeConverter.addTargetMaterialization(materializeWithCast); + typeConverter.addSourceMaterialization(materializeWithCast); + + populateXeTileToXeGPUConversionPatterns(typeConverter, patterns); + + if (mlir::failed( + mlir::applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); } private: @@ -735,11 +636,12 @@ struct ConvertXeTileToXeGPUPass // convert XeTile to XeGPU /// Populate the given list with patterns that convert XeTile to XeGPU void populateXeTileToXeGPUConversionPatterns( - imex::XeOneToNTypeConverter &converter, mlir::RewritePatternSet &patterns, - TileUsageAnalysis &analysis) { - populateSCFOpConversionPatterns(converter, patterns, analysis); - populateArithOpConversionPatterns(converter, patterns, analysis); - populateXeTileOpConversionPatterns(converter, patterns, analysis); + mlir::TypeConverter &converter, mlir::RewritePatternSet &patterns) { + patterns.add(converter, + patterns.getContext()); } /// Create a pass that convert XeTile to XeGPU diff --git a/lib/Dialect/XeTile/Transforms/Blocking.cpp b/lib/Dialect/XeTile/Transforms/Blocking.cpp index 8e66042f2..2d36cd00a 100644 --- a/lib/Dialect/XeTile/Transforms/Blocking.cpp +++ b/lib/Dialect/XeTile/Transforms/Blocking.cpp @@ -611,7 +611,6 @@ class RewriteInitTileOp return mlir::failure(); llvm::SmallVector newOps; - // handle scattered tiles. if (tileTy.getScatterAttr() == mlir::BoolAttr::get(ctx, true)) { auto indices = op.getIndices(); @@ -636,7 +635,7 @@ class RewriteInitTileOp auto width = blockSize[1]; llvm::SmallVector grids( {shape[0] / blockSize[0], shape[1] / width}); - llvm::SmallVector offsets = op.getMixedOffsets(); + auto mixedOffsets = op.getMixedOffsets(); auto addi = [&](mlir::OpFoldResult a, int64_t b) -> mlir::Value { if (mlir::isa(a)) { @@ -652,16 +651,22 @@ class RewriteInitTileOp } }; + // For n-D memrefs where n > 2, we need to handle the last two + // dimensions, and keep the first n-2 dimensions as is. + int64_t x = mixedOffsets.size() - 2; + int64_t y = mixedOffsets.size() - 1; + mlir::OpFoldResult oldX = mixedOffsets[x]; + mlir::OpFoldResult oldY = mixedOffsets[y]; + for (int64_t i = 0; i < grids[0]; i++) { for (int64_t j = 0; j < grids[1]; j++) { auto subOffX = blockSize[0] * i; auto subOffY = width * j; - auto X = addi(offsets[0], subOffX); - auto Y = addi(offsets[1], subOffY); - llvm::SmallVector ofrs({X, Y}); + mixedOffsets[x] = addi(oldX, subOffX); + mixedOffsets[y] = addi(oldY, subOffY); llvm::SmallVector offsets; llvm::SmallVector constOffsets; - mlir::dispatchIndexOpFoldResults(ofrs, offsets, constOffsets); + mlir::dispatchIndexOpFoldResults(mixedOffsets, offsets, constOffsets); auto constOffsetsAttr = rewriter.getDenseI64ArrayAttr(constOffsets); auto newOp = rewriter.create( loc, newTileTy, op.getSource(), offsets, op.getSizes(), @@ -1197,8 +1202,10 @@ class RewriteTileBroadcastOp } else { return mlir::failure(); } - - return mlir::failure(); + auto castOp = unpackWithUnrealizedCastOp( + newOps, resTy, resBlkSize.asArrayRef(), loc, rewriter); + rewriter.replaceOp(op, castOp); + return mlir::success(); } }; @@ -1349,21 +1356,11 @@ class RewriteSCFForOp auto results = op.getResults(); llvm::SmallVector blockSZs; - // verify the block size of region args and results. They should be the - // same if the result is used outside of the loop. Also, if the type is - // TileType, the init value should have the same block size as the region - // arg, since there is no unpack/pack op for TileType. - for (auto [init, arg, res] : - llvm::zip_equal(initArgs, regionArgs, results)) { - auto initBlock = analysis.getDefBlockSize(init); + // We use region args as anchor. PackOps will be inserted if ncessary + // for each init args, and UnpackOps will be inserted for each argument + // and result. + for (auto arg : regionArgs) { auto argBlock = analysis.getDefBlockSize(arg); - auto resBlock = analysis.getDefBlockSize(res); - - if (mlir::isa(arg.getType()) && initBlock != argBlock) - return rewriter.notifyMatchFailure(op, "Incompatiable blocking size."); - - if (res.isUsedOutsideOfBlock(op.getBody()) && argBlock != resBlock) - return rewriter.notifyMatchFailure(op, "Incompatiable blocking size."); blockSZs.push_back(argBlock); } diff --git a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp index 53d2a0f80..a169c4ac3 100644 --- a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp +++ b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp @@ -587,11 +587,11 @@ void BlockingAnalysisImpl::visitBroadcastOp( if (!lattice.isInitialized()) return; + auto req = lattice.getRequests()[0]; auto dim = dims[0]; Block blockSize; if (dim == 0) { - auto req = lattice.getRequests()[0]; blockSize = Block(1, req[1]); } else if (dim == 1) { blockSize = Block(1, 1); @@ -600,6 +600,10 @@ void BlockingAnalysisImpl::visitBroadcastOp( } auto blockingRequest = BlockingRequests(blockSize, op->getOpOperand(0)); propagateIfChanged(operands[0], operands[0]->join(blockingRequest)); + + // update the def block size for the result value + BlockingRequests &def = getLatticeElement(op.getResult())->getValue(); + def.updateDefBlock(Block(1, req[1])); } void BlockingAnalysisImpl::visitTransposeOp( @@ -689,14 +693,13 @@ void BlockingAnalysisImpl::visitVectorizableOp( // well supported by IGC yet. Using default size (same as CreateMask) // could help to avoid this. Remove it when lowering of create_mask // and IGC get matured. - if (mlir::isa(op) && !Enable2DBlockingTransform) { + if (mlir::isa(op)) { block = Block(1, block[1]); } // elementwise operations are not sensitive to the block size. // It will use the block size requested by its users, except SelectOp - if (lattice.isInitialized() && - (Enable2DBlockingTransform || !mlir::isa(op))) { + if (lattice.isInitialized() && !mlir::isa(op)) { block[0] = 0; for (auto &req : lattice.getRequests()) { block[0] = std::max(block[0], req[0]); @@ -746,12 +749,15 @@ void BlockingAnalysisImpl::visitCreateMaskOp( // [1, subgroupSize] for CreateMaskOp if 2D transform is not enabled. // If 2D transform is enabled, it will aligned with its users. Block block = getInnerBlockSize(op, elemTy, shape); - if (Enable2DBlockingTransform) { - for (auto &req : lattice.getRequests()) { - block[0] = std::max(block[0], req[0]); - block[1] = std::min(block[1], req[1]); - } - } + + // TODO: need to enable the following code after 2D lowering in + // GPUToSPIRV is enabled. + // if (Enable2DBlockingTransform) { + // for (auto &req : lattice.getRequests()) { + // block[0] = std::max(block[0], req[0]); + // block[1] = std::min(block[1], req[1]); + // } + // } def.updateDefBlock(block); } diff --git a/lib/Dialect/XeTile/Transforms/CMakeLists.txt b/lib/Dialect/XeTile/Transforms/CMakeLists.txt index 1200e28fc..5d74fdda2 100644 --- a/lib/Dialect/XeTile/Transforms/CMakeLists.txt +++ b/lib/Dialect/XeTile/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_imex_dialect_library(IMEXXeTileTransforms InitDuplicate.cpp Canonicalization.cpp WgToSg.cpp + XeTileOneToNConversion.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/imex/Dialect/XeTile diff --git a/lib/Dialect/XeTile/Transforms/WgToSg.cpp b/lib/Dialect/XeTile/Transforms/WgToSg.cpp index 8880077ab..6441b1813 100644 --- a/lib/Dialect/XeTile/Transforms/WgToSg.cpp +++ b/lib/Dialect/XeTile/Transforms/WgToSg.cpp @@ -38,9 +38,8 @@ #include -#include "imex/Dialect/XeTile/Transforms/Passes.h" -#include -#include +#include +#include using namespace mlir; using namespace imex; diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp b/lib/Dialect/XeTile/Transforms/XeTileOneToNConversion.cpp similarity index 93% rename from lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp rename to lib/Dialect/XeTile/Transforms/XeTileOneToNConversion.cpp index 25b438430..e185156ce 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp +++ b/lib/Dialect/XeTile/Transforms/XeTileOneToNConversion.cpp @@ -1,5 +1,4 @@ -//===- XeTileToXeGPUConversion.cpp - XeTileToXeGPU conversion -------*- C++ -//-*-===// +//===- XeTileOneToNConversion.cpp -- XeTileOneToNConversion ----*- C++ -*-===// // // Copyright 2022 Intel Corporation // Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. @@ -16,18 +15,19 @@ /// OneToN replace. /// //===----------------------------------------------------------------------===// -#include #include +#include #include #include +#include #include #include #include #include -#include #include #include +#include #include @@ -131,17 +131,6 @@ std::optional XeOneToNTypeConverter::convertVectorType( return std::nullopt; } -// mlir::LogicalResult XeOneToNTypeConverter::computeTypeMapping( -// mlir::ValueRange originalVals, -// llvm::ArrayRef convertedVals, -// mlir::OneToNTypeMapping &resultMap) { -// for (auto [i, val] : llvm::enumerate(convertedVals)) { -// llvm::SmallVector convertedTypes(val.getTypes()); -// resultMap.addInputs(i, convertedTypes); -// } -// return mlir::success(); -// } - // It computes the mapping between types orginal values and // converted values. The standard type conversion method doesn't // work here because a TileType could have multiple decomposions diff --git a/test/Conversion/XeTileToXeGPU/addf.mlir b/test/Conversion/XeTileToXeGPU/addf.mlir deleted file mode 100644 index 9e2455a6b..000000000 --- a/test/Conversion/XeTileToXeGPU/addf.mlir +++ /dev/null @@ -1,214 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu %s -verify-diagnostics -o -| FileCheck %s - gpu.module @test_kernel { - gpu.func @arith_binary_ops() { - //CHECK: %[[c0:.*]] = arith.constant dense<8.999020e-01> : vector<16x16xf16> - //CHECK: %[[c1:.*]] = arith.constant dense<8.999020e-01> : vector<16x16xf16> - //CHECK: %[[c2:.*]] = arith.constant dense<8.999020e-01> : vector<16x16xf16> - //CHECK: %[[c3:.*]] = arith.constant dense<8.999020e-01> : vector<16x16xf16> - //CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[c0]], %[[c1]], %[[c2]], %[[c3]] : vector<16x16xf16>, vector<16x16xf16>, vector<16x16xf16>, vector<16x16xf16> to vector<2x2x16x16xf16> - - %0 = arith.constant dense<0.9>: vector<2x2x16x16xf16> - - //CHECK: %[[c4:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c5:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c6:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c7:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c8:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c9:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c10:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c11:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c12:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c13:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c14:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c15:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c16:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c17:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c18:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c19:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c20:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c21:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c22:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c23:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c24:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c25:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c26:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c27:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c28:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c29:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c30:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c31:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c32:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c33:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c34:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c35:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c36:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c37:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c38:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c39:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c40:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c41:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c42:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c43:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c44:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c45:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c46:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c47:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c48:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c49:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c50:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c51:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c52:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c53:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c54:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c55:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c56:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c57:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c58:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c59:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c60:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c61:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c62:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c63:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c64:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c65:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c66:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - //CHECK: %[[c67:.*]] = arith.constant dense<2.300780e+00> : vector<1x16xf16> - - %1 = arith.constant dense<2.3>: vector<32x2x1x16xf16> - - //CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[c4]], %[[c5]], %[[c6]], %[[c7]], %[[c8]], %[[c9]], %[[c10]], %[[c11]], %[[c12]], %[[c13]], %[[c14]], %[[c15]], %[[c16]], %[[c17]], %[[c18]], %[[c19]], %[[c20]], %[[c21]], %[[c22]], %[[c23]], %[[c24]], %[[c25]], %[[c26]], %[[c27]], %[[c28]], %[[c29]], %[[c30]], %[[c31]], %[[c32]], %[[c33]], %[[c34]], %[[c35]], %[[c36]], %[[c37]], %[[c38]], %[[c39]], %[[c40]], %[[c41]], %[[c42]], %[[c43]], %[[c44]], %[[c45]], %[[c46]], %[[c47]], %[[c48]], %[[c49]], %[[c50]], %[[c51]], %[[c52]], %[[c53]], %[[c54]], %[[c55]], %[[c56]], %[[c57]], %[[c58]], %[[c59]], %[[c60]], %[[c61]], %[[c62]], %[[c63]], %[[c64]], %[[c65]], %[[c66]], %[[c67]] : vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16>, vector<1x16xf16> to vector<32x2x1x16xf16> - //CHECK: %[[SLICE1:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE2:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [1, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE3:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [2, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE4:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [3, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE5:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [4, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE6:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [5, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE7:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [6, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE8:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [7, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE9:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [8, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE10:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [9, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE11:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [10, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE12:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [11, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE13:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [12, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE14:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [13, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE15:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [14, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE16:.*]] = vector.extract_strided_slice %[[c0]] {offsets = [15, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE17:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE18:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [1, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE19:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [2, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE20:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [3, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE21:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [4, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE22:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [5, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE23:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [6, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE24:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [7, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE25:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [8, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE26:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [9, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE27:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [10, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE28:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [11, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE29:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [12, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE30:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [13, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE31:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [14, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE32:.*]] = vector.extract_strided_slice %[[c1]] {offsets = [15, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE33:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE34:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [1, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE35:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [2, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE36:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [3, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE37:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [4, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE38:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [5, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE39:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [6, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE40:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [7, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE41:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [8, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE42:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [9, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE43:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [10, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE44:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [11, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE45:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [12, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE46:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [13, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE47:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [14, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE48:.*]] = vector.extract_strided_slice %[[c2]] {offsets = [15, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE49:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE50:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [1, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE51:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [2, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE52:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [3, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE53:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [4, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE54:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [5, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE55:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [6, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE56:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [7, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE57:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [8, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE58:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [9, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE59:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [10, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE60:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [11, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE61:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [12, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE62:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [13, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE63:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [14, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16> - //CHECK: %[[SLICE64:.*]] = vector.extract_strided_slice %[[c3]] {offsets = [15, 0], sizes = [1, 16], strides = [1, 1]} : vector<16x16xf16> to vector<1x16xf16 - %2 = xetile.tile_unpack %0 {inner_blocks = array}: vector<2x2x16x16xf16> -> vector<32x32xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x32xf16> -> vector<32x2x1x16xf16> - //CHECK: %[[ADD1:.*]] = arith.addf %[[SLICE1]], %[[c4]] : vector<1x16xf16> - //CHECK: %[[ADD2:.*]] = arith.addf %[[SLICE17]], %[[c5]] : vector<1x16xf16> - //CHECK: %[[ADD3:.*]] = arith.addf %[[SLICE2]], %[[c6]] : vector<1x16xf16> - //CHECK: %[[ADD4:.*]] = arith.addf %[[SLICE18]], %[[c7]] : vector<1x16xf16> - //CHECK: %[[ADD5:.*]] = arith.addf %[[SLICE3]], %[[c8]] : vector<1x16xf16> - //CHECK: %[[ADD6:.*]] = arith.addf %[[SLICE19]], %[[c9]] : vector<1x16xf16> - //CHECK: %[[ADD7:.*]] = arith.addf %[[SLICE4]], %[[c10]] : vector<1x16xf16> - //CHECK: %[[ADD8:.*]] = arith.addf %[[SLICE20]], %[[c11]] : vector<1x16xf16> - //CHECK: %[[ADD9:.*]] = arith.addf %[[SLICE5]], %[[c12]] : vector<1x16xf16> - //CHECK: %[[ADD10:.*]] = arith.addf %[[SLICE21]], %[[c13]] : vector<1x16xf16> - //CHECK: %[[ADD11:.*]] = arith.addf %[[SLICE6]], %[[c14]] : vector<1x16xf16> - //CHECK: %[[ADD12:.*]] = arith.addf %[[SLICE22]], %[[c15]] : vector<1x16xf16> - //CHECK: %[[ADD13:.*]] = arith.addf %[[SLICE7]], %[[c16]] : vector<1x16xf16> - //CHECK: %[[ADD14:.*]] = arith.addf %[[SLICE23]], %[[c17]] : vector<1x16xf16> - //CHECK: %[[ADD15:.*]] = arith.addf %[[SLICE8]], %[[c18]] : vector<1x16xf16> - //CHECK: %[[ADD16:.*]] = arith.addf %[[SLICE24]], %[[c19]] : vector<1x16xf16> - //CHECK: %[[ADD17:.*]] = arith.addf %[[SLICE9]], %[[c20]] : vector<1x16xf16> - //CHECK: %[[ADD18:.*]] = arith.addf %[[SLICE25]], %[[c21]] : vector<1x16xf16> - //CHECK: %[[ADD19:.*]] = arith.addf %[[SLICE10]], %[[c22]] : vector<1x16xf16> - //CHECK: %[[ADD20:.*]] = arith.addf %[[SLICE26]], %[[c23]] : vector<1x16xf16> - //CHECK: %[[ADD21:.*]] = arith.addf %[[SLICE11]], %[[c24]] : vector<1x16xf16> - //CHECK: %[[ADD22:.*]] = arith.addf %[[SLICE27]], %[[c25]] : vector<1x16xf16> - //CHECK: %[[ADD23:.*]] = arith.addf %[[SLICE12]], %[[c26]] : vector<1x16xf16> - //CHECK: %[[ADD24:.*]] = arith.addf %[[SLICE28]], %[[c27]] : vector<1x16xf16> - //CHECK: %[[ADD25:.*]] = arith.addf %[[SLICE13]], %[[c28]] : vector<1x16xf16> - //CHECK: %[[ADD26:.*]] = arith.addf %[[SLICE29]], %[[c29]] : vector<1x16xf16> - //CHECK: %[[ADD27:.*]] = arith.addf %[[SLICE14]], %[[c30]] : vector<1x16xf16> - //CHECK: %[[ADD28:.*]] = arith.addf %[[SLICE30]], %[[c31]] : vector<1x16xf16> - //CHECK: %[[ADD29:.*]] = arith.addf %[[SLICE15]], %[[c32]] : vector<1x16xf16> - //CHECK: %[[ADD30:.*]] = arith.addf %[[SLICE31]], %[[c33]] : vector<1x16xf16> - //CHECK: %[[ADD31:.*]] = arith.addf %[[SLICE16]], %[[c34]] : vector<1x16xf16> - //CHECK: %[[ADD32:.*]] = arith.addf %[[SLICE32]], %[[c35]] : vector<1x16xf16> - //CHECK: %[[ADD33:.*]] = arith.addf %[[SLICE33]], %[[c36]] : vector<1x16xf16> - //CHECK: %[[ADD34:.*]] = arith.addf %[[SLICE49]], %[[c37]] : vector<1x16xf16> - //CHECK: %[[ADD35:.*]] = arith.addf %[[SLICE34]], %[[c38]] : vector<1x16xf16> - //CHECK: %[[ADD36:.*]] = arith.addf %[[SLICE50]], %[[c39]] : vector<1x16xf16> - //CHECK: %[[ADD37:.*]] = arith.addf %[[SLICE35]], %[[c40]] : vector<1x16xf16> - //CHECK: %[[ADD38:.*]] = arith.addf %[[SLICE51]], %[[c41]] : vector<1x16xf16> - //CHECK: %[[ADD39:.*]] = arith.addf %[[SLICE36]], %[[c42]] : vector<1x16xf16> - //CHECK: %[[ADD40:.*]] = arith.addf %[[SLICE52]], %[[c43]] : vector<1x16xf16> - //CHECK: %[[ADD41:.*]] = arith.addf %[[SLICE37]], %[[c44]] : vector<1x16xf16> - //CHECK: %[[ADD42:.*]] = arith.addf %[[SLICE53]], %[[c45]] : vector<1x16xf16> - //CHECK: %[[ADD43:.*]] = arith.addf %[[SLICE38]], %[[c46]] : vector<1x16xf16> - //CHECK: %[[ADD44:.*]] = arith.addf %[[SLICE54]], %[[c47]] : vector<1x16xf16> - //CHECK: %[[ADD45:.*]] = arith.addf %[[SLICE39]], %[[c48]] : vector<1x16xf16> - //CHECK: %[[ADD46:.*]] = arith.addf %[[SLICE55]], %[[c49]] : vector<1x16xf16> - //CHECK: %[[ADD47:.*]] = arith.addf %[[SLICE40]], %[[c50]] : vector<1x16xf16> - //CHECK: %[[ADD48:.*]] = arith.addf %[[SLICE56]], %[[c51]] : vector<1x16xf16> - //CHECK: %[[ADD49:.*]] = arith.addf %[[SLICE41]], %[[c52]] : vector<1x16xf16> - //CHECK: %[[ADD50:.*]] = arith.addf %[[SLICE57]], %[[c53]] : vector<1x16xf16> - //CHECK: %[[ADD51:.*]] = arith.addf %[[SLICE42]], %[[c54]] : vector<1x16xf16> - //CHECK: %[[ADD52:.*]] = arith.addf %[[SLICE58]], %[[c55]] : vector<1x16xf16> - //CHECK: %[[ADD53:.*]] = arith.addf %[[SLICE43]], %[[c56]] : vector<1x16xf16> - //CHECK: %[[ADD54:.*]] = arith.addf %[[SLICE59]], %[[c57]] : vector<1x16xf16> - //CHECK: %[[ADD55:.*]] = arith.addf %[[SLICE44]], %[[c58]] : vector<1x16xf16> - //CHECK: %[[ADD56:.*]] = arith.addf %[[SLICE60]], %[[c59]] : vector<1x16xf16> - //CHECK: %[[ADD57:.*]] = arith.addf %[[SLICE45]], %[[c60]] : vector<1x16xf16> - //CHECK: %[[ADD58:.*]] = arith.addf %[[SLICE61]], %[[c61]] : vector<1x16xf16> - //CHECK: %[[ADD59:.*]] = arith.addf %[[SLICE46]], %[[c62]] : vector<1x16xf16> - //CHECK: %[[ADD60:.*]] = arith.addf %[[SLICE62]], %[[c63]] : vector<1x16xf16> - //CHECK: %[[ADD61:.*]] = arith.addf %[[SLICE47]], %[[c64]] : vector<1x16xf16> - //CHECK: %[[ADD62:.*]] = arith.addf %[[SLICE63]], %[[c65]] : vector<1x16xf16> - //CHECK: %[[ADD63:.*]] = arith.addf %[[SLICE48]], %[[c66]] : vector<1x16xf16> - //CHECK: %[[ADD64:.*]] = arith.addf %[[SLICE64]], %[[c67]] : vector<1x16xf16> - - %result = arith.addf %3, %1 : vector<32x2x1x16xf16> - gpu.return - } - } diff --git a/test/Conversion/XeTileToXeGPU/array_length_load.mlir b/test/Conversion/XeTileToXeGPU/array_length_load.mlir index e3fa432f1..09a12cffa 100644 --- a/test/Conversion/XeTileToXeGPU/array_length_load.mlir +++ b/test/Conversion/XeTileToXeGPU/array_length_load.mlir @@ -7,7 +7,6 @@ gpu.module @test_kernel { %a_loaded = xetile.load_tile %a_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> // Do not let XeGPU do one load with multiple blocks (array_length > 1), where each block is finer than one GRF. - //CHECK: xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<1x32xf16> -> !xegpu.tensor_desc<1x32xf16, #xegpu.block_tdesc_attr> %b_tile = xetile.init_tile %b[%c0, %c0] : memref<1x32xf16> -> !xetile.tile<1x32xf16> %b_loaded = xetile.load_tile %b_tile : !xetile.tile<1x32xf16> -> vector<1x32xf16> diff --git a/test/Conversion/XeTileToXeGPU/create_mask.mlir b/test/Conversion/XeTileToXeGPU/create_mask.mlir deleted file mode 100644 index ca92e65e1..000000000 --- a/test/Conversion/XeTileToXeGPU/create_mask.mlir +++ /dev/null @@ -1,59 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --cse -verify-diagnostics %s -o -| FileCheck %s - -gpu.module @test_kernel { - gpu.func @create_mask(%arg0: vector<32x32xf16>, %arg1: vector<32x32xf16>, %arg2: memref<32x32xf16>) { - %c32 = arith.constant 32 : index - %c20 = arith.constant 20 : index - %0 = vector.create_mask %c32, %c20, %c32, %c20 : vector<32x2x1x16xi1> - %1 = xetile.tile_pack %arg0 {inner_blocks = array}: vector<32x32xf16> -> vector<32x2x1x16xf16> - %2 = xetile.tile_pack %arg1 {inner_blocks = array}: vector<32x32xf16> -> vector<32x2x1x16xf16> - %3 = arith.select %0, %1, %2 : vector<32x2x1x16xi1>, vector<32x2x1x16xf16> - %4 = xetile.tile_unpack %3 {inner_blocks = array}: vector<32x2x1x16xf16> -> vector<32x32xf16> - %5 = xetile.init_tile %arg2[0, 0] : memref<32x32xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - %6 = xetile.tile_pack %4 {inner_blocks = array}: vector<32x32xf16> -> vector<4x1x8x32xf16> - xetile.store_tile %6, %5 : vector<4x1x8x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr> - gpu.return - } -} - -// convert-xetile-to-xegpu generates 64 create_mask ops. But since all the rows -// have the same masks, cse leaves only two masks corresponding to a row. - -//CHECK: gpu.func @create_mask -//CHECK-DAG: %[[C1:.*]] = arith.constant 1 -//CHECK-DAG: %[[C16:.*]] = arith.constant 16 -//CHECK-DAG: %[[MASK1_VAL:.*]] = arith.constant 20 -//CHECK: %[[MASK2_VAL:.*]] = arith.subi %[[MASK1_VAL]], %c16 -//CHECK: %[[MASK1:.*]] = vector.create_mask %[[C1]], %[[MASK1_VAL]] : vector<1x16xi1> -//CHECK: %[[MASK2:.*]] = vector.create_mask %[[C1]], %[[MASK2_VAL]] : vector<1x16xi1> -//CHECK-NOT: %[[MASK2:.*]] = vector.create_mask -//CHECK: arith.select %[[MASK1]], {{.*}} : vector<1x16xi1>, vector<1x16xf16> -//CHECK: arith.select %[[MASK2]], {{.*}} : vector<1x16xi1>, vector<1x16xf16> - -// ----- - -gpu.module @test_kernel_2 { - gpu.func @create_mask_2(%arg0: vector<32x32xf16>, %arg1: vector<32x32xf16>, %arg2: memref<32x32xf16>) { - %c32 = arith.constant 32 : index - %c20 = arith.constant 20 : index - %0 = vector.create_mask %c20, %c32, %c20, %c32 : vector<32x2x1x16xi1> - %1 = xetile.tile_pack %arg0 {inner_blocks = array} : vector<32x32xf16> -> vector<32x2x1x16xf16> - %2 = xetile.tile_pack %arg1 {inner_blocks = array} : vector<32x32xf16> -> vector<32x2x1x16xf16> - %3 = arith.select %0, %1, %2 : vector<32x2x1x16xi1>, vector<32x2x1x16xf16> - %4 = xetile.tile_unpack %3 {inner_blocks = array} : vector<32x2x1x16xf16> -> vector<32x32xf16> - %5 = xetile.init_tile %arg2[0, 0] : memref<32x32xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - %6 = xetile.tile_pack %4 {inner_blocks = array} : vector<32x32xf16> -> vector<4x1x8x32xf16> - xetile.store_tile %6, %5 : vector<4x1x8x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr> - gpu.return - } -} - -//CHECK: gpu.func @create_mask_2 -//CHECK: %[[UB:.*]] = arith.constant 20 -//CHECK: %[[C0:.*]] = arith.constant 0 -//CHECK: %[[CMP0:.*]] = arith.cmpi slt, %[[C0]], %[[UB]] : index -//CHECK: %[[SPLAT0:.*]] = vector.splat %[[CMP0]] : vector<1x16xi1> -//CHECK-COUNT-31: arith.cmpi -//CHECK: arith.select %[[SPLAT0]], {{.*}} : vector<1x16xi1>, vector<1x16xf16> -//CHECK: arith.select %[[SPLAT0]], {{.*}} : vector<1x16xi1>, vector<1x16xf16> -//CHECK-COUNT-62: arith.select diff --git a/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir b/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir deleted file mode 100644 index 429a64f33..000000000 --- a/test/Conversion/XeTileToXeGPU/elementwise_ops.mlir +++ /dev/null @@ -1,297 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu %s -verify-diagnostics -o -| FileCheck %s - gpu.module @test_kernel { - gpu.func @arith_binary_ops() { - %0 = arith.constant dense<0.9>: vector<4x4x16x16xf16> - %1 = arith.constant dense<2.3>: vector<64x4x1x16xf16> - %2 = xetile.tile_unpack %0 {inner_blocks = array}: vector<4x4x16x16xf16> -> vector<64x64xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<64x64xf16> -> vector<64x4x1x16xf16> - // CHECK-COUNT-256: arith.addf {{.*}}, {{.*}} fastmath : vector<1x16xf16> - // CHECK-COUNT-256: arith.sub - // CHECK-COUNT-256: arith.mulf - // CHECK-COUNT-256: arith.maximumf - // CHECK-COUNT-256: arith.minimumf - // CHECK-COUNT-256: arith.divf - // CHECK-COUNT-256: arith.remf - // CHECK-COUNT-256: arith.cmpf - %result = arith.addf %3, %1 fastmath : vector<64x4x1x16xf16> - %subf_result = arith.subf %result, %1 : vector<64x4x1x16xf16> - %mulf_result = arith.mulf %subf_result, %1 : vector<64x4x1x16xf16> - %maxf_result = arith.maximumf %mulf_result, %1 : vector<64x4x1x16xf16> - %minf_result = arith.minimumf %maxf_result, %mulf_result : vector<64x4x1x16xf16> - %divf_result = arith.divf %minf_result, %1 : vector<64x4x1x16xf16> - %remf_result = arith.remf %minf_result, %divf_result : vector<64x4x1x16xf16> - %cmpf_result = arith.cmpf ult, %remf_result, %divf_result : vector<64x4x1x16xf16> - gpu.return - } - - gpu.func @arith_binary_ops_int() { - %0 = arith.constant dense<1>: vector<4x4x16x16xi16> - %1 = arith.constant dense<2>: vector<64x4x1x16xi16> - %2 = xetile.tile_unpack %0 {inner_blocks = array}: vector<4x4x16x16xi16> -> vector<64x64xi16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<64x64xi16> -> vector<64x4x1x16xi16> - // CHECK-COUNT-256: arith.addi {{.*}}, {{.*}} : vector<1x16xi16> - // CHECK-COUNT-256: arith.subi - // CHECK-COUNT-256: arith.muli - // CHECK-COUNT-256: arith.maxsi - // CHECK-COUNT-256: arith.maxui - // CHECK-COUNT-256: arith.minsi - // CHECK-COUNT-256: arith.minui - // CHECK-COUNT-256: arith.divsi - // CHECK-COUNT-256: arith.divui - // CHECK-COUNT-256: arith.remsi - // CHECK-COUNT-256: arith.remui - // CHECK-COUNT-256: arith.andi - // CHECK-COUNT-256: arith.addui_extended - %result = arith.addi %3, %1 : vector<64x4x1x16xi16> - %subi_result = arith.subi %3, %1 : vector<64x4x1x16xi16> - %muli_result = arith.muli %subi_result, %1 : vector<64x4x1x16xi16> - %maxsi_result = arith.maxsi %muli_result, %1 : vector<64x4x1x16xi16> - %maxui_result = arith.maxui %muli_result, %1 : vector<64x4x1x16xi16> - %minsi_result = arith.minsi %maxsi_result, %muli_result : vector<64x4x1x16xi16> - %minui_result = arith.minui %maxui_result, %muli_result : vector<64x4x1x16xi16> - %divsi_result = arith.divsi %minui_result, %1 : vector<64x4x1x16xi16> - %divui_result = arith.divui %minui_result, %1 : vector<64x4x1x16xi16> - %remsi_result = arith.remsi %minsi_result, %divsi_result : vector<64x4x1x16xi16> - %remui_result = arith.remui %minui_result, %divui_result : vector<64x4x1x16xi16> - %and_result = arith.andi %remsi_result, %remui_result : vector<64x4x1x16xi16> - %addui_sum, %addui_overflow = arith.addui_extended %3, %1 : vector<64x4x1x16xi16>, vector<64x4x1x16xi1> - %addui_extented_result:2 = arith.addui_extended %3, %1 : vector<64x4x1x16xi16>, vector<64x4x1x16xi1> - gpu.return - } - - gpu.func @arith_xori_ops() { - %0 = arith.constant dense<1>: vector<4x4x16x16xi16> - %1 = arith.constant dense<2>: vector<64x4x1x16xi16> - %2 = xetile.tile_unpack %0 {inner_blocks = array}: vector<4x4x16x16xi16> -> vector<64x64xi16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<64x64xi16> -> vector<64x4x1x16xi16> - // CHECK-COUNT-256: arith.xori {{.*}}, {{.*}} : vector<1x16xi16> - %xori_result = arith.xori %3, %1 : vector<64x4x1x16xi16> - gpu.return - } - - gpu.func @arith_unary_ops() { - %0 = arith.constant dense<0.9>: vector<4x4x16x16xf16> - %1 = arith.constant dense<2.3>: vector<64x4x1x16xf16> - %2 = xetile.tile_unpack %0 {inner_blocks = array}: vector<4x4x16x16xf16> -> vector<64x64xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<64x64xf16> -> vector<64x4x1x16xf16> - // CHECK-COUNT-256: arith.addf - // CHECK-COUNT-256: arith.negf {{.*}} : vector<1x16xf16> - %result = arith.addf %3, %1 : vector<64x4x1x16xf16> - %negf_result = arith.negf %result : vector<64x4x1x16xf16> - gpu.return - } - - - gpu.func @math_binary_ops() { - %0 = arith.constant dense<0.9>: vector<4x4x16x16xf16> - %1 = arith.constant dense<2.3>: vector<64x4x1x16xf16> - %2 = xetile.tile_unpack %0 {inner_blocks = array}: vector<4x4x16x16xf16> -> vector<64x64xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<64x64xf16> -> vector<64x4x1x16xf16> - // CHECK-COUNT-256: math.powf {{.*}}, {{.*}} : vector<1x16xf16> - %result = math.powf %3, %1 : vector<64x4x1x16xf16> - gpu.return - } - - gpu.func @math_unary_ops() { - %0 = arith.constant dense<0.9>: vector<4x4x16x16xf16> - %1 = arith.constant dense<2.3>: vector<64x4x1x16xf16> - %2 = xetile.tile_unpack %0 {inner_blocks = array}: vector<4x4x16x16xf16> -> vector<64x64xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<64x64xf16> -> vector<64x4x1x16xf16> - // CHECK-COUNT-256: math.exp {{.*}} : vector<1x16xf16> - // CHECK-COUNT-256: math.sin - // CHECK-COUNT-256: math.cos - // CHECK-COUNT-256: math.tanh - // CHECK-COUNT-256: math.sqrt - // CHECK-COUNT-256: math.log - // CHECK-COUNT-256: math.rsqrt - // CHECK-COUNT-256: math.erf - %result = arith.addf %3, %1 : vector<64x4x1x16xf16> - %exp_result = math.exp %result : vector<64x4x1x16xf16> - %sin_result = math.sin %exp_result : vector<64x4x1x16xf16> - %cos_result = math.cos %sin_result : vector<64x4x1x16xf16> - %tan_result = math.tanh %cos_result : vector<64x4x1x16xf16> - %sqrt_result = math.sqrt %tan_result : vector<64x4x1x16xf16> - %log_result = math.log %sqrt_result : vector<64x4x1x16xf16> - %rsqrt_result = math.rsqrt %log_result : vector<64x4x1x16xf16> - %erf_result = math.erf %rsqrt_result : vector<64x4x1x16xf16> - gpu.return - } - - gpu.func @sglevel_type_cast(%arg0: memref<1024x1024xf16>) { - %0 = xetile.init_tile %arg0[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - %1 = xetile.load_tile %0 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<1x1x32x32xf16> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x1x32x32xf16> -> vector<32x32xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x32xf16> -> vector<32x2x1x16xf16> - - //CHECK-COUNT-64: arith.extf {{.*}} : vector<1x16xf16> to vector<1x16xf32> - %4 = arith.extf %3 : vector<32x2x1x16xf16> to vector<32x2x1x16xf32> - - //CHECK-COUNT-64: math.exp {{.*}} : vector<1x16xf32> - %5 = math.exp %4 : vector<32x2x1x16xf32> - - //CHECK-COUNT-64: arith.truncf {{.*}} : vector<1x16xf32> to vector<1x16xf16> - %6 = arith.truncf %5 : vector<32x2x1x16xf32> to vector<32x2x1x16xf16> - - %7 = xetile.tile_unpack %6 {inner_blocks = array}: vector<32x2x1x16xf16> -> vector<32x32xf16> - %8 = xetile.init_tile %arg0[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - %9 = xetile.tile_pack %7 {inner_blocks = array}: vector<32x32xf16> -> vector<4x1x8x32xf16> - xetile.store_tile %9, %8 : vector<4x1x8x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr> - gpu.return - } - - gpu.func @sglevel_extsi_test(%arg0: memref<32x32xi16>, %arg1: memref<32x32xi32>) { - %0 = xetile.init_tile %arg0[0, 0] : memref<32x32xi16> -> !xetile.tile<32x32xi16, #xetile.tile_attr> - %1 = xetile.load_tile %0 { padding = 0 : i32 } : !xetile.tile<32x32xi16, #xetile.tile_attr> -> vector<1x1x32x32xi16> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x1x32x32xi16> -> vector<32x32xi16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x32xi16> -> vector<32x2x1x16xi16> - - //CHECK-COUNT-64: arith.extsi {{.*}} : vector<1x16xi16> to vector<1x16xi32> - %4 = arith.extsi %3 : vector<32x2x1x16xi16> to vector<32x2x1x16xi32> - - %5 = xetile.tile_unpack %4 {inner_blocks = array}: vector<32x2x1x16xi32> -> vector<32x32xi32> - %6 = xetile.init_tile %arg1[0, 0] : memref<32x32xi32> -> !xetile.tile<32x32xi32, #xetile.tile_attr> - %7 = xetile.tile_pack %5 {inner_blocks = array} : vector<32x32xi32> -> vector<4x2x8x16xi32> - xetile.store_tile %7, %6 : vector<4x2x8x16xi32>, !xetile.tile<32x32xi32, #xetile.tile_attr> - gpu.return - } - - gpu.func @sglevel_extui_test(%arg0: memref<32x32xi16>, %arg1: memref<32x32xi32>) { - %0 = xetile.init_tile %arg0[0, 0] : memref<32x32xi16> -> !xetile.tile<32x32xi16, #xetile.tile_attr> - %1 = xetile.load_tile %0 { padding = 0 : i32 } : !xetile.tile<32x32xi16, #xetile.tile_attr> -> vector<1x1x32x32xi16> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x1x32x32xi16> -> vector<32x32xi16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x32xi16> -> vector<32x2x1x16xi16> - - //CHECK-COUNT-64: arith.extui {{.*}} : vector<1x16xi16> to vector<1x16xi32> - %4 = arith.extui %3 : vector<32x2x1x16xi16> to vector<32x2x1x16xi32> - - %5 = xetile.tile_unpack %4 {inner_blocks = array}: vector<32x2x1x16xi32> -> vector<32x32xi32> - %6 = xetile.init_tile %arg1[0, 0] : memref<32x32xi32> -> !xetile.tile<32x32xi32, #xetile.tile_attr> - %7 = xetile.tile_pack %5 {inner_blocks = array} : vector<32x32xi32> -> vector<4x2x8x16xi32> - xetile.store_tile %7, %6 : vector<4x2x8x16xi32>, !xetile.tile<32x32xi32, #xetile.tile_attr> - gpu.return - } - - gpu.func @sglevel_fptosi_test(%arg0: memref<32x32xf16>, %arg1: memref<32x32xi32>) { - %0 = xetile.init_tile %arg0[0, 0] : memref<32x32xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - %1 = xetile.load_tile %0 { padding = 0.0 : f32 } : !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<1x1x32x32xf16> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x1x32x32xf16> -> vector<32x32xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x32xf16> -> vector<32x2x1x16xf16> - - //CHECK-COUNT-64: arith.fptosi {{.*}} : vector<1x16xf16> to vector<1x16xi32> - %4 = arith.fptosi %3 : vector<32x2x1x16xf16> to vector<32x2x1x16xi32> - - %5 = xetile.tile_unpack %4 {inner_blocks = array}: vector<32x2x1x16xi32> -> vector<32x32xi32> - %6 = xetile.init_tile %arg1[0, 0] : memref<32x32xi32> -> !xetile.tile<32x32xi32, #xetile.tile_attr> - %7 = xetile.tile_pack %5 {inner_blocks = array} : vector<32x32xi32> -> vector<4x2x8x16xi32> - xetile.store_tile %7, %6 : vector<4x2x8x16xi32>, !xetile.tile<32x32xi32, #xetile.tile_attr> - gpu.return - } - - gpu.func @sglevel_fptoui_test(%arg0: memref<32x32xf16>, %arg1: memref<32x32xi32>) { - %0 = xetile.init_tile %arg0[0, 0] : memref<32x32xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - %1 = xetile.load_tile %0 { padding = 0.0 : f32 } : !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<1x1x32x32xf16> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x1x32x32xf16> -> vector<32x32xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x32xf16> -> vector<32x2x1x16xf16> - - //CHECK-COUNT-64: arith.fptoui {{.*}} : vector<1x16xf16> to vector<1x16xi32> - %4 = arith.fptoui %3 : vector<32x2x1x16xf16> to vector<32x2x1x16xi32> - - %5 = xetile.tile_unpack %4 {inner_blocks = array}: vector<32x2x1x16xi32> -> vector<32x32xi32> - %6 = xetile.init_tile %arg1[0, 0] : memref<32x32xi32> -> !xetile.tile<32x32xi32, #xetile.tile_attr> - %7 = xetile.tile_pack %5 {inner_blocks = array} : vector<32x32xi32> -> vector<4x2x8x16xi32> - xetile.store_tile %7, %6 : vector<4x2x8x16xi32>, !xetile.tile<32x32xi32, #xetile.tile_attr> - gpu.return - } - - gpu.func @sglevel_sitofp_test(%arg0: memref<32x32xi32>, %arg1: memref<32x32xf32>) { - %0 = xetile.init_tile %arg0[0, 0] : memref<32x32xi32> -> !xetile.tile<32x32xi32, #xetile.tile_attr> - %1 = xetile.load_tile %0 { padding = 0 : i32 } : !xetile.tile<32x32xi32, #xetile.tile_attr> -> vector<1x2x32x16xi32> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x2x32x16xi32> -> vector<32x32xi32> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x32xi32> -> vector<32x2x1x16xi32> - //CHECK-COUNT-64: arith.sitofp {{.*}} : vector<1x16xi32> to vector<1x16xf32> - %4 = arith.sitofp %3 : vector<32x2x1x16xi32> to vector<32x2x1x16xf32> - %5 = xetile.tile_unpack %4 {inner_blocks = array}: vector<32x2x1x16xf32> -> vector<32x32xf32> - %6 = xetile.init_tile %arg1[0, 0] : memref<32x32xf32> -> !xetile.tile<32x32xf32, #xetile.tile_attr> - %7 = xetile.tile_pack %5 {inner_blocks = array} : vector<32x32xf32> -> vector<4x2x8x16xf32> - xetile.store_tile %7, %6 : vector<4x2x8x16xf32>, !xetile.tile<32x32xf32, #xetile.tile_attr> - gpu.return - } - - gpu.func @sglevel_uitofp_test(%arg0: memref<32x32xi32>, %arg1: memref<32x32xf32>) { - %0 = xetile.init_tile %arg0[0, 0] : memref<32x32xi32> -> !xetile.tile<32x32xi32, #xetile.tile_attr> - %1 = xetile.load_tile %0 { padding = 0 : i32 } : !xetile.tile<32x32xi32, #xetile.tile_attr> -> vector<1x2x32x16xi32> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x2x32x16xi32> -> vector<32x32xi32> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x32xi32> -> vector<32x2x1x16xi32> - //CHECK-COUNT-64: arith.uitofp {{.*}} : vector<1x16xi32> to vector<1x16xf32> - %4 = arith.uitofp %3 : vector<32x2x1x16xi32> to vector<32x2x1x16xf32> - %5 = xetile.tile_unpack %4 {inner_blocks = array}: vector<32x2x1x16xf32> -> vector<32x32xf32> - %6 = xetile.init_tile %arg1[0, 0] : memref<32x32xf32> -> !xetile.tile<32x32xf32, #xetile.tile_attr> - %7 = xetile.tile_pack %5 {inner_blocks = array} : vector<32x32xf32> -> vector<4x2x8x16xf32> - xetile.store_tile %7, %6 : vector<4x2x8x16xf32>, !xetile.tile<32x32xf32, #xetile.tile_attr> - gpu.return - } - - gpu.func @sglevel_truncf_test(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf16>) { - %0 = xetile.init_tile %arg0[0, 0] : memref<32x32xf32> -> !xetile.tile<32x32xf32, #xetile.tile_attr> - %1 = xetile.load_tile %0 { padding = 0.0 : f32 } : !xetile.tile<32x32xf32, #xetile.tile_attr> -> vector<1x2x32x16xf32> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x2x32x16xf32> -> vector<32x32xf32> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x32xf32> -> vector<32x2x1x16xf32> - //CHECK-COUNT-64: arith.truncf {{.*}} : vector<1x16xf32> to vector<1x16xf16> - %4 = arith.truncf %3 : vector<32x2x1x16xf32> to vector<32x2x1x16xf16> - %5 = xetile.tile_unpack %4 {inner_blocks = array}: vector<32x2x1x16xf16> -> vector<32x32xf16> - %6 = xetile.init_tile %arg1[0, 0] : memref<32x32xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - %7 = xetile.tile_pack %5 {inner_blocks = array} : vector<32x32xf16> -> vector<4x2x8x16xf16> - xetile.store_tile %7, %6 : vector<4x2x8x16xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr> - gpu.return - } - - gpu.func @sglevel_trunci_test(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi16>) { - %0 = xetile.init_tile %arg0[0, 0] : memref<32x32xi32> -> !xetile.tile<32x32xi32, #xetile.tile_attr> - %1 = xetile.load_tile %0 { padding = 0 : i32 } : !xetile.tile<32x32xi32, #xetile.tile_attr> -> vector<1x2x32x16xi32> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x2x32x16xi32> -> vector<32x32xi32> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x32xi32> -> vector<32x2x1x16xi32> - //CHECK-COUNT-64: arith.trunci {{.*}} : vector<1x16xi32> to vector<1x16xi16> - %4 = arith.trunci %3 : vector<32x2x1x16xi32> to vector<32x2x1x16xi16> - %5 = xetile.tile_unpack %4 {inner_blocks = array}: vector<32x2x1x16xi16> -> vector<32x32xi16> - %6 = xetile.init_tile %arg1[0, 0] : memref<32x32xi16> -> !xetile.tile<32x32xi16, #xetile.tile_attr> - %7 = xetile.tile_pack %5 {inner_blocks = array} : vector<32x32xi16> -> vector<4x2x8x16xi16> - xetile.store_tile %7, %6 : vector<4x2x8x16xi16>, !xetile.tile<32x32xi16, #xetile.tile_attr> - gpu.return - } - - - gpu.func @sglevel_and_test(%arg0: memref<1x4096xi8>, %arg1: memref<1x4096xi8>, %arg2: memref<1x4096xi8>) { - %c0 = arith.constant 0 : index - %c4096 = arith.constant 4096 : index - %c32 = arith.constant 32 : index - %c1024_i32 = arith.constant 1024 : i32 - %thread_id_x = gpu.thread_id x - %thread_id_y = gpu.thread_id y - %block_dim_y = gpu.block_dim y - %0 = arith.muli %thread_id_x, %block_dim_y : index - %1 = arith.addi %0, %thread_id_y : index - %block_id_x = gpu.block_id x - %2 = arith.index_cast %block_id_x : index to i32 - %3 = arith.muli %2, %c1024_i32 : i32 - %4 = arith.index_cast %3 : i32 to index - %5 = arith.remsi %1, %c32 : index - %6 = arith.muli %5, %c32 : index - %7 = arith.remsi %6, %c4096 : index - %8 = arith.addi %7, %4 : index - %9 = xetile.init_tile %arg0[%c0, %8] : memref<1x4096xi8> -> !xetile.tile<1x32xi8, #xetile.tile_attr> - %10 = xetile.load_tile %9 {padding = 0 : i32} : !xetile.tile<1x32xi8, #xetile.tile_attr> -> vector<1x1x1x32xi8> - %11 = xetile.tile_unpack %10 {inner_blocks = array} : vector<1x1x1x32xi8> -> vector<1x32xi8> - %12 = xetile.init_tile %arg1[%c0, %8] : memref<1x4096xi8> -> !xetile.tile<1x32xi8, #xetile.tile_attr> - %13 = xetile.load_tile %12 {padding = 0 : i32} : !xetile.tile<1x32xi8, #xetile.tile_attr> -> vector<1x1x1x32xi8> - %14 = xetile.tile_unpack %13 {inner_blocks = array} : vector<1x1x1x32xi8> -> vector<1x32xi8> - %15 = xetile.tile_pack %11 {inner_blocks = array} : vector<1x32xi8> -> vector<1x1x1x32xi8> - %16 = xetile.tile_pack %14 {inner_blocks = array} : vector<1x32xi8> -> vector<1x1x1x32xi8> - //CHECK: %{{.*}} = arith.andi %{{.*}}, %{{.*}} : vector<1x32xi8> - %17 = arith.andi %15, %16 : vector<1x1x1x32xi8> - %18 = xetile.tile_unpack %17 {inner_blocks = array} : vector<1x1x1x32xi8> -> vector<1x32xi8> - %19 = xetile.init_tile %arg2[%c0, %8] : memref<1x4096xi8> -> !xetile.tile<1x32xi8, #xetile.tile_attr> - %20 = xetile.tile_pack %18 {inner_blocks = array} : vector<1x32xi8> -> vector<1x1x1x32xi8> - xetile.store_tile %20, %19 : vector<1x1x1x32xi8>, !xetile.tile<1x32xi8, #xetile.tile_attr> - gpu.return - } - } diff --git a/test/Conversion/XeTileToXeGPU/gemm_preop.mlir b/test/Conversion/XeTileToXeGPU/gemm_preop.mlir index 19d471218..d19a4d607 100755 --- a/test/Conversion/XeTileToXeGPU/gemm_preop.mlir +++ b/test/Conversion/XeTileToXeGPU/gemm_preop.mlir @@ -1,4 +1,4 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking --cse \ +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" --cse \ // RUN: --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s #map = affine_map<() -> (0)> #map1 = affine_map<() -> (64)> diff --git a/test/Conversion/XeTileToXeGPU/lit.local.cfg b/test/Conversion/XeTileToXeGPU/lit.local.cfg index 097b2470c..fb5f46296 100644 --- a/test/Conversion/XeTileToXeGPU/lit.local.cfg +++ b/test/Conversion/XeTileToXeGPU/lit.local.cfg @@ -5,4 +5,14 @@ excludes_slm_tests = [ 'sg_gemm_1k_1k_1k_f16_f32_slm.mlir', ] +excludes_array_length_tests = [ + 'sg_tile_mma.mlir', + 'array_length_load.mlir', + 'sg_gemm_1k_1k_1k_f16_f32.mlir', + 'sg_gemm_1k_1k_1k_i8_i32.mlir', + 'sg_gemm_1k_1k_1k_tf32_tf32.mlir', + 'sg_gemm_transpose_b.mlir', + ] + config.excludes.update(excludes_slm_tests) +config.excludes.update(excludes_array_length_tests) diff --git a/test/Conversion/XeTileToXeGPU/non_pow2_stacking.mlir b/test/Conversion/XeTileToXeGPU/non_pow2_stacking.mlir deleted file mode 100644 index b2560df6e..000000000 --- a/test/Conversion/XeTileToXeGPU/non_pow2_stacking.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s - -module @test_module attributes {gpu.container_module} { - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<24x32xf16>, %B: memref<24x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %a_tile = xetile.init_tile %A[%c0, %c0] : memref<24x32xf16> -> !xetile.tile<24x32xf16, #xetile.tile_attr> - %b_tile = xetile.init_tile %B[%c0, %c0] : memref<24x32xf16> -> !xetile.tile<24x32xf16, #xetile.tile_attr> - %a_value = xetile.load_tile %a_tile {padding = 0.000000e+00 : f32} : !xetile.tile<24x32xf16, #xetile.tile_attr> -> vector<3x1x8x32xf16> - %b_value = xetile.load_tile %b_tile {padding = 0.000000e+00 : f32} : !xetile.tile<24x32xf16, #xetile.tile_attr> -> vector<4x2x6x16xf16> - - %a_valuee = xetile.tile_unpack %a_value {inner_blocks = array}: vector<3x1x8x32xf16> -> vector<24x32xf16> - %b_valuee = xetile.tile_unpack %b_value {inner_blocks = array} : vector<4x2x6x16xf16> -> vector<24x32xf16> - - %c_value = arith.addf %a_valuee, %b_valuee : vector<24x32xf16> - //CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice %{{.*}} {offsets = {{.*}}, sizes = [6, 32], strides = [1, 1]} : vector<24x32xf16> to vector<6x32xf16> - //CHECK-COUNT-8: {{.*}} = vector.extract_strided_slice %{{.*}} {offsets = {{.*}}, sizes = [6, 16], strides = [1, 1]} : vector<6x32xf16> to vector<6x16xf16> - %c_valuee = xetile.tile_pack %c_value {inner_blocks = array} : vector<24x32xf16> -> vector<4x2x6x16xf16> - xetile.store_tile %c_valuee, %b_tile : vector<4x2x6x16xf16>, !xetile.tile<24x32xf16, #xetile.tile_attr> - - gpu.return - } - } -} diff --git a/test/Conversion/XeTileToXeGPU/prefetch.mlir b/test/Conversion/XeTileToXeGPU/prefetch.mlir deleted file mode 100644 index 0b1ccb01b..000000000 --- a/test/Conversion/XeTileToXeGPU/prefetch.mlir +++ /dev/null @@ -1,29 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu %s -verify-diagnostics -o -| FileCheck %s - -// CHECK-LABEL: test_prefetch -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: xegpu.prefetch_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> -// CHECK: gpu.return -gpu.module @test_kernel { -gpu.func @test_prefetch(%a: memref<2x64xf16>) { - %c0 = arith.constant 0 : index - %0 = xetile.init_tile %a[%c0, %c0] : memref<2x64xf16> -> !xetile.tile<2x64xf16, #xetile.tile_attr> - xetile.prefetch_tile %0 : !xetile.tile<2x64xf16, #xetile.tile_attr> - xetile.prefetch_tile %0 {l1_hint = #xetile.cache_hint, l3_hint = #xetile.cache_hint} : !xetile.tile<2x64xf16, #xetile.tile_attr> - gpu.return -} -} diff --git a/test/Conversion/XeTileToXeGPU/reduction.mlir b/test/Conversion/XeTileToXeGPU/reduction.mlir deleted file mode 100644 index 9a1eeebea..000000000 --- a/test/Conversion/XeTileToXeGPU/reduction.mlir +++ /dev/null @@ -1,273 +0,0 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-canonicalization --xetile-blocking \ -// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s -module { - gpu.module @test_kernel { - - //CHECK: gpu.func @inner_reduction(%[[arg0:.*]]: memref<128x256xf16>, %[[arg1:.*]]: memref<128x256xf16>) { - gpu.func @inner_reduction(%a: memref<128x256xf16>, %b: memref<128x256xf16>) { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - %c0 = arith.constant 0 : index - //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c0]]] : memref<128x256xf16> -> !xegpu.tensor_desc<16x32xf16, #xegpu.block_tdesc_attr> - %t = xetile.init_tile %a[%c0, %c0] : memref<128x256xf16> -> !xetile.tile<16x32xf16> - //CHECK: %[[r1:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x32xf16, #xegpu.block_tdesc_attr> -> vector<16x32xf16> - //CHECK-COUNT-16: {{.*}} = vector.extract_strided_slice %[[r1]] {offsets = {{.*}}, sizes = [1, 32], strides = [1, 1]} : vector<16x32xf16> to vector<1x32xf16> - %v = xetile.load_tile %t : !xetile.tile<16x32xf16> -> vector<16x32xf16> - //CHECK-COUNT-16: {{.*}} = math.exp %{{.*}} : vector<1x32xf16> - %e = math.exp %v: vector<16x32xf16> - //CHECK-COUNT-16: {{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29, 32, 33, 36, 37, 40, 41, 44, 45, 48, 49, 52, 53, 56, 57, 60, 61] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31, 34, 35, 38, 39, 42, 43, 46, 47, 50, 51, 54, 55, 58, 59, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31] : vector<32xf16>, vector<32xf16> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<16xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK-COUNT-8: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK-COUNT-4: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK-COUNT-2: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x1xf16>, vector<8x1xf16> - %r = xetile.reduction , %e[1] : vector<16x32xf16> -> vector<16x1xf16> - //CHECK: %[[r177:.*]] = vector.shape_cast {{.*}} : vector<16x1xf16> to vector<2x8xf16> - %c = vector.shape_cast %r: vector<16x1xf16> to vector<2x8xf16> - //CHECK: %[[r178:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c0]], %[[c0]]] : memref<128x256xf16> -> !xegpu.tensor_desc<2x8xf16, #xegpu.block_tdesc_attr> - %s = xetile.init_tile %b[%c0, %c0] : memref<128x256xf16> -> !xetile.tile<2x8xf16> - //CHECK: xegpu.store_nd %[[r177]], %[[r178]] <{{.*}}> : vector<2x8xf16>, !xegpu.tensor_desc<2x8xf16, #xegpu.block_tdesc_attr> - xetile.store_tile %c, %s : vector<2x8xf16>, !xetile.tile<2x8xf16> - gpu.return - } - - gpu.func @inner_reduction_1(%a: memref<8x32xf32>, %b: memref<8x1xf32>) { - %c0 = arith.constant 0 : index - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c0]]] : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK: %[[c16:.*]] = arith.constant 16 : index - //CHECK: %[[r1:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c16]]] : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %a_tile = xetile.init_tile %a[%c0, %c0] : memref<8x32xf32> -> !xetile.tile<8x32xf32> - //CHECK: %[[r2:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c0]], %[[c0]]] : memref<8x1xf32> -> !xegpu.tensor_desc<8x1xf32, #xegpu.block_tdesc_attr> - %b_tile = xetile.init_tile %b[%c0, %c0] : memref<8x1xf32> -> !xetile.tile<8x1xf32> - //CHECK: %[[r3:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> -> vector<8x16xf32> - //CHECK: %[[r4:.*]] = xegpu.load_nd %[[r1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> -> vector<8x16xf32> - //CHECK-COUNT-8: {{.*}} = vector.extract_strided_slice %[[r3]] {offsets = {{.*}}, sizes = [1, 16], strides = [1, 1]} : vector<8x16xf32> to vector<1x16xf32> - //CHECK-COUNT-8: {{.*}} = vector.extract_strided_slice %[[r4]] {offsets = {{.*}}, sizes = [1, 16], strides = [1, 1]} : vector<8x16xf32> to vector<1x16xf32> - %a_loaded = xetile.load_tile %a_tile: !xetile.tile<8x32xf32> -> vector<8x32xf32> - - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<1x16xf32> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf32> to vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<1x16xf32> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf32> to vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<1x16xf32> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf32> to vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<1x16xf32> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf32> to vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<1x16xf32> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf32> to vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<1x16xf32> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf32> to vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<1x16xf32> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf32> to vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = arith.maximumf %{{.*}}, %{{.*}} : vector<8xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK-COUNT-4: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x1xf32>, vector<1x1xf32> - //CHECK-COUNT-2: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x1xf32>, vector<2x1xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf32>, vector<4x1xf32> - %3 = xetile.reduction , %a_loaded[1] : vector<8x32xf32> -> vector<8x1xf32> // fastmath is implicit here - //CHECK: xegpu.store_nd {{.*}} : vector<8x1xf32>, !xegpu.tensor_desc<8x1xf32, #xegpu.block_tdesc_attr> - xetile.store_tile %3, %b_tile : vector<8x1xf32>, !xetile.tile<8x1xf32> - gpu.return - } - - gpu.func @inner_reduction_small_size_1(%arg0: memref<*xf32>, %arg1: memref<*xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %cst = arith.constant dense<0.000000e+00> : vector<1xf32> - %cst_0 = arith.constant dense : vector<1x16xi1> - %cst_1 = arith.constant dense : vector<1x1xi1> - %cst_2 = arith.constant dense<0> : vector<1x1xindex> - %cst_3 = arith.constant dense<0> : vector<1x16xindex> - %cast = memref.cast %arg0 : memref<*xf32> to memref - %cast_4 = memref.cast %arg1 : memref<*xf32> to memref - %0 = xetile.init_tile %cast, %cst_3 : memref, vector<1x16xindex> -> !xetile.tile<1x16xf32, #xetile.tile_attr> - %1 = xetile.load %0, %cst_0 : !xetile.tile<1x16xf32, #xetile.tile_attr>, vector<1x16xi1> -> vector<1x16xf32> - //CHECK: {{.*}} = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32> - //CHECK: {{.*}} = vector.shape_cast %{{.*}} : vector<1x16xf32> to vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf32>, vector<16xf32> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<8xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<8xf32>, vector<8xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7] : vector<8xf32>, vector<8xf32> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<4xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<4xf32>, vector<4xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [2, 3] : vector<4xf32>, vector<4xf32> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<2xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0] : vector<2xf32>, vector<2xf32> - //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [1] : vector<2xf32>, vector<2xf32> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<1xf32> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x1xf32> to vector<1xf32> - //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1xf32> - - %2 = vector.multi_reduction , %1, %cst [1] : vector<1x16xf32> to vector<1xf32> - %3 = vector.shape_cast %2 : vector<1xf32> to vector<1x1xf32> - %4 = xetile.init_tile %cast_4, %cst_2 : memref, vector<1x1xindex> -> !xetile.tile<1x1xf32, #xetile.tile_attr> - xetile.store %3, %4, %cst_1 : vector<1x1xf32>, !xetile.tile<1x1xf32, #xetile.tile_attr>, vector<1x1xi1> - gpu.return - } - - //CHECK: gpu.func @outter_reduction(%[[arg0:.*]]: memref<128x256xf16>, %[[arg1:.*]]: memref<128x256xf16>) { - gpu.func @outter_reduction(%a: memref<128x256xf16>, %b: memref<128x256xf16>) { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - %c0 = arith.constant 0 : index - //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c0]]] : memref<128x256xf16> -> !xegpu.tensor_desc<16x32xf16, #xegpu.block_tdesc_attr> - %t = xetile.init_tile %a[%c0, %c0] : memref<128x256xf16> -> !xetile.tile<16x32xf16> - //CHECK: %[[r1:.*]] = xegpu.load_nd %[[r0]] <{{.*}}> : !xegpu.tensor_desc<16x32xf16, #xegpu.block_tdesc_attr> -> vector<16x32xf16> - //CHECK-COUNT-16: {{.*}} = vector.extract_strided_slice %[[r1]] {offsets = {{.*}}, sizes = [1, 32], strides = [1, 1]} : vector<16x32xf16> to vector<1x32xf16> - %v = xetile.load_tile %t : !xetile.tile<16x32xf16> -> vector<16x32xf16> - //CHECK-COUNT-16: {{.*}} = math.exp {{.*}} : vector<1x32xf16> - %e = math.exp %v: vector<16x32xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x32xf16> - %r = xetile.reduction , %e[0] : vector<16x32xf16> -> vector<1x32xf16> - //CHECK: %[[r49:.*]] = vector.shape_cast {{.*}} : vector<1x32xf16> to vector<4x8xf16> - %c = vector.shape_cast %r: vector<1x32xf16> to vector<4x8xf16> - //CHECK: %[[r50:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c0]], %[[c0]]] : memref<128x256xf16> -> !xegpu.tensor_desc<4x8xf16, #xegpu.block_tdesc_attr> - %s = xetile.init_tile %b[%c0, %c0] : memref<128x256xf16> -> !xetile.tile<4x8xf16> - //CHECK: xegpu.store_nd %[[r49]], %[[r50]] <{{.*}}> : vector<4x8xf16>, !xegpu.tensor_desc<4x8xf16, #xegpu.block_tdesc_attr> - xetile.store_tile %c, %s : vector<4x8xf16>, !xetile.tile<4x8xf16> - gpu.return - } - } -} diff --git a/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_f16_f32.mlir b/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_f16_f32.mlir index 606a61f73..dc30fc3d8 100644 --- a/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_f16_f32.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_f16_f32.mlir @@ -1,4 +1,4 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking --canonicalize \ +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" --canonicalize \ // RUN: --cse --convert-xetile-to-xegpu --cse %s -o -| FileCheck %s // CHECK-LABEL: gpu.module @test_kernel { gpu.module @test_kernel { @@ -12,48 +12,48 @@ gpu.module @test_kernel { %m = arith.muli %block_id_x, %c64 : index %n = arith.muli %block_id_y, %c64 : index - //CHECK: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-8: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-8: xegpu.create_nd_tdesc %[[C]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<64x64xf32> - //CHECK-COUNT-8: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> -> vector<32x16xf32> + //CHECK-COUNT-8: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> -> vector<32x16xf32> //CHECK-COUNT-32: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = {{.*}}, sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<64x64xf32> -> vector<64x64xf32> - //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[A]][{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[A]][{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<64x64xf16> - //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[B]][{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + //CHECK-COUNT-4: xegpu.create_nd_tdesc %[[B]][{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<64x64xf16> %out:3 = scf.for %k = %c0 to %c1024 step %c64 iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) -> (!xetile.tile<64x64xf16>, !xetile.tile<64x64xf16>, vector<64x64xf32>) { - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x16xf16> from vector<2x32x16xf16> - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x16xf16> from vector<2x32x16xf16> - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x16xf16> from vector<2x32x16xf16> - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x16xf16> from vector<2x32x16xf16> //CHECK-COUNT-32: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = {{.*}}, sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> %a_value = xetile.load_tile %a_tile : !xetile.tile<64x64xf16> -> vector<64x64xf16> - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x16xf16> from vector<2x32x16xf16> - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x16xf16> from vector<2x32x16xf16> - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x16xf16> from vector<2x32x16xf16> - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x16xf16> from vector<2x32x16xf16> //CHECK-COUNT-16: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = {{.*}}, sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> %b_value = xetile.load_tile %b_tile : !xetile.tile<64x64xf16> -> vector<64x64xf16> @@ -61,13 +61,13 @@ gpu.module @test_kernel { //CHECK-COUNT-128: {{.*}} = xegpu.dpas {{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value : vector<64x64xf16>, vector<64x64xf16>, vector<64x64xf32> -> vector<64x64xf32> - //CHECK-COUNT-8: {{.*}} = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + //CHECK-COUNT-8: {{.*}} = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c64] : !xetile.tile<64x64xf16> %b_next_tile = xetile.update_tile_offset %b_tile, [%c64, %c0] : !xetile.tile<64x64xf16> scf.yield %a_next_tile, %b_next_tile, %c_new_value : !xetile.tile<64x64xf16>, !xetile.tile<64x64xf16>, vector<64x64xf32> } - //CHECK-COUNT-32: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-32: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> xetile.store_tile %out#2, %c_init_tile: vector<64x64xf32>, !xetile.tile<64x64xf32> gpu.return diff --git a/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_f16_f32_slm.mlir b/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_f16_f32_slm.mlir index ad84fbc62..de526420c 100644 --- a/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_f16_f32_slm.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_f16_f32_slm.mlir @@ -1,4 +1,4 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking --canonicalize \ +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" --canonicalize \ // RUN: --cse --convert-xetile-to-xegpu --cse %s -o -| FileCheck %s diff --git a/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_i8_i32.mlir b/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_i8_i32.mlir index 2df196e78..f894e8807 100644 --- a/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_i8_i32.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_i8_i32.mlir @@ -15,46 +15,46 @@ gpu.module @test_kernel { %m = arith.muli %block_id_x, %c32 : index %n = arith.muli %block_id_y, %c32 : index - //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> - //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<32x16xi32, #xegpu.block_tdesc_attr> + //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> + //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xi32> -> !xegpu.tensor_desc<32x16xi32, #xegpu.block_tdesc_attr> %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<32x32xi32> - //CHECK-COUNT-2: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xi32, #xegpu.block_tdesc_attr> -> vector<32x16xi32> + //CHECK-COUNT-2: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xi32, #xegpu.block_tdesc_attr> -> vector<32x16xi32> //CHECK-COUNT-8: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = {{.*}}, sizes = [8, 16], strides = [1, 1]} : vector<32x16xi32> to vector<8x16xi32> %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<32x32xi32> -> vector<32x32xi32> - //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xi8> -> !xegpu.tensor_desc<32x32xi8, #xegpu.block_tdesc_attr> + //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xi8> -> !xegpu.tensor_desc<32x32xi8, #xegpu.block_tdesc_attr> %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xi8> -> !xetile.tile<32x32xi8> - //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg1]][{{.*}}] : memref<1024x1024xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.block_tdesc_attr> + //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg1]][{{.*}}] : memref<1024x1024xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.block_tdesc_attr> %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xi8> -> !xetile.tile<32x32xi8> //CHECK: {{.*}} = scf.for {{.*}} %out:3 = scf.for %k = %c0 to %c1024 step %c32 iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) -> (!xetile.tile<32x32xi8>, !xetile.tile<32x32xi8>, vector<32x32xi32>) { - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x32xi8, #xegpu.block_tdesc_attr> -> vector<32x32xi8> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x32xi8, #xegpu.block_tdesc_attr> -> vector<32x32xi8> //CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = {{.*}}, sizes = [8, 32], strides = [1, 1]} : vector<32x32xi8> to vector<8x32xi8> %a_value = xetile.load_tile %a_tile : !xetile.tile<32x32xi8> -> vector<32x32xi8> - //CHECK: {{.*}} = xegpu.load_nd {{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xi8, #xegpu.block_tdesc_attr> -> vector<2x32x16xi8> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xi8, #xegpu.block_tdesc_attr> -> vector<2x32x16xi8> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x16xi8> from vector<2x32x16xi8> %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xi8> -> vector<32x32xi8> //CHECK-COUNT-8: {{.*}} = xegpu.dpas {{.*}} : vector<8x32xi8>, vector<32x16xi8>, vector<8x16xi32> -> vector<8x16xi32> %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value : vector<32x32xi8>, vector<32x32xi8>, vector<32x32xi32> -> vector<32x32xi32> - //CHECK: {{.*}} = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<32x32xi8, #xegpu.block_tdesc_attr> + //CHECK: {{.*}} = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<32x32xi8, #xegpu.block_tdesc_attr> %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] : !xetile.tile<32x32xi8> - //CHECK: {{.*}} = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<32x16xi8, #xegpu.block_tdesc_attr> + //CHECK: {{.*}} = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<32x16xi8, #xegpu.block_tdesc_attr> %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] : !xetile.tile<32x32xi8> scf.yield %a_next_tile, %b_next_tile, %c_new_value : !xetile.tile<32x32xi8>, !xetile.tile<32x32xi8>, vector<32x32xi32> } - //CHECK-COUNT-8: xegpu.store_nd {{.*}} : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-8: xegpu.store_nd {{.*}} : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32, #xegpu.block_tdesc_attr> xetile.store_tile %out#2, %c_init_tile {innner_blocks = [8, 16]}: vector<32x32xi32>, !xetile.tile<32x32xi32> gpu.return } diff --git a/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_tf32_tf32.mlir b/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_tf32_tf32.mlir index 30403bdf8..ee580a81d 100755 --- a/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_tf32_tf32.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_gemm_1k_1k_1k_tf32_tf32.mlir @@ -15,49 +15,49 @@ gpu.module @test_kernel { %1 = arith.muli %block_id_y, %c64 : index - //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> + //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg2]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> - //CHECK-COUNT-2: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> -> vector<32x16xf32> + //CHECK-COUNT-2: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> -> vector<32x16xf32> //CHECK-COUNT-8: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = {{.*}}, sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> %3 = xetile.load_tile %2 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf32> -> vector<32x32xf32> - //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xtf32> -> !xegpu.tensor_desc<32x8xtf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xtf32> -> !xegpu.tensor_desc<32x8xtf32, #xegpu.block_tdesc_attr> %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xtf32> -> !xetile.tile<32x32xtf32> - //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg1]][{{.*}}] : memref<1024x1024xtf32> -> !xegpu.tensor_desc<32x16xtf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.create_nd_tdesc %[[arg1]][{{.*}}] : memref<1024x1024xtf32> -> !xegpu.tensor_desc<32x16xtf32, #xegpu.block_tdesc_attr> %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xtf32> -> !xetile.tile<32x32xtf32> //CHECK: {{.*}} = scf.for %6:3 = scf.for %arg3 = %c0 to %c1024 step %c64 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<32x32xtf32>, !xetile.tile<32x32xtf32>, vector<32x32xf32>) { - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x8xtf32, #xegpu.block_tdesc_attr> -> vector<2x32x8xtf32> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x8xtf32, #xegpu.block_tdesc_attr> -> vector<2x32x8xtf32> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x8xtf32> from vector<2x32x8xtf32> - //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x8xtf32, #xegpu.block_tdesc_attr> -> vector<2x32x8xtf32> + //CHECK: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x8xtf32, #xegpu.block_tdesc_attr> -> vector<2x32x8xtf32> //CHECK-COUNT-2: {{.*}} = vector.extract {{.*}} : vector<32x8xtf32> from vector<2x32x8xtf32> //CHECK-COUNT-16: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = {{.*}}, sizes = [8, 8], strides = [1, 1]} : vector<32x8xtf32> to vector<8x8xtf32> %7 = xetile.load_tile %arg4 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xtf32> -> vector<32x32xtf32> - //CHECK-COUNT-2: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xtf32, #xegpu.block_tdesc_attr> -> vector<32x16xtf32> + //CHECK-COUNT-2: {{.*}} = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<32x16xtf32, #xegpu.block_tdesc_attr> -> vector<32x16xtf32> //CHECK-COUNT-8: {{.*}} = vector.extract_strided_slice {{.*}} {offsets = {{.*}}, sizes = [8, 16], strides = [1, 1]} : vector<32x16xtf32> to vector<8x16xtf32> %8 = xetile.load_tile %arg5 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xtf32> -> vector<32x32xtf32> //CHECK-COUNT-32: {{.*}} = xegpu.dpas {{.*}} : vector<8x8xtf32>, vector<8x16xtf32>, vector<8x16xf32> -> vector<8x16xf32> %9 = xetile.tile_mma %7, %8, %arg6 : vector<32x32xtf32>, vector<32x32xtf32>, vector<32x32xf32> -> vector<32x32xf32> - //CHECK-COUNT-2: {{.*}} = xegpu.update_nd_offset %{{.*}}, [{{.*}}] : !xegpu.tensor_desc<32x8xtf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.update_nd_offset %{{.*}}, [{{.*}}] : !xegpu.tensor_desc<32x8xtf32, #xegpu.block_tdesc_attr> %10 = xetile.update_tile_offset %arg4, [%c0, %c64] : !xetile.tile<32x32xtf32> - //CHECK-COUNT-2: {{.*}} = xegpu.update_nd_offset %{{.*}}, [{{.*}}] : !xegpu.tensor_desc<32x16xtf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-2: {{.*}} = xegpu.update_nd_offset %{{.*}}, [{{.*}}] : !xegpu.tensor_desc<32x16xtf32, #xegpu.block_tdesc_attr> %11 = xetile.update_tile_offset %arg5, [%c64, %c0] : !xetile.tile<32x32xtf32> scf.yield %10, %11, %9 : !xetile.tile<32x32xtf32>, !xetile.tile<32x32xtf32>, vector<32x32xf32> } - //CHECK-COUNT-8: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK-COUNT-8: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> xetile.store_tile %6#2, %2 : vector<32x32xf32>, !xetile.tile<32x32xf32> gpu.return } diff --git a/test/Conversion/XeTileToXeGPU/sg_init_tile.mlir b/test/Conversion/XeTileToXeGPU/sg_init_tile.mlir index f2fc09c31..1cf526729 100644 --- a/test/Conversion/XeTileToXeGPU/sg_init_tile.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_init_tile.mlir @@ -1,7 +1,5 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \ -// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s // RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" \ -// RUN: --cse --convert-xetile-to-xegpu="enable-2d-transform=true" --cse %s -verify-diagnostics -o -| FileCheck %s +// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s gpu.module @test_kernel { //CHECK: gpu.func @sg_init_tile(%[[arg0:.*]]: memref<1024x1024xf32>, %[[arg1:.*]]: memref) { diff --git a/test/Conversion/XeTileToXeGPU/sg_load_tile.mlir b/test/Conversion/XeTileToXeGPU/sg_load_tile.mlir index 38f38ac68..9772f4884 100644 --- a/test/Conversion/XeTileToXeGPU/sg_load_tile.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_load_tile.mlir @@ -1,7 +1,6 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \ -// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s // RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" \ -// RUN: --cse --convert-xetile-to-xegpu="enable-2d-transform=true" --cse %s -verify-diagnostics -o -| FileCheck %s +// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s + gpu.module @test_kernel { //CHECK: gpu.func @sg_load_tile(%[[arg0:.*]]: memref<1024x1024xf16>, %[[arg1:.*]]: memref<1024x1024xf16>, %[[arg2:.*]]: memref<1024x1024xf32>) { gpu.func @sg_load_tile(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) { diff --git a/test/Conversion/XeTileToXeGPU/sg_mixed_scf.mlir b/test/Conversion/XeTileToXeGPU/sg_mixed_scf.mlir index cd98fae2e..ccdcf9c5d 100755 --- a/test/Conversion/XeTileToXeGPU/sg_mixed_scf.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_mixed_scf.mlir @@ -1,5 +1,5 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \ +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" \ // RUN: --cse --convert-xetile-to-xegpu --cse --canonicalize %s -verify-diagnostics -o -| FileCheck %s // RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" \ // RUN: --cse --convert-xetile-to-xegpu="enable-2d-transform=true" --cse --canonicalize %s -verify-diagnostics -o -| FileCheck %s diff --git a/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir b/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir index f044f1d03..4a344ba2c 100644 --- a/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_scattered_ops.mlir @@ -1,5 +1,5 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \ -// RUN: --cse --convert-xetile-to-xegpu --cse --canonicalize %s -verify-diagnostics -o -| FileCheck %s +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" \ +// RUN: --cse --convert-xetile-to-xegpu --cse --canonicalize --cse %s -verify-diagnostics -o -| FileCheck %s gpu.module @test { //CHECK-LABEL: @test_init_tile_for_scattered @@ -7,19 +7,19 @@ gpu.module @test { gpu.func @test_init_tile_for_scattered(%arg0: memref<1024xf16>) { //CHECK: %[[cst:.*]] = arith.constant dense : vector<32xi1> //CHECK: %[[cst_0:.*]] = arith.constant dense<1> : vector<32xindex> - //CHECK: %[[r0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst_0]] : memref<1024xf16>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r1:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> - //CHECK: %[[r2:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> - //CHECK: %[[r3:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> - //CHECK: %[[r4:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> - //CHECK: %[[r5:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> - //CHECK: %[[r6:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> - //CHECK: %[[r7:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> - //CHECK: %[[r8:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> - //CHECK: xegpu.store %[[r1]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> - //CHECK: xegpu.store %[[r2]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> - //CHECK: xegpu.store %[[r3]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> - //CHECK: xegpu.store %[[r4]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: %[[r0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst_0]] : memref<1024xf16>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r1:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r2:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r3:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r4:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r5:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r6:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r7:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r8:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: xegpu.store %[[r1]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r2]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r3]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r4]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> %cst = arith.constant dense : vector<4x32xi1> %cst_0 = arith.constant dense<1> : vector<4x32xindex> %0 = xetile.init_tile %arg0, %cst_0 : memref<1024xf16>, vector<4x32xindex> -> !xetile.tile<4x32xf16, #xetile.tile_attr> @@ -29,24 +29,26 @@ gpu.module @test { gpu.return } + //----- + //CHECK-LABEL: @test_init_tile_for_scattered_cache_attr //CHECK-SAME: %[[arg0:.*]]: memref<1024xf16> gpu.func @test_init_tile_for_scattered_cache_attr(%arg0: memref<1024xf16>) { //CHECK: %[[cst:.*]] = arith.constant dense : vector<32xi1> //CHECK: %[[cst_0:.*]] = arith.constant dense<1> : vector<32xindex> - //CHECK: %[[r0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst_0]] : memref<1024xf16>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r1:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> - //CHECK: %[[r2:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> - //CHECK: %[[r3:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> - //CHECK: %[[r4:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> - //CHECK: %[[r5:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> - //CHECK: %[[r6:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> - //CHECK: %[[r7:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> - //CHECK: %[[r8:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> - //CHECK: xegpu.store %[[r1]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> - //CHECK: xegpu.store %[[r2]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> - //CHECK: xegpu.store %[[r3]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> - //CHECK: xegpu.store %[[r4]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: %[[r0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst_0]] : memref<1024xf16>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r1:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r2:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r3:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r4:.*]] = xegpu.load %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> -> vector<32xf16> + //CHECK: %[[r5:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r6:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r7:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: %[[r8:.*]] = xegpu.update_offset %[[r0]], %[[cst_0]] : !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xindex> + //CHECK: xegpu.store %[[r1]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r2]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r3]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> + //CHECK: xegpu.store %[[r4]], %[[r0]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16, #xegpu.scatter_tdesc_attr>, vector<32xi1> %cst = arith.constant dense : vector<4x32xi1> %cst_0 = arith.constant dense<1> : vector<4x32xindex> %0 = xetile.init_tile %arg0, %cst_0 : memref<1024xf16>, vector<4x32xindex> -> !xetile.tile<4x32xf16, #xetile.tile_attr> @@ -56,6 +58,8 @@ gpu.module @test { gpu.return } + //----- + //CHECK-LABEL: @add_kernel //CHECK-SAME: %[[arg0:.*]]: memref<*xf32>, %[[arg1:.*]]: memref<*xf32>, %[[arg2:.*]]: memref<*xf32> gpu.func @add_kernel(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) { @@ -68,23 +72,23 @@ gpu.module @test { //CHECK: %[[r0:.*]] = arith.muli %[[block_id_x]], %[[c1024]] : index //CHECK: %[[r1:.*]] = vector.splat %[[r0]] : vector<1x16xindex> //CHECK: %[[r2:.*]] = vector.shape_cast %[[r1]] : vector<1x16xindex> to vector<16xindex> - //CHECK: %[[r3:.*]] = xegpu.create_tdesc %[[cast]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r4:.*]] = xegpu.load %[[r3]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + //CHECK: %[[r3:.*]] = xegpu.create_tdesc %[[cast]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r4:.*]] = xegpu.load %[[r3]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> //CHECK: %[[r5:.*]] = vector.shape_cast %[[r4]] : vector<16xf32> to vector<1x16xf32> - //CHECK: %[[r6:.*]] = xegpu.load %[[r3]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + //CHECK: %[[r6:.*]] = xegpu.load %[[r3]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> //CHECK: %[[r7:.*]] = vector.shape_cast %[[r6]] : vector<16xf32> to vector<1x16xf32> - //CHECK: %[[r8:.*]] = xegpu.create_tdesc %[[cast_0]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r9:.*]] = xegpu.load %[[r8]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + //CHECK: %[[r8:.*]] = xegpu.create_tdesc %[[cast_0]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r9:.*]] = xegpu.load %[[r8]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> //CHECK: %[[r10:.*]] = vector.shape_cast %[[r9]] : vector<16xf32> to vector<1x16xf32> - //CHECK: %[[r11:.*]] = xegpu.load %[[r8]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + //CHECK: %[[r11:.*]] = xegpu.load %[[r8]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> //CHECK: %[[r12:.*]] = vector.shape_cast %[[r11]] : vector<16xf32> to vector<1x16xf32> //CHECK: %[[r13:.*]] = arith.addf %[[r5]], %[[r10]] : vector<1x16xf32> //CHECK: %[[r14:.*]] = arith.addf %[[r7]], %[[r12]] : vector<1x16xf32> - //CHECK: %[[r15:.*]] = xegpu.create_tdesc %[[cast_1]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r15:.*]] = xegpu.create_tdesc %[[cast_1]], %[[r2]] : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> //CHECK: %[[r16:.*]] = vector.shape_cast %[[r13]] : vector<1x16xf32> to vector<16xf32> - //CHECK: xegpu.store %[[r16]], %[[r15]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + //CHECK: xegpu.store %[[r16]], %[[r15]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> //CHECK: %[[r17:.*]] = vector.shape_cast %[[r14]] : vector<1x16xf32> to vector<16xf32> - //CHECK: xegpu.store %[[r17]], %[[r15]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + //CHECK: xegpu.store %[[r17]], %[[r15]], %[[cst]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> %c1024 = arith.constant 1024 : index %cst = arith.constant dense : vector<1x32xi1> %cast = memref.cast %arg0 : memref<*xf32> to memref diff --git a/test/Conversion/XeTileToXeGPU/sg_scf_for.mlir b/test/Conversion/XeTileToXeGPU/sg_scf_for.mlir index da746ef52..caec6f383 100644 --- a/test/Conversion/XeTileToXeGPU/sg_scf_for.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_scf_for.mlir @@ -1,11 +1,12 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \ +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" \ // RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s gpu.module @test_kernel { //CHECK: gpu.func @sglevel_tiled_gemm(%[[arg0:.*]]: memref<1024x1024xf16>, %[[arg1:.*]]: memref<1024x1024xf16>) gpu.func @sglevel_tiled_gemm(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>) { - //CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<8x32xf16> - //CHECK: %[[r0:.*]] = vector.shuffle %[[cst]], %[[cst]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x32xf16>, vector<8x32xf16> - //CHECK: %[[r1:.*]] = vector.shuffle %[[r0]], %[[r0]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x32xf16>, vector<16x32xf16> + //CHECK: %[[c24:.*]] = arith.constant 24 : index + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[c8:.*]] = arith.constant 8 : index + //CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<32x32xf16> //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c64:.*]] = arith.constant 64 : index //CHECK: %[[c1024:.*]] = arith.constant 1024 : index @@ -14,40 +15,37 @@ gpu.module @test_kernel { %c64 = arith.constant 64 : index %c1024 = arith.constant 1024 : index - //CHECK: %[[r2:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r2:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> %1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> - //CHECK: %[[r3:.*]]:2 = scf.for %[[arg2:.*]] = %[[c0]] to %[[c1024]] step %[[c64]] iter_args(%[[arg3:.*]] = %[[r2]], %[[arg4:.*]] = %[[r1]]) -> (!xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr>, vector<32x32xf16>) { + //CHECK: %[[r3:.*]]:2 = scf.for %[[arg2:.*]] = %[[c0]] to %[[c1024]] step %[[c64]] iter_args(%[[arg3:.*]] = %[[r2]], %[[arg4:.*]] = %[[cst]]) -> (!xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr>, vector<32x32xf16>) { %nexta, %res = scf.for %k= %c0 to %c1024 step %c64 iter_args(%subA = %1, %subB = %cst) -> (!xetile.tile<32x32xf16>, vector<32x32xf16>) { - //CHECK: %[[r12:.*]] = xegpu.load_nd %[[arg3]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> + //CHECK: %[[r12:.*]] = xegpu.load_nd %[[arg3]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> %3 = xetile.load_tile %subA : !xetile.tile<32x32xf16> -> vector<32x32xf16> - //CHECK: %[[r13:.*]] = xegpu.update_nd_offset %[[arg3]], [%[[c0]], %[[c64]]] : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r13:.*]] = xegpu.update_nd_offset %[[arg3]], [%[[c0]], %[[c64]]] : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> %5 = xetile.update_tile_offset %subA, [%c0, %c64]: !xetile.tile<32x32xf16> - //CHECK: scf.yield %[[r13]], %[[r12]] : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr>, vector<32x32xf16> + //CHECK: scf.yield %[[r13]], %[[r12]] : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr>, vector<32x32xf16> scf.yield %5, %3: !xetile.tile<32x32xf16>, vector<32x32xf16> } + + //CHECK: %[[r8:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c0]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r9:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c8]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r10:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c16]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r11:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c24]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> + %5 = xetile.init_tile %b[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + //CHECK: %[[r4:.*]] = vector.extract_strided_slice %[[r3]]#1 {offsets = [0, 0], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> //CHECK: %[[r5:.*]] = vector.extract_strided_slice %[[r3]]#1 {offsets = [8, 0], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> //CHECK: %[[r6:.*]] = vector.extract_strided_slice %[[r3]]#1 {offsets = [16, 0], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> //CHECK: %[[r7:.*]] = vector.extract_strided_slice %[[r3]]#1 {offsets = [24, 0], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> - - //CHECK: %[[r8:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c0]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - //CHECK: %[[c8:.*]] = arith.constant 8 : index - //CHECK: %[[r9:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c8]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - //CHECK: %[[c16:.*]] = arith.constant 16 : index - //CHECK: %[[r10:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c16]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - //CHECK: %[[c24:.*]] = arith.constant 24 : index - //CHECK: %[[r11:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c24]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - %5 = xetile.init_tile %b[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> - - //CHECK: xegpu.store_nd %[[r4]], %[[r8]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - //CHECK: xegpu.store_nd %[[r5]], %[[r9]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - //CHECK: xegpu.store_nd %[[r6]], %[[r10]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - //CHECK: xegpu.store_nd %[[r7]], %[[r11]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %[[r4]], %[[r8]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %[[r5]], %[[r9]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %[[r6]], %[[r10]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %[[r7]], %[[r11]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> xetile.store_tile %res, %5: vector<32x32xf16>, !xetile.tile<32x32xf16> //CHECK: gpu.return diff --git a/test/Conversion/XeTileToXeGPU/sg_softmax.mlir b/test/Conversion/XeTileToXeGPU/sg_softmax.mlir index 60c657297..990d56e0d 100644 --- a/test/Conversion/XeTileToXeGPU/sg_softmax.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_softmax.mlir @@ -1,16 +1,21 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" \ +// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s gpu.module @test_kernel { //CHECK-LABEL: @sglevel_softmax_dim_0 //CHECK-SAME: (%[[arg0:.*]]: memref<1024x1024xf16>) gpu.func @sglevel_softmax_dim_0(%a: memref<1024x1024xf16>) { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c0]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[c24:.*]] = arith.constant 24 : index + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[c8:.*]] = arith.constant 8 : index //CHECK: %[[c32:.*]] = arith.constant 32 : index - //CHECK: %[[r1:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c32]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[c0:.*]] = arith.constant 0 : index + + //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c0]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r1:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c32]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> %1 = xetile.init_tile %a[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16> - //CHECK: %[[r2:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> - //CHECK: %[[r3:.*]] = xegpu.load_nd %[[r1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> + //CHECK: %[[r2:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> + //CHECK: %[[r3:.*]] = xegpu.load_nd %[[r1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> %2 = xetile.load_tile %1: !xetile.tile<32x64xf16> -> vector<32x64xf16> //CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice %[[r2]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> @@ -29,13 +34,13 @@ gpu.module @test_kernel { //CHECK-LABEL: @sglevel_softmax_dim_1 //CHECK-SAME: (%[[arg0:.*]]: memref<1024x1024xf16>) gpu.func @sglevel_softmax_dim_1(%a: memref<1024x1024xf16>) { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c0]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> //CHECK: %[[c32:.*]] = arith.constant 32 : index - //CHECK: %[[r1:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c32]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c0]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r1:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c32]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> %1 = xetile.init_tile %a[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16> - //CHECK: %[[r2:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> - //CHECK: %[[r3:.*]] = xegpu.load_nd %[[r1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> + //CHECK: %[[r2:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> + //CHECK: %[[r3:.*]] = xegpu.load_nd %[[r1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> %2 = xetile.load_tile %1: !xetile.tile<32x64xf16> -> vector<32x64xf16> //CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice %[[r2]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> //CHECK-COUNT-4: {{.*}} = vector.extract_strided_slice %[[r3]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> @@ -198,136 +203,22 @@ gpu.module @test_kernel { //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62] : vector<32xf16>, vector<32xf16> //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63] : vector<32xf16>, vector<32xf16> //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<32xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> + //CHECK-COUNT-32: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<32xf16> + //CHECK-COUNT-32: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> %4 = xetile.reduction , %3 [1]: vector<32x64xf16> -> vector<32x1xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> %5 = xetile.broadcast %4 [1]: vector<32x1xf16> -> vector<32x64xf16> // CHECK-COUNT-8: {{.*}} = arith.divf {{.*}}, {{.*}} : vector<8x32xf16> %6 = arith.divf %3, %5: vector<32x64xf16> diff --git a/test/Conversion/XeTileToXeGPU/sg_store_tile.mlir b/test/Conversion/XeTileToXeGPU/sg_store_tile.mlir index a9a64e0b5..aa07f69d7 100644 --- a/test/Conversion/XeTileToXeGPU/sg_store_tile.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_store_tile.mlir @@ -1,8 +1,5 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \ -// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s - // RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" \ -// RUN: --cse --convert-xetile-to-xegpu="enable-2d-transform=true" --cse %s -verify-diagnostics -o -| FileCheck %s +// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s gpu.module @test_kernel { //CHECK: gpu.func @sg_tiled_store(%[[arg0:.*]]: memref<1024x1024xf32>) { diff --git a/test/Conversion/XeTileToXeGPU/sg_tile_mma.mlir b/test/Conversion/XeTileToXeGPU/sg_tile_mma.mlir index bc1871af5..1c3ba97a4 100644 --- a/test/Conversion/XeTileToXeGPU/sg_tile_mma.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_tile_mma.mlir @@ -1,18 +1,18 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \ -// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking="enable-2d-transform=true" \ +// RUN: --convert-xetile-to-xegpu %s -verify-diagnostics -o -| FileCheck %s gpu.module @test_kernel { //CHECK: s_tiled_gemm(%[[arg0:.*]]: memref<1024x1024xf16>, %[[arg1:.*]]: memref<1024x1024xf16>) gpu.func @s_tiled_gemm(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>) { + //CHECK: %[[c32:.*]] = arith.constant 32 : index //CHECK: %[[c0:.*]] = arith.constant 0 : index - %c0 = arith.constant 0 : index - //CHECK: %[[c64:.*]] = arith.constant 64 : index + %c0 = arith.constant 0 : index %c64 = arith.constant 64 : index - //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> %1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> - //CHECK: %[[r1:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: %[[r1:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK: %[[r2:.*]] = vector.extract %[[r1]][0] : vector<32x16xf16> from vector<2x32x16xf16> //CHECK: %[[r3:.*]] = vector.extract %[[r1]][1] : vector<32x16xf16> from vector<2x32x16xf16> //CHECK: %[[r4:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> @@ -24,15 +24,14 @@ gpu.module @test_kernel { //CHECK: %[[r10:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> //CHECK: %[[r11:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> %2 = xetile.load_tile %1 : !xetile.tile<32x32xf16> -> vector<32x32xf16> - //CHECK: %[[r12:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c64]], %[[c0]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - //CHECK: %[[c32:.*]] = arith.constant 32 : index - //CHECK: %[[r13:.*]] = xegpu.create_nd_tdesc %arg1[%[[c64]], %[[c32]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r12:.*]] = xegpu.create_nd_tdesc %[[arg1]][%[[c64]], %[[c0]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r13:.*]] = xegpu.create_nd_tdesc %arg1[%[[c64]], %[[c32]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> %3 = xetile.init_tile %b[%c64, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16> - //CHECK: %[[r14:.*]] = xegpu.load_nd %[[r12]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: %[[r14:.*]] = xegpu.load_nd %[[r12]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK: %[[r15:.*]] = vector.extract %[[r14]][0] : vector<32x16xf16> from vector<2x32x16xf16> //CHECK: %[[r16:.*]] = vector.extract %[[r14]][1] : vector<32x16xf16> from vector<2x32x16xf16> - //CHECK: %[[r17:.*]] = xegpu.load_nd %[[r13]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + //CHECK: %[[r17:.*]] = xegpu.load_nd %[[r13]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> //CHECK: %[[r18:.*]] = vector.extract %[[r17]][0] : vector<32x16xf16> from vector<2x32x16xf16> //CHECK: %[[r19:.*]] = vector.extract %[[r17]][1] : vector<32x16xf16> from vector<2x32x16xf16> //CHECK: %[[r20:.*]] = vector.extract_strided_slice %[[r15]] {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> diff --git a/test/Conversion/XeTileToXeGPU/sg_tiled_broadcast.mlir b/test/Conversion/XeTileToXeGPU/sg_tiled_broadcast.mlir deleted file mode 100644 index f503b4dea..000000000 --- a/test/Conversion/XeTileToXeGPU/sg_tiled_broadcast.mlir +++ /dev/null @@ -1,95 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s -gpu.module @test_kernel { - // CHECK-LABEL: @sglevel_broadcast_test_1 - gpu.func @sglevel_broadcast_test_1(%arg0: memref<1024x1024xf16>) { - // CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<1x16xf16> - %cst = arith.constant dense<0.000000e+00> : vector<1x4x1x16xf16> - %0 = xetile.tile_unpack %cst {inner_blocks = array}: vector<1x4x1x16xf16> -> vector<1x64xf16> - %1 = xetile.tile_pack %0 {inner_blocks = array}: vector<1x64xf16> -> vector<1x4x1x16xf16> - %2 = xetile.broadcast %1 [0, 2] : vector<1x4x1x16xf16> -> vector<32x4x1x16xf16> - %3 = xetile.tile_unpack %2 {inner_blocks = array}: vector<32x4x1x16xf16> -> vector<32x64xf16> - %4 = xetile.init_tile %arg0[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16, #xetile.tile_attr> - %5 = xetile.tile_pack %3 {inner_blocks = array}: vector<32x64xf16> -> vector<32x4x1x16xf16> - // CHECK-COUNT-128: xegpu.store_nd %[[cst]], %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<1x16xf16>, !xegpu.tensor_desc<1x16xf16, #xegpu.block_tdesc_attr> - xetile.store_tile %5, %4 : vector<32x4x1x16xf16>, !xetile.tile<32x64xf16, #xetile.tile_attr> - gpu.return - } - - // CHECK-LABEL: @sglevel_broadcast_test_2 - gpu.func @sglevel_broadcast_test_2(%arg0: memref<1024x1024xf16>) { - %cst = arith.constant dense<0.000000e+00> : vector<32x1xf16> - %0 = xetile.tile_pack %cst {inner_blocks = array} : vector<32x1xf16> -> vector<32x1x1x1xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - //CHECK: vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: vector.splat %{{.*}} : vector<1x16xf16> - %1 = xetile.broadcast %0 [1, 3] : vector<32x1x1x1xf16> -> vector<32x4x1x16xf16> - %2 = xetile.tile_unpack %1 {inner_blocks = array}: vector<32x4x1x16xf16> -> vector<32x64xf16> - %3 = xetile.init_tile %arg0[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16, #xetile.tile_attr> - %4 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x64xf16> -> vector<32x4x1x16xf16> - // CHECK-COUNT-128: xegpu.store_nd %{{.*}}, %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<1x16xf16>, !xegpu.tensor_desc<1x16xf16, #xegpu.block_tdesc_attr> - xetile.store_tile %4, %3 : vector<32x4x1x16xf16>, !xetile.tile<32x64xf16, #xetile.tile_attr> - gpu.return - } - -} diff --git a/test/Conversion/XeTileToXeGPU/sg_tiled_load_tile.mlir b/test/Conversion/XeTileToXeGPU/sg_tiled_load_tile.mlir deleted file mode 100644 index 92d59f0ee..000000000 --- a/test/Conversion/XeTileToXeGPU/sg_tiled_load_tile.mlir +++ /dev/null @@ -1,20 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu -cse %s -verify-diagnostics -o -| FileCheck %s -gpu.module @test_kernel { - //CHECK: sg_load_tile(%[[ARG:.*]]: memref<1024x1024xf16>) - gpu.func @sg_load_tile(%arg0: memref<1024x1024xf16>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - %c0 = arith.constant 0 : index - - // CHECK: %[[C64:.*]] = arith.constant 64 : index - %c64 = arith.constant 64 : index - - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[%[[C0]], %[[C64]]] : memref<1024x1024xf16> - // CHECK-SAME: !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> - %0 = xetile.init_tile %arg0[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - - // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> - // CHECK-SAME: !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> - %1 = xetile.load_tile %0 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<1x1x32x32xf16> - gpu.return - } -} diff --git a/test/Conversion/XeTileToXeGPU/sg_tiled_scattered_ops.mlir b/test/Conversion/XeTileToXeGPU/sg_tiled_scattered_ops.mlir deleted file mode 100644 index 86bceef91..000000000 --- a/test/Conversion/XeTileToXeGPU/sg_tiled_scattered_ops.mlir +++ /dev/null @@ -1,69 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --cse --canonicalize %s -verify-diagnostics -o -| FileCheck %s - -gpu.module @test { - //CHECK-LABEL: @test_init_tile_for_scattered - //CHECK-SAME: %[[arg0:.*]]: memref<1024xf16> - gpu.func @test_init_tile_for_scattered(%arg0: memref<1024xf16>) { - - //CHECK: %[[cst:.*]] = arith.constant dense<1> : vector<16xindex> - //CHECK: %[[cst_0:.*]] = arith.constant dense : vector<16xi1> - //CHECK: %[[cst_1:.*]] = arith.constant dense<{{.*}}0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15{{.*}}> : vector<1x16xindex> - //CHECK: %[[cst_2:.*]] = arith.constant dense<{{.*}}16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31{{.*}}]> : vector<1x16xindex> - //CHECK: %[[cst_3:.*]] = arith.constant dense<{{.*}}32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47{{.*}}]> : vector<1x16xindex> - //CHECK: %[[cst_4:.*]] = arith.constant dense<{{.*}}48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63{{.*}}]> : vector<1x16xindex> - //CHECK: %[[cst_5:.*]] = arith.constant dense<{{.*}}64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79{{.*}}]> : vector<1x16xindex> - //CHECK: %[[cst_6:.*]] = arith.constant dense<{{.*}}80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95{{.*}}]> : vector<1x16xindex> - //CHECK: %[[cst_7:.*]] = arith.constant dense<{{.*}}96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111{{.*}}> : vector<1x16xindex> - //CHECK: %[[cst_8:.*]] = arith.constant dense<{{.*}}112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127{{.*}}> : vector<1x16xindex> - //CHECK: %[[r0:.*]] = vector.shape_cast %[[cst_1]] : vector<1x16xindex> to vector<16xindex> - //CHECK: %[[r1:.*]] = xegpu.create_tdesc %[[arg0]], %[[r0]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r2:.*]] = vector.shape_cast %[[cst_2]] : vector<1x16xindex> to vector<16xindex> - //CHECK: %[[r3:.*]] = xegpu.create_tdesc %[[arg0]], %[[r2]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r4:.*]] = vector.shape_cast %[[cst_3]] : vector<1x16xindex> to vector<16xindex> - //CHECK: %[[r5:.*]] = xegpu.create_tdesc %[[arg0]], %[[r4]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r6:.*]] = vector.shape_cast %[[cst_4]] : vector<1x16xindex> to vector<16xindex> - //CHECK: %[[r7:.*]] = xegpu.create_tdesc %[[arg0]], %[[r6]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r8:.*]] = vector.shape_cast %[[cst_5]] : vector<1x16xindex> to vector<16xindex> - //CHECK: %[[r9:.*]] = xegpu.create_tdesc %[[arg0]], %[[r8]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r10:.*]] = vector.shape_cast %[[cst_6]] : vector<1x16xindex> to vector<16xindex> - //CHECK: %[[r11:.*]] = xegpu.create_tdesc %[[arg0]], %[[r10]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r12:.*]] = vector.shape_cast %[[cst_7]] : vector<1x16xindex> to vector<16xindex> - //CHECK: %[[r13:.*]] = xegpu.create_tdesc %[[arg0]], %[[r12]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r14:.*]] = vector.shape_cast %[[cst_8]] : vector<1x16xindex> to vector<16xindex> - //CHECK: %[[r15:.*]] = xegpu.create_tdesc %[[arg0]], %[[r14]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> - //CHECK: %[[r16:.*]] = xegpu.load %[[r1]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> - //CHECK: %[[r17:.*]] = xegpu.load %[[r3]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> - //CHECK: %[[r18:.*]] = xegpu.load %[[r5]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> - //CHECK: %[[r19:.*]] = xegpu.load %[[r7]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> - //CHECK: %[[r20:.*]] = xegpu.load %[[r9]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> - //CHECK: %[[r21:.*]] = xegpu.load %[[r11]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> - //CHECK: %[[r22:.*]] = xegpu.load %[[r13]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> - //CHECK: %[[r23:.*]] = xegpu.load %[[r15]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> - //CHECK: %[[r24:.*]] = xegpu.update_offset %[[r1]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> - //CHECK: %[[r25:.*]] = xegpu.update_offset %[[r3]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> - //CHECK: %[[r26:.*]] = xegpu.update_offset %[[r5]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> - //CHECK: %[[r27:.*]] = xegpu.update_offset %[[r7]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> - //CHECK: %[[r28:.*]] = xegpu.update_offset %[[r9]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> - //CHECK: %[[r29:.*]] = xegpu.update_offset %[[r11]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> - //CHECK: %[[r30:.*]] = xegpu.update_offset %[[r13]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> - //CHECK: %[[r31:.*]] = xegpu.update_offset %[[r15]], %[[cst]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> - //CHECK: xegpu.store %[[r16]], %[[r1]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> - //CHECK: xegpu.store %[[r17]], %[[r3]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> - //CHECK: xegpu.store %[[r18]], %[[r5]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> - //CHECK: xegpu.store %[[r19]], %[[r7]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> - //CHECK: xegpu.store %[[r20]], %[[r9]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> - //CHECK: xegpu.store %[[r21]], %[[r11]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> - //CHECK: xegpu.store %[[r22]], %[[r13]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> - //CHECK: xegpu.store %[[r23]], %[[r15]], %[[cst_0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> - - - %cst = arith.constant dense : vector<4x2x1x16xi1> - %cst_0 = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : vector<4x2x1x16xindex> - %offsets = arith.constant dense<1> : vector<4x2x1x16xindex> - %0 = xetile.init_tile %arg0, %cst_0 : memref<1024xf16>, vector<4x2x1x16xindex> -> !xetile.tile<4x32xf16, #xetile.tile_attr> - %1 = xetile.load %0, %cst : !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x2x1x16xi1> -> vector<4x2x1x16xf16> - %2 = xetile.update_tile_offset %0, %offsets : !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x2x1x16xindex> - xetile.store %1, %0, %cst : vector<4x2x1x16xf16>, !xetile.tile<4x32xf16, #xetile.tile_attr>, vector<4x2x1x16xi1> - gpu.return - } -} diff --git a/test/Conversion/XeTileToXeGPU/sg_tiled_scf_for.mlir b/test/Conversion/XeTileToXeGPU/sg_tiled_scf_for.mlir deleted file mode 100644 index 82d87f6d4..000000000 --- a/test/Conversion/XeTileToXeGPU/sg_tiled_scf_for.mlir +++ /dev/null @@ -1,57 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s - gpu.module @test_kernel { - // CHECK: sg_scf_for(%[[ARG0:.*]]: memref<1024x1024xf16>, %[[ARG1:.*]]: memref<1024x1024xf16>) - gpu.func @sg_scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>) { - // CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<32x32xf16> - %cst = arith.constant dense<0.000000e+00> : vector<1x1x32x32xf16> - - // CHECK: %[[c0:.*]] = arith.constant 0 : index - %c0 = arith.constant 0 : index - - // CHECK: %[[c64:.*]] = arith.constant 64 : index - %c64 = arith.constant 64 : index - - // CHECK: %[[c1024:.*]] = arith.constant 1024 : index - %c1024 = arith.constant 1024 : index - - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[c0]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> - %0 = xetile.init_tile %arg0[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - - // CHECK: %[[R1:.*]]:2 = scf.for %[[arg2:.*]] = %[[c0]] to %[[c1024]] step %[[c64]] - // CHECK-SAME: iter_args(%[[arg3:.*]] = %[[R0]], %[[arg4:.*]] = %[[cst]]) -> (!xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr>, vector<32x32xf16>) - %1:2 = scf.for %arg2 = %c0 to %c1024 step %c64 iter_args(%arg3 = %0, %arg4 = %cst) -> (!xetile.tile<32x32xf16, #xetile.tile_attr>, vector<1x1x32x32xf16>) { - - // CHECK: %[[R10:.*]] = xegpu.load_nd %[[arg3]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> - %5 = xetile.load_tile %arg3 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<1x1x32x32xf16> - - // CHECK: %[[R11:.*]] = xegpu.update_nd_offset %[[arg3]], [%[[c0]], %[[c64]]] : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> - %6 = xetile.update_tile_offset %arg3, [%c0, %c64] : !xetile.tile<32x32xf16, #xetile.tile_attr> - - // CHECK: scf.yield %[[R11]], %[[R10]] : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr>, vector<32x32xf16> - scf.yield %6, %5 : !xetile.tile<32x32xf16, #xetile.tile_attr>, vector<1x1x32x32xf16> - } - - // CHECK: %[[R2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[c0]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - // CHECK: %[[c8:.*]] = arith.constant 8 : index - // CHECK: %[[R3:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[c8]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - // CHECK: %[[c16:.*]] = arith.constant 16 : index - // CHECK: %[[R4:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[c16]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - // CHECK: %[[c24:.*]] = arith.constant 24 : index - // CHECK: %[[R5:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[c24]], %[[c64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - %2 = xetile.init_tile %arg1[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - - // CHECK: %[[R6:.*]] = vector.extract_strided_slice %[[R1]]#1 {offsets = [0, 0], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> - // CHECK: %[[R7:.*]] = vector.extract_strided_slice %[[R1]]#1 {offsets = [8, 0], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> - // CHECK: %[[R8:.*]] = vector.extract_strided_slice %[[R1]]#1 {offsets = [16, 0], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> - // CHECK: %[[R9:.*]] = vector.extract_strided_slice %[[R1]]#1 {offsets = [24, 0], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> - %3 = xetile.tile_unpack %1#1 {inner_blocks = array} : vector<1x1x32x32xf16> -> vector<32x32xf16> - %4 = xetile.tile_pack %3 {inner_blocks = array}: vector<32x32xf16> -> vector<4x1x8x32xf16> - - // CHECK: xegpu.store_nd %[[R6]], %[[R2]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - // CHECK: xegpu.store_nd %[[R7]], %[[R3]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - // CHECK: xegpu.store_nd %[[R8]], %[[R4]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - // CHECK: xegpu.store_nd %[[R9]], %[[R5]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16, #xegpu.block_tdesc_attr> - xetile.store_tile %4, %2 : vector<4x1x8x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr> - gpu.return - } - } diff --git a/test/Conversion/XeTileToXeGPU/sg_tiled_softmax.mlir b/test/Conversion/XeTileToXeGPU/sg_tiled_softmax.mlir deleted file mode 100644 index 9929788d0..000000000 --- a/test/Conversion/XeTileToXeGPU/sg_tiled_softmax.mlir +++ /dev/null @@ -1,346 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s -gpu.module @test_kernel { - //CHECK-LABEL: @sglevel_softmax_dim_0 - //CHECK-SAME: (%[[arg0:.*]]: memref<1024x1024xf16>) - gpu.func @sglevel_softmax_dim_0(%arg0: memref<1024x1024xf16>) { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c0]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> - //CHECK: %[[c32:.*]] = arith.constant 32 : index - //CHECK: %[[r1:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c32]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> - %0 = xetile.init_tile %arg0[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16, #xetile.tile_attr> - - //CHECK: %[[r2:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> - //CHECK: %[[r3:.*]] = xegpu.load_nd %[[r1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> - %1 = xetile.load_tile %0 {padding = 0.000000e+00 : f32} : !xetile.tile<32x64xf16, #xetile.tile_attr> -> vector<1x2x32x32xf16> - - //CHECK-COUNT-32: {{.*}} = vector.extract_strided_slice %[[r2]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK-COUNT-32: {{.*}} = vector.extract_strided_slice %[[r3]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK-COUNT-128: {{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf16> to vector<1x16xf16> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x2x32x32xf16> -> vector<32x64xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x64xf16> -> vector<32x4x1x16xf16> - //CHECK-COUNT-128: {{.*}} = math.exp %{{.*}} : vector<1x16xf16> - %4 = math.exp %3 : vector<32x4x1x16xf16> - //CHECK-COUNT-124: arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - %5 = xetile.reduction , %4 [0, 2] : vector<32x4x1x16xf16> -> vector<1x4x1x16xf16> - %6 = xetile.broadcast %5 [0, 2] : vector<1x4x1x16xf16> -> vector<32x4x1x16xf16> - //CHECK-COUNT-128: arith.divf {{.*}}, {{.*}} : vector<1x16xf16> - %7 = arith.divf %4, %6 : vector<32x4x1x16xf16> - %8 = xetile.tile_unpack %7 {inner_blocks = array}: vector<32x4x1x16xf16> -> vector<32x64xf16> - %9 = xetile.init_tile %arg0[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16, #xetile.tile_attr> - %10 = xetile.tile_pack %8 {inner_blocks = array}: vector<32x64xf16> -> vector<4x2x8x32xf16> - xetile.store_tile %10, %9 : vector<4x2x8x32xf16>, !xetile.tile<32x64xf16, #xetile.tile_attr> - gpu.return - } - //CHECK-LABEL: @sglevel_softmax_dim_1 - //CHECK-SAME: (%[[arg0:.*]]: memref<1024x1024xf16>) - gpu.func @sglevel_softmax_dim_1(%arg0: memref<1024x1024xf16>) { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c0]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> - //CHECK: %[[c32:.*]] = arith.constant 32 : index - //CHECK: %[[r1:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c32]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> - %0 = xetile.init_tile %arg0[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16, #xetile.tile_attr> - //CHECK: %[[r2:.*]] = xegpu.load_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> - //CHECK: %[[r3:.*]] = xegpu.load_nd %[[r1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x32xf16, #xegpu.block_tdesc_attr> -> vector<32x32xf16> - %1 = xetile.load_tile %0 {padding = 0.000000e+00 : f32} : !xetile.tile<32x64xf16, #xetile.tile_attr> -> vector<1x2x32x32xf16> - //CHECK-COUNT-32: {{.*}} = vector.extract_strided_slice %[[r2]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK-COUNT-32: {{.*}} = vector.extract_strided_slice %[[r3]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK-COUNT-128: {{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf16> to vector<1x16xf16> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x2x32x32xf16> -> vector<32x64xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array}: vector<32x64xf16> -> vector<32x4x1x16xf16> - //CHECK-COUNT-128: {{.*}} = math.exp %{{.*}} : vector<1x16xf16> - %4 = math.exp %3 : vector<32x4x1x16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - //CHECK-COUNT-3: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x16xf16> to vector<16xf16> - - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = vector.shuffle {{.*}}, {{.*}} [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31] : vector<16xf16>, vector<16xf16> - //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<16xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = vector.extractelement {{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf16> - %5 = xetile.reduction , %4 [1, 3] : vector<32x4x1x16xf16> -> vector<32x1x1x1xf16> - - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - //CHECK: {{.*}} = vector.extract {{.*}}[0, 0] : f16 from vector<1x1xf16> - //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x16xf16> - %6 = xetile.broadcast %5 [1, 3] : vector<32x1x1x1xf16> -> vector<32x4x1x16xf16> - // CHECK-COUNT-128: {{.*}} = arith.divf {{.*}}, {{.*}} : vector<1x16xf16> - %7 = arith.divf %4, %6 : vector<32x4x1x16xf16> - %8 = xetile.tile_unpack %7 {inner_blocks = array}: vector<32x4x1x16xf16> -> vector<32x64xf16> - %9 = xetile.init_tile %arg0[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16, #xetile.tile_attr> - %10 = xetile.tile_pack %8 {inner_blocks = array}: vector<32x64xf16> -> vector<4x2x8x32xf16> - xetile.store_tile %10, %9 : vector<4x2x8x32xf16>, !xetile.tile<32x64xf16, #xetile.tile_attr> - gpu.return - } -} diff --git a/test/Conversion/XeTileToXeGPU/sg_tiled_store_tile.mlir b/test/Conversion/XeTileToXeGPU/sg_tiled_store_tile.mlir deleted file mode 100644 index 52ba152c0..000000000 --- a/test/Conversion/XeTileToXeGPU/sg_tiled_store_tile.mlir +++ /dev/null @@ -1,58 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s - gpu.module @test_kernel { - // CHECK: sg_tiled_store(%[[arg0:.*]]: memref<1024x1024xf32>) - gpu.func @sg_tiled_store(%arg0: memref<1024x1024xf32>) { - // CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<32x16xf32> - %cst = arith.constant dense<0.000000e+00> : vector<1x2x32x16xf32> - - // CHECK: %[[c0:.*]] = arith.constant 0 : index - // CHECK: %[[c32:.*]] = arith.constant 32 : index - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c32]]] : memref<1024x1024xf32> - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: %[[c48:.*]] = arith.constant 48 : index - // CHECK: %[[R1:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c0]], %[[c48]]] : memref<1024x1024xf32> - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: %[[c8:.*]] = arith.constant 8 : index - // CHECK: %[[R2:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c8]], %[[c32]]] : memref<1024x1024xf32> - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: %[[R3:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c8]], %[[c48]]] : memref<1024x1024xf32> - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: %[[c16:.*]] = arith.constant 16 : index - // CHECK: %[[R4:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c16]], %[[c32]]] : memref<1024x1024xf32> - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: %[[R5:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c16]], %[[c48]]] : memref<1024x1024xf32> - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: %[[c24:.*]] = arith.constant 24 : index - // CHECK: %[[R6:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c24]], %[[c32]]] : memref<1024x1024xf32> - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: %[[R7:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[c24]], %[[c48]]] : memref<1024x1024xf32> - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %0 = xetile.init_tile %arg0[0, 32] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32, #xetile.tile_attr> - - // CHECK: %[[R8:.*]] = vector.extract_strided_slice %[[cst]] {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - // CHECK: %[[R9:.*]] = vector.extract_strided_slice %[[cst]] {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - // CHECK: %[[R10:.*]] = vector.extract_strided_slice %[[cst]] {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - // CHECK: %[[R11:.*]] = vector.extract_strided_slice %[[cst]] {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %1 = xetile.tile_unpack %cst {inner_blocks = array} : vector<1x2x32x16xf32> -> vector<32x32xf32> - %2 = xetile.tile_pack %1 {inner_blocks = array} : vector<32x32xf32> -> vector<4x2x8x16xf32> - - // CHECK: xegpu.store_nd %[[R8]], %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> - // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: xegpu.store_nd %[[R8]], %[[R1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> - // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: xegpu.store_nd %[[R9]], %[[R2]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> - // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: xegpu.store_nd %[[R9]], %[[R3]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> - // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: xegpu.store_nd %[[R10]], %[[R4]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> - // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: xegpu.store_nd %[[R10]], %[[R5]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> - // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: xegpu.store_nd %[[R11]], %[[R6]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> - // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - // CHECK: xegpu.store_nd %[[R11]], %[[R7]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> - // CHECK-SAME: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xetile.store_tile %2, %0 : vector<4x2x8x16xf32>, !xetile.tile<32x32xf32, #xetile.tile_attr> - gpu.return - } - } diff --git a/test/Conversion/XeTileToXeGPU/sg_tiled_tile_mma.mlir b/test/Conversion/XeTileToXeGPU/sg_tiled_tile_mma.mlir deleted file mode 100644 index 6079e9f29..000000000 --- a/test/Conversion/XeTileToXeGPU/sg_tiled_tile_mma.mlir +++ /dev/null @@ -1,91 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s -gpu.module @test_kernel { - // CHECK: sg_tiled_gemm(%[[ARG0:.*]]: memref<1024x1024xf16>, %[[ARG1:.*]]: memref<1024x1024xf16>) - gpu.func @sg_tiled_gemm(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - %c0 = arith.constant 0 : index - - // CHECK: %[[C64:.*]] = arith.constant 64 : index - %c64 = arith.constant 64 : index - - // CHECK: %[[REG0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[C0]], %[[C64]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %0 = xetile.init_tile %arg0[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - - // CHECK: %[[REG1:.*]] = xegpu.load_nd %[[REG0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> - // CHECK: %[[REG2:.*]] = vector.extract %[[REG1]][0] : vector<32x16xf16> from vector<2x32x16xf16> - // CHECK: %[[REG3:.*]] = vector.extract %[[REG1]][1] : vector<32x16xf16> from vector<2x32x16xf16> - %1 = xetile.load_tile %0 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<1x2x32x16xf16> - - - // CHECK: %[[REG4:.*]] = xegpu.create_nd_tdesc %arg1[%[[C64]], %[[C0]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - // CHECK: %[[C32:.*]] = arith.constant 32 : index - // CHECK: %[[REG5:.*]] = xegpu.create_nd_tdesc %arg1[%[[C64]], %[[C32]]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %2 = xetile.init_tile %arg1[%c64, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16, #xetile.tile_attr> - - // CHECK: %[[REG6:.*]] = xegpu.load_nd %[[REG4]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> - // CHECK: %[[REG7:.*]] = vector.extract %[[REG6]][0] : vector<32x16xf16> from vector<2x32x16xf16> - // CHECK: %[[REG8:.*]] = vector.extract %[[REG6]][1] : vector<32x16xf16> from vector<2x32x16xf16> - // CHECK: %[[REG9:.*]] = xegpu.load_nd %[[REG5]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> - // CHECK: %[[REG10:.*]] = vector.extract %[[REG9]][0] : vector<32x16xf16> from vector<2x32x16xf16> - // CHECK: %[[REG11:.*]] = vector.extract %[[REG9]][1] : vector<32x16xf16> from vector<2x32x16xf16> - %3 = xetile.load_tile %2 {padding = 0.000000e+00 : f32} : !xetile.tile<32x64xf16, #xetile.tile_attr> -> vector<1x4x32x16xf16> - - //CHECK: %[[REG12:.*]] = vector.extract_strided_slice %[[REG2]] {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - //CHECK: %[[REG13:.*]] = vector.extract_strided_slice %[[REG2]] {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - //CHECK: %[[REG14:.*]] = vector.extract_strided_slice %[[REG2]] {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - //CHECK: %[[REG15:.*]] = vector.extract_strided_slice %[[REG2]] {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - //CHECK: %[[REG16:.*]] = vector.extract_strided_slice %[[REG3]] {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - //CHECK: %[[REG17:.*]] = vector.extract_strided_slice %[[REG3]] {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - //CHECK: %[[REG18:.*]] = vector.extract_strided_slice %[[REG3]] {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - //CHECK: %[[REG19:.*]] = vector.extract_strided_slice %[[REG3]] {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - %4 = xetile.tile_unpack %1 {inner_blocks = array} : vector<1x2x32x16xf16> -> vector<32x32xf16> - %5 = xetile.tile_pack %4 {inner_blocks = array} : vector<32x32xf16> -> vector<4x2x8x16xf16> - - // CHECK: %[[REG20:.*]] = vector.extract_strided_slice %[[REG7]] {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - // CHECK: %[[REG21:.*]] = vector.extract_strided_slice %[[REG7]] {offsets = [16, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - // CHECK: %[[REG22:.*]] = vector.extract_strided_slice %[[REG8]] {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - // CHECK: %[[REG23:.*]] = vector.extract_strided_slice %[[REG8]] {offsets = [16, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - // CHECK: %[[REG24:.*]] = vector.extract_strided_slice %[[REG10]] {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - // CHECK: %[[REG25:.*]] = vector.extract_strided_slice %[[REG10]] {offsets = [16, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - // CHECK: %[[REG26:.*]] = vector.extract_strided_slice %[[REG11]] {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - // CHECK: %[[REG27:.*]] = vector.extract_strided_slice %[[REG11]] {offsets = [16, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - %6 = xetile.tile_unpack %3 {inner_blocks = array} : vector<1x4x32x16xf16> -> vector<32x64xf16> - %7 = xetile.tile_pack %6 {inner_blocks = array} : vector<32x64xf16> -> vector<2x4x16x16xf16> - - // CHECK: %[[REG28:.*]] = xegpu.dpas %[[REG12]], %[[REG20]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG29:.*]] = xegpu.dpas %[[REG16]], %[[REG21]], %[[REG28]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG30:.*]] = xegpu.dpas %[[REG12]], %[[REG22]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG31:.*]] = xegpu.dpas %[[REG16]], %[[REG23]], %[[REG30]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG32:.*]] = xegpu.dpas %[[REG12]], %[[REG24]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG33:.*]] = xegpu.dpas %[[REG16]], %[[REG25]], %[[REG32]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG34:.*]] = xegpu.dpas %[[REG12]], %[[REG26]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG35:.*]] = xegpu.dpas %[[REG16]], %[[REG27]], %[[REG34]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG36:.*]] = xegpu.dpas %[[REG13]], %[[REG20]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG37:.*]] = xegpu.dpas %[[REG17]], %[[REG21]], %[[REG36]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG38:.*]] = xegpu.dpas %[[REG13]], %[[REG22]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG39:.*]] = xegpu.dpas %[[REG17]], %[[REG23]], %[[REG38]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG40:.*]] = xegpu.dpas %[[REG13]], %[[REG24]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG41:.*]] = xegpu.dpas %[[REG17]], %[[REG25]], %[[REG40]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG42:.*]] = xegpu.dpas %[[REG13]], %[[REG26]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG43:.*]] = xegpu.dpas %[[REG17]], %[[REG27]], %[[REG42]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG44:.*]] = xegpu.dpas %[[REG14]], %[[REG20]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG45:.*]] = xegpu.dpas %[[REG18]], %[[REG21]], %[[REG44]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG46:.*]] = xegpu.dpas %[[REG14]], %[[REG22]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG47:.*]] = xegpu.dpas %[[REG18]], %[[REG23]], %[[REG46]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG48:.*]] = xegpu.dpas %[[REG14]], %[[REG24]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG49:.*]] = xegpu.dpas %[[REG18]], %[[REG25]], %[[REG48]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG50:.*]] = xegpu.dpas %[[REG14]], %[[REG26]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG51:.*]] = xegpu.dpas %[[REG18]], %[[REG27]], %[[REG50]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG52:.*]] = xegpu.dpas %[[REG15]], %[[REG20]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG53:.*]] = xegpu.dpas %[[REG19]], %[[REG21]], %[[REG52]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG54:.*]] = xegpu.dpas %[[REG15]], %[[REG22]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG55:.*]] = xegpu.dpas %[[REG19]], %[[REG23]], %[[REG54]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG56:.*]] = xegpu.dpas %[[REG15]], %[[REG24]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG57:.*]] = xegpu.dpas %[[REG19]], %[[REG25]], %[[REG56]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - // CHECK: %[[REG58:.*]] = xegpu.dpas %[[REG15]], %[[REG26]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> - // CHECK: %[[REG59:.*]] = xegpu.dpas %[[REG19]], %[[REG27]], %[[REG58]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %8 = xetile.tile_mma %5, %7 : vector<4x2x8x16xf16>, vector<2x4x16x16xf16> -> vector<4x4x8x16xf32> - - gpu.return - } -} diff --git a/test/Conversion/XeTileToXeGPU/test_order.mlir b/test/Conversion/XeTileToXeGPU/test_order.mlir index 7bb5b861d..2672328ef 100644 --- a/test/Conversion/XeTileToXeGPU/test_order.mlir +++ b/test/Conversion/XeTileToXeGPU/test_order.mlir @@ -1,18 +1,19 @@ -// RUN: imex-opt --split-input-file --xetile-canonicalization --xetile-blocking --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s +// RUN: imex-opt --split-input-file --xetile-canonicalization --xetile-blocking="enable-2d-transform=true" \ +// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s // CHECK-LABEL: @test_func // CHECK-SAME: (%[[ARG0:.*]]: memref<128x64xf16>, %[[ARG1:.*]]: memref<64x128xf16, strided<[1, 64]>>) { // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index // CHECK: %[[R_CAST:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<64x128xf16, strided<[1, 64]>> to memref<128x64xf16, strided<[64, 1]>> -// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[R_CAST]][%[[C0]], %[[C0]]] : memref<128x64xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[R_CAST]][%[[C16]], %[[C0]]] : memref<128x64xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -// CHECK: %[[T8:.*]] = xegpu.load_nd %[[T1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> -// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T2]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> -// CHECK: %[[T19:.*]] = xegpu.update_nd_offset %[[T1]], [%[[C0]], %[[C16]]] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -// CHECK: %[[T20:.*]] = xegpu.update_nd_offset %[[T2]], [%[[C0]], %[[C16]]] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -// CHECK: %[[T26:.*]] = xegpu.load_nd %[[T19]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> -// CHECK: %[[T27:.*]] = xegpu.load_nd %[[T20]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> +// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[R_CAST]][%[[C0]], %[[C0]]] : memref<128x64xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[R_CAST]][%[[C16]], %[[C0]]] : memref<128x64xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T8:.*]] = xegpu.load_nd %[[T1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> +// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T2]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> +// CHECK: %[[T19:.*]] = xegpu.update_nd_offset %[[T1]], [%[[C0]], %[[C16]]] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T20:.*]] = xegpu.update_nd_offset %[[T2]], [%[[C0]], %[[C16]]] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T26:.*]] = xegpu.load_nd %[[T19]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> +// CHECK: %[[T27:.*]] = xegpu.load_nd %[[T20]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> gpu.module @test_kernel { func.func @test_func(%A : memref<128x64xf16>, %B : memref<64x128xf16, strided<[1, 64], offset: 0>>) { %c0 = arith.constant 0 : index diff --git a/test/Conversion/XeTileToXeGPU/unit_tests.mlir b/test/Conversion/XeTileToXeGPU/unit_tests.mlir new file mode 100644 index 000000000..c3cf82b9c --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/unit_tests.mlir @@ -0,0 +1,75 @@ +// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @test_kernel { + //CHECK-LABEL: gpu.func @sg_store_tile + //CHECK-SAME: (%[[arg0:.*]]: memref<32x32xf32>) { + gpu.func @sg_store_tile(%arg0: memref<32x32xf32>) { + //CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %[[cst]], %[[r0]] {{.*}}: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> + %0 = xetile.init_tile %arg0[0, 0] : memref<32x32xf32> -> !xetile.tile<8x16xf32> + xetile.store_tile %cst, %0 : vector<8x16xf32>, !xetile.tile<8x16xf32> + gpu.return + } + + //----- + + // CHECK: gpu.func @sg_tile_mma(%[[arg0:.*]]: memref<8x16xf16>, %[[arg1:.*]]: memref<16x16xf16>, %[[arg2:.*]]: memref<8x16xf32>) + gpu.func @sg_tile_mma(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { + //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r1:.*]] = xegpu.load_nd %[[r0]] {{.*}}: !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr> -> vector<8x16xf16> + //CHECK: %[[r2:.*]] = xegpu.create_nd_tdesc %[[arg1]][0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + //CHECK: %[[r3:.*]] = xegpu.load_nd %[[r2]] {{.*}}: !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> + //CHECK: %[[r4:.*]] = xegpu.dpas %[[r1]], %[[r3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + //CHECK: %[[r5:.*]] = xegpu.create_nd_tdesc %[[arg2]][0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: xegpu.store_nd %[[r4]], %[[r5]] {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + %0 = xetile.init_tile %arg0[0, 0] : memref<8x16xf16> -> !xetile.tile<8x16xf16> + %1 = xetile.load_tile %0 {padding = 0.000000e+00 : f32} : !xetile.tile<8x16xf16> -> vector<8x16xf16> + %2 = xetile.init_tile %arg1[0, 0] : memref<16x16xf16> -> !xetile.tile<16x16xf16> + %3 = xetile.load_tile %2 {padding = 0.000000e+00 : f32} : !xetile.tile<16x16xf16> -> vector<16x16xf16> + %4 = xetile.tile_mma %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + %5 = xetile.init_tile %arg2[0, 0] : memref<8x16xf32> -> !xetile.tile<8x16xf32> + xetile.store_tile %4, %5 : vector<8x16xf32>, !xetile.tile<8x16xf32> + gpu.return + } + + //----- + //CHECK: gpu.func @sg_prefetch_tile(%[[arg0:.*]]: memref<2x64xf16>) + gpu.func @sg_prefetch_tile(%a: memref<2x64xf16>) { + //CHECK: %[[r0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<2x64xf16> -> !xegpu.tensor_desc<2x16xf16, #xegpu.block_tdesc_attr> + //CHECK: xegpu.prefetch_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<2x16xf16, #xegpu.block_tdesc_attr> + //CHECK: xegpu.prefetch_nd %[[r0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<2x16xf16, #xegpu.block_tdesc_attr> + %0 = xetile.init_tile %a[0, 0] : memref<2x64xf16> -> !xetile.tile<2x16xf16> + xetile.prefetch_tile %0 : !xetile.tile<2x16xf16> + xetile.prefetch_tile %0 {l1_hint = #xetile.cache_hint, l3_hint = #xetile.cache_hint} : !xetile.tile<2x16xf16> + gpu.return + } + + //----- + + //CHECK: gpu.func @sg_scattered_ops(%[[arg0:.*]]: memref<1024xf16>) + gpu.func @sg_scattered_ops(%arg0: memref<1024xf16>) { + //CHECK: %[[cst:.*]] = arith.constant dense : vector<1x16xi1> + //CHECK: %[[cst_0:.*]] = arith.constant dense<{{.*}}> : vector<1x16xindex> + //CHECK: %[[cst_1:.*]] = arith.constant dense<16> : vector<1x16xindex> + //CHECK: %[[r0:.*]] = vector.shape_cast %[[cst_0]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r1:.*]] = xegpu.create_tdesc %[[arg0]], %[[r0]] : memref<1024xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> + //CHECK: %[[r2:.*]] = vector.shape_cast %[[cst]] : vector<1x16xi1> to vector<16xi1> + //CHECK: %[[r3:.*]] = xegpu.load %[[r1]], %[[r2]] {{.*}} : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> + //CHECK: %[[r4:.*]] = vector.shape_cast %[[r3]] : vector<16xf16> to vector<1x16xf16> + //CHECK: %[[r5:.*]] = vector.shape_cast %[[cst_1]] : vector<1x16xindex> to vector<16xindex> + //CHECK: %[[r6:.*]] = xegpu.update_offset %[[r1]], %[[r5]] : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xindex> + //CHECK: %[[r7:.*]] = vector.shape_cast %[[cst]] : vector<1x16xi1> to vector<16xi1> + //CHECK: %[[r8:.*]] = vector.shape_cast %[[r4]] : vector<1x16xf16> to vector<16xf16> + //CHECK: xegpu.store %[[r8]], %[[r6]], %[[r7]] {{.*}} : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %mask = arith.constant dense : vector<1x16xi1> + %idx = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> + %offsets = arith.constant dense<16> : vector<1x16xindex> + %0 = xetile.init_tile %arg0, %idx : memref<1024xf16>, vector<1x16xindex> -> !xetile.tile<1x16xf16, #xetile.tile_attr> + %1 = xetile.load %0, %mask : !xetile.tile<1x16xf16, #xetile.tile_attr>, vector<1x16xi1> -> vector<1x16xf16> + %2 = xetile.update_tile_offset %0, %offsets : !xetile.tile<1x16xf16, #xetile.tile_attr>, vector<1x16xindex> + xetile.store %1, %2, %mask : vector<1x16xf16>, !xetile.tile<1x16xf16, #xetile.tile_attr>, vector<1x16xi1> + gpu.return + } +} diff --git a/test/Conversion/XeTileToXeGPU/unpack_pack.mlir b/test/Conversion/XeTileToXeGPU/unpack_pack.mlir deleted file mode 100644 index ec9b4470d..000000000 --- a/test/Conversion/XeTileToXeGPU/unpack_pack.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --cse -verify-diagnostics %s -o - | FileCheck %s - -gpu.module @unpack_pack_non_compatible { - gpu.func @unpack_pack_non_compatible(%arg0: memref<88x32xf16>, %arg1: memref<88x32xf16>) { - %c0 = arith.constant 0 : index - %0 = xetile.init_tile %arg0[%c0, %c0] : memref<88x32xf16> -> !xetile.tile<88x32xf16, #xetile.tile_attr> - %1 = xetile.load_tile %0 {padding = 0.000000e+00 : f32} : !xetile.tile<88x32xf16, #xetile.tile_attr> -> vector<4x2x22x16xf16> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<4x2x22x16xf16> -> vector<88x32xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array} : vector<88x32xf16> -> vector<11x2x8x16xf16> - %4 = xetile.init_tile %arg1[%c0, %c0] : memref<88x32xf16> -> !xetile.tile<88x32xf16, #xetile.tile_attr> - xetile.store_tile %3, %4 : vector<11x2x8x16xf16>, !xetile.tile<88x32xf16, #xetile.tile_attr> - gpu.return - } -} - -// Since the unpack/pack are lowered independently, the unpack op is lowered to -// a series of shuffles followed by a shape cast and the pack op is lowered to -// a series of extract_strided_slice. - -// CHECK: gpu.func @unpack_pack_non_compatible -// CHECK-COUNT-4: vector.shuffle {{.*}} vector<22x16xf16> -// CHECK: vector.shape_cast {{.*}} vector<88x32xf16> -// CHECK-COUNT-22: vector.extract_strided_slice {{.*}} vector<8x16xf16> - -// ----- - -gpu.module @unpack_pack_compatible { - gpu.func @unpack_pack_compatible(%arg0: memref<64x32xf16>, %arg1: memref<64x32xf16>) { - %c0 = arith.constant 0 : index - %0 = xetile.init_tile %arg0[%c0, %c0] : memref<64x32xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr> - %1 = xetile.load_tile %0 {padding = 0.000000e+00 : f32} : !xetile.tile<64x32xf16, #xetile.tile_attr> -> vector<2x2x32x16xf16> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<2x2x32x16xf16> -> vector<64x32xf16> - %3 = xetile.tile_pack %2 {inner_blocks = array} : vector<64x32xf16> -> vector<8x2x8x16xf16> - %4 = xetile.init_tile %arg1[%c0, %c0] : memref<64x32xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr> - xetile.store_tile %3, %4 : vector<8x2x8x16xf16>, !xetile.tile<64x32xf16, #xetile.tile_attr> - gpu.return - } -} - -// Since the unpack/pack are lowered jointly, there will be a series of -// extract_strided_slice from the unpack inner blocks (32x16) to the pack inner -// blocks (8x16). - -// CHECK: gpu.func @unpack_pack_compatible -// CHECK-COUNT-16: vector.extract_strided_slice {{.*}} vector<32x16xf16> to vector<8x16xf16> diff --git a/test/Dialect/XeTile/Transforms/Blocking/unit_tests_transform.mlir b/test/Dialect/XeTile/Transforms/Blocking/unit_tests_transform.mlir index ec137fb71..a30b7b283 100644 --- a/test/Dialect/XeTile/Transforms/Blocking/unit_tests_transform.mlir +++ b/test/Dialect/XeTile/Transforms/Blocking/unit_tests_transform.mlir @@ -41,12 +41,16 @@ gpu.module @test_kernel { %c32 = arith.constant 32 : index %c20 = arith.constant 20 : index - //CHECK: %[[r0:.*]] = vector.constant_mask [8, 20] : vector<8x32xi1> + //CHECK-2D-FOR-LATER: %[[r0:.*]] = vector.constant_mask [8, 20] : vector<8x32xi1> + //CHECK: %[[r0:.*]] = vector.constant_mask [1, 20] : vector<1x32xi1> %mask = vector.create_mask %c32, %c20 : vector<32x32xi1> - //CHECK-COUNT-4: vector.extract_strided_slice %[[arg0]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> - //CHECK-COUNT-4: vector.extract_strided_slice %[[arg1]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> - //CHECK-COUNT-4: arith.select %[[r0]], %{{.*}}, %{{.*}} : vector<8x32xi1>, vector<8x32xf16> + //CHECK-2D-FOR-LATER-COUNT-4: vector.extract_strided_slice %[[arg0]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> + //CHECK-2D-FOR-LATER-COUNT-4: vector.extract_strided_slice %[[arg1]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> + //CHECK-2D-FOR-LATER-COUNT-4: arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<8x32xi1>, vector<8x32xf16> + //CHECK-COUNT-32: vector.extract_strided_slice %[[arg0]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> + //CHECK-COUNT-32: vector.extract_strided_slice %[[arg1]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> + //CHECK-COUNT-32: arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<1x32xi1>, vector<1x32xf16> %select = arith.select %mask, %a, %b : vector<32x32xi1>, vector<32x32xf16> //CHECK-COUNT-4: xetile.init_tile %[[arg2]][{{.*}}] : memref<32x32xf16> -> !xetile.tile<8x32xf16> @@ -64,14 +68,19 @@ gpu.module @test_kernel { gpu.func @create_mask_2(%a: vector<32x32xf16>, %b: vector<32x32xf16>, %c: memref<32x32xf16>) { %c20 = arith.constant 20 : index %c32 = arith.constant 32 : index - //CHECK: %[[r0:.*]] = vector.constant_mask [8, 32] : vector<8x32xi1> - //CHECK: %[[r1:.*]] = vector.constant_mask [4, 32] : vector<8x32xi1> - //CHECK: %[[r2:.*]] = vector.constant_mask [0, 0] : vector<8x32xi1> + //CHECK-2D-FOR-LATER: %[[r0:.*]] = vector.constant_mask [8, 32] : vector<8x32xi1> + //CHECK-2D-FOR-LATER: %[[r1:.*]] = vector.constant_mask [4, 32] : vector<8x32xi1> + //CHECK-2D-FOR-LATER: %[[r2:.*]] = vector.constant_mask [0, 0] : vector<8x32xi1> + //CHECK: %[[r0:.*]] = vector.constant_mask [1, 32] : vector<1x32xi1> + //CHECK: %[[r1:.*]] = vector.constant_mask [0, 0] : vector<1x32xi1> %mask = vector.create_mask %c20, %c32 : vector<32x32xi1> - //CHECK-COUNT-4: vector.extract_strided_slice %[[arg0]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> - //CHECK-COUNT-4: vector.extract_strided_slice %[[arg1]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> - //CHECK-COUNT-4: arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<8x32xi1>, vector<8x32xf16> + //CHECK-2D-FOR-LATER-COUNT: vector.extract_strided_slice %[[arg0]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> + //CHECK-2D-FOR-LATER-COUNT: vector.extract_strided_slice %[[arg1]] {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<32x32xf16> to vector<8x32xf16> + //CHECK-2D-FOR-LATER-COUNT: arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<8x32xi1>, vector<8x32xf16> + //CHECK-COUNT-32: vector.extract_strided_slice %[[arg0]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> + //CHECK-COUNT-32: vector.extract_strided_slice %[[arg1]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> + //CHECK-COUNT-32: arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<1x32xi1>, vector<1x32xf16> %select = arith.select %mask, %a, %b : vector<32x32xi1>, vector<32x32xf16> //CHECK-COUNT-4: xetile.init_tile %[[arg2]][%{{.*}}, %{{.*}}] : memref<32x32xf16> -> !xetile.tile<8x32xf16> @@ -586,13 +595,32 @@ gpu.module @test_kernel { //CHECK-COUNT-32: %{{.*}} = vector.extract_strided_slice %[[r3]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> //CHECK-COUNT-62: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> - //CHECK-COUNT-2: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> - //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<64xf16> to vector<1x64xf16> - //CHECK: %{{.*}} = xetile.broadcast %{{.*}} [0] : vector<1x64xf16> -> vector<32x64xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-8: %{{.*}} = xetile.init_tile %[[arg0]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK-COUNT-4: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 64], strides = [1, 1]} : vector<32x64xf16> to vector<8x64xf16> - //CHECK-COUNT-8: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> //CHECK-COUNT-8: xetile.store_tile %{{.*}}, %{{.*}} : vector<8x32xf16>, !xetile.tile<8x32xf16> %1 = xetile.init_tile %a[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16> %2 = xetile.load_tile %1: !xetile.tile<32x64xf16> -> vector<32x64xf16> @@ -611,389 +639,201 @@ gpu.module @test_kernel { //CHECK: %[[c24:.*]] = arith.constant 24 : index //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[c8:.*]] = arith.constant 8 : index - //CHECK: %[[c31_i32:.*]] = arith.constant 31 : i32 - //CHECK: %[[c30_i32:.*]] = arith.constant 30 : i32 - //CHECK: %[[c29_i32:.*]] = arith.constant 29 : i32 - //CHECK: %[[c28_i32:.*]] = arith.constant 28 : i32 - //CHECK: %[[c27_i32:.*]] = arith.constant 27 : i32 - //CHECK: %[[c26_i32:.*]] = arith.constant 26 : i32 - //CHECK: %[[c25_i32:.*]] = arith.constant 25 : i32 - //CHECK: %[[c24_i32:.*]] = arith.constant 24 : i32 - //CHECK: %[[c23_i32:.*]] = arith.constant 23 : i32 - //CHECK: %[[c22_i32:.*]] = arith.constant 22 : i32 - //CHECK: %[[c21_i32:.*]] = arith.constant 21 : i32 - //CHECK: %[[c20_i32:.*]] = arith.constant 20 : i32 - //CHECK: %[[c19_i32:.*]] = arith.constant 19 : i32 - //CHECK: %[[c18_i32:.*]] = arith.constant 18 : i32 - //CHECK: %[[c17_i32:.*]] = arith.constant 17 : i32 - //CHECK: %[[c16_i32:.*]] = arith.constant 16 : i32 - //CHECK: %[[c15_i32:.*]] = arith.constant 15 : i32 - //CHECK: %[[c14_i32:.*]] = arith.constant 14 : i32 - //CHECK: %[[c13_i32:.*]] = arith.constant 13 : i32 - //CHECK: %[[c12_i32:.*]] = arith.constant 12 : i32 - //CHECK: %[[c11_i32:.*]] = arith.constant 11 : i32 - //CHECK: %[[c10_i32:.*]] = arith.constant 10 : i32 - //CHECK: %[[c9_i32:.*]] = arith.constant 9 : i32 - //CHECK: %[[c8_i32:.*]] = arith.constant 8 : i32 - //CHECK: %[[c7_i32:.*]] = arith.constant 7 : i32 - //CHECK: %[[c6_i32:.*]] = arith.constant 6 : i32 - //CHECK: %[[c5_i32:.*]] = arith.constant 5 : i32 - //CHECK: %[[c4_i32:.*]] = arith.constant 4 : i32 - //CHECK: %[[c3_i32:.*]] = arith.constant 3 : i32 - //CHECK: %[[c2_i32:.*]] = arith.constant 2 : i32 - //CHECK: %[[c1_i32:.*]] = arith.constant 1 : i32 - //CHECK: %[[c0_i32:.*]] = arith.constant 0 : i32 //CHECK: %[[c32:.*]] = arith.constant 32 : index //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[r0:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> //CHECK: %[[r1:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c32]]] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> //CHECK: %[[r2:.*]] = xetile.load_tile %[[r0]] : !xetile.tile<32x32xf16> -> vector<32x32xf16> //CHECK: %[[r3:.*]] = xetile.load_tile %[[r1]] : !xetile.tile<32x32xf16> -> vector<32x32xf16> - //CHECK: %[[r4:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [0, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r5:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [1, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r6:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [2, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r7:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [3, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r8:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [4, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r9:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [5, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r10:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [6, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r11:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [7, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r12:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [8, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r13:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [9, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r14:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [10, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r15:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [11, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r16:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [12, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r17:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [13, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r18:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [14, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r19:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [15, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r20:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [16, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r21:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [17, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r22:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [18, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r23:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [19, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r24:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [20, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r25:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [21, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r26:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [22, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r27:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [23, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r28:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [24, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r29:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [25, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r30:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [26, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r31:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [27, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r32:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [28, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r33:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [29, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r34:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [30, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r35:.*]] = vector.extract_strided_slice %[[r2]] {offsets = [31, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r36:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [0, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r37:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [1, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r38:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [2, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r39:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [3, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r40:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [4, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r41:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [5, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r42:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [6, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r43:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [7, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r44:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [8, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r45:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [9, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r46:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [10, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r47:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [11, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r48:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [12, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r49:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [13, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r50:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [14, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r51:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [15, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r52:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [16, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r53:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [17, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r54:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [18, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r55:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [19, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r56:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [20, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r57:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [21, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r58:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [22, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r59:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [23, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r60:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [24, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r61:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [25, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r62:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [26, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r63:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [27, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r64:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [28, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r65:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [29, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r66:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [30, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r67:.*]] = vector.extract_strided_slice %[[r3]] {offsets = [31, 0], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> - //CHECK: %[[r68:.*]] = arith.addf %[[r4]], %[[r36]] : vector<1x32xf16> - //CHECK: %[[r69:.*]] = vector.shape_cast %[[r68]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r70:.*]] = arith.addf %[[r5]], %[[r37]] : vector<1x32xf16> - //CHECK: %[[r71:.*]] = vector.shape_cast %[[r70]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r72:.*]] = arith.addf %[[r6]], %[[r38]] : vector<1x32xf16> - //CHECK: %[[r73:.*]] = vector.shape_cast %[[r72]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r74:.*]] = arith.addf %[[r7]], %[[r39]] : vector<1x32xf16> - //CHECK: %[[r75:.*]] = vector.shape_cast %[[r74]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r76:.*]] = arith.addf %[[r8]], %[[r40]] : vector<1x32xf16> - //CHECK: %[[r77:.*]] = vector.shape_cast %[[r76]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r78:.*]] = arith.addf %[[r9]], %[[r41]] : vector<1x32xf16> - //CHECK: %[[r79:.*]] = vector.shape_cast %[[r78]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r80:.*]] = arith.addf %[[r10]], %[[r42]] : vector<1x32xf16> - //CHECK: %[[r81:.*]] = vector.shape_cast %[[r80]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r82:.*]] = arith.addf %[[r11]], %[[r43]] : vector<1x32xf16> - //CHECK: %[[r83:.*]] = vector.shape_cast %[[r82]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r84:.*]] = arith.addf %[[r12]], %[[r44]] : vector<1x32xf16> - //CHECK: %[[r85:.*]] = vector.shape_cast %[[r84]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r86:.*]] = arith.addf %[[r13]], %[[r45]] : vector<1x32xf16> - //CHECK: %[[r87:.*]] = vector.shape_cast %[[r86]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r88:.*]] = arith.addf %[[r14]], %[[r46]] : vector<1x32xf16> - //CHECK: %[[r89:.*]] = vector.shape_cast %[[r88]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r90:.*]] = arith.addf %[[r15]], %[[r47]] : vector<1x32xf16> - //CHECK: %[[r91:.*]] = vector.shape_cast %[[r90]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r92:.*]] = arith.addf %[[r16]], %[[r48]] : vector<1x32xf16> - //CHECK: %[[r93:.*]] = vector.shape_cast %[[r92]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r94:.*]] = arith.addf %[[r17]], %[[r49]] : vector<1x32xf16> - //CHECK: %[[r95:.*]] = vector.shape_cast %[[r94]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r96:.*]] = arith.addf %[[r18]], %[[r50]] : vector<1x32xf16> - //CHECK: %[[r97:.*]] = vector.shape_cast %[[r96]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r98:.*]] = arith.addf %[[r19]], %[[r51]] : vector<1x32xf16> - //CHECK: %[[r99:.*]] = vector.shape_cast %[[r98]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r100:.*]] = arith.addf %[[r20]], %[[r52]] : vector<1x32xf16> - //CHECK: %[[r101:.*]] = vector.shape_cast %[[r100]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r102:.*]] = arith.addf %[[r21]], %[[r53]] : vector<1x32xf16> - //CHECK: %[[r103:.*]] = vector.shape_cast %[[r102]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r104:.*]] = arith.addf %[[r22]], %[[r54]] : vector<1x32xf16> - //CHECK: %[[r105:.*]] = vector.shape_cast %[[r104]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r106:.*]] = arith.addf %[[r23]], %[[r55]] : vector<1x32xf16> - //CHECK: %[[r107:.*]] = vector.shape_cast %[[r106]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r108:.*]] = arith.addf %[[r24]], %[[r56]] : vector<1x32xf16> - //CHECK: %[[r109:.*]] = vector.shape_cast %[[r108]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r110:.*]] = arith.addf %[[r25]], %[[r57]] : vector<1x32xf16> - //CHECK: %[[r111:.*]] = vector.shape_cast %[[r110]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r112:.*]] = arith.addf %[[r26]], %[[r58]] : vector<1x32xf16> - //CHECK: %[[r113:.*]] = vector.shape_cast %[[r112]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r114:.*]] = arith.addf %[[r27]], %[[r59]] : vector<1x32xf16> - //CHECK: %[[r115:.*]] = vector.shape_cast %[[r114]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r116:.*]] = arith.addf %[[r28]], %[[r60]] : vector<1x32xf16> - //CHECK: %[[r117:.*]] = vector.shape_cast %[[r116]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r118:.*]] = arith.addf %[[r29]], %[[r61]] : vector<1x32xf16> - //CHECK: %[[r119:.*]] = vector.shape_cast %[[r118]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r120:.*]] = arith.addf %[[r30]], %[[r62]] : vector<1x32xf16> - //CHECK: %[[r121:.*]] = vector.shape_cast %[[r120]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r122:.*]] = arith.addf %[[r31]], %[[r63]] : vector<1x32xf16> - //CHECK: %[[r123:.*]] = vector.shape_cast %[[r122]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r124:.*]] = arith.addf %[[r32]], %[[r64]] : vector<1x32xf16> - //CHECK: %[[r125:.*]] = vector.shape_cast %[[r124]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r126:.*]] = arith.addf %[[r33]], %[[r65]] : vector<1x32xf16> - //CHECK: %[[r127:.*]] = vector.shape_cast %[[r126]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r128:.*]] = arith.addf %[[r34]], %[[r66]] : vector<1x32xf16> - //CHECK: %[[r129:.*]] = vector.shape_cast %[[r128]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r130:.*]] = arith.addf %[[r35]], %[[r67]] : vector<1x32xf16> - //CHECK: %[[r131:.*]] = vector.shape_cast %[[r130]] : vector<1x32xf16> to vector<32xf16> - //CHECK: %[[r132:.*]] = vector.shuffle %[[r69]], %[[r71]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r133:.*]] = vector.shuffle %[[r69]], %[[r71]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r134:.*]] = arith.addf %[[r132]], %[[r133]] : vector<32xf16> - //CHECK: %[[r135:.*]] = vector.shuffle %[[r73]], %[[r75]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r136:.*]] = vector.shuffle %[[r73]], %[[r75]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r137:.*]] = arith.addf %[[r135]], %[[r136]] : vector<32xf16> - //CHECK: %[[r138:.*]] = vector.shuffle %[[r77]], %[[r79]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r139:.*]] = vector.shuffle %[[r77]], %[[r79]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r140:.*]] = arith.addf %[[r138]], %[[r139]] : vector<32xf16> - //CHECK: %[[r141:.*]] = vector.shuffle %[[r81]], %[[r83]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r142:.*]] = vector.shuffle %[[r81]], %[[r83]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r143:.*]] = arith.addf %[[r141]], %[[r142]] : vector<32xf16> - //CHECK: %[[r144:.*]] = vector.shuffle %[[r85]], %[[r87]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r145:.*]] = vector.shuffle %[[r85]], %[[r87]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r146:.*]] = arith.addf %[[r144]], %[[r145]] : vector<32xf16> - //CHECK: %[[r147:.*]] = vector.shuffle %[[r89]], %[[r91]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r148:.*]] = vector.shuffle %[[r89]], %[[r91]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r149:.*]] = arith.addf %[[r147]], %[[r148]] : vector<32xf16> - //CHECK: %[[r150:.*]] = vector.shuffle %[[r93]], %[[r95]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r151:.*]] = vector.shuffle %[[r93]], %[[r95]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r152:.*]] = arith.addf %[[r150]], %[[r151]] : vector<32xf16> - //CHECK: %[[r153:.*]] = vector.shuffle %[[r97]], %[[r99]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r154:.*]] = vector.shuffle %[[r97]], %[[r99]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r155:.*]] = arith.addf %[[r153]], %[[r154]] : vector<32xf16> - //CHECK: %[[r156:.*]] = vector.shuffle %[[r101]], %[[r103]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r157:.*]] = vector.shuffle %[[r101]], %[[r103]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r158:.*]] = arith.addf %[[r156]], %[[r157]] : vector<32xf16> - //CHECK: %[[r159:.*]] = vector.shuffle %[[r105]], %[[r107]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r160:.*]] = vector.shuffle %[[r105]], %[[r107]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r161:.*]] = arith.addf %[[r159]], %[[r160]] : vector<32xf16> - //CHECK: %[[r162:.*]] = vector.shuffle %[[r109]], %[[r111]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r163:.*]] = vector.shuffle %[[r109]], %[[r111]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r164:.*]] = arith.addf %[[r162]], %[[r163]] : vector<32xf16> - //CHECK: %[[r165:.*]] = vector.shuffle %[[r113]], %[[r115]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r166:.*]] = vector.shuffle %[[r113]], %[[r115]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r167:.*]] = arith.addf %[[r165]], %[[r166]] : vector<32xf16> - //CHECK: %[[r168:.*]] = vector.shuffle %[[r117]], %[[r119]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r169:.*]] = vector.shuffle %[[r117]], %[[r119]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r170:.*]] = arith.addf %[[r168]], %[[r169]] : vector<32xf16> - //CHECK: %[[r171:.*]] = vector.shuffle %[[r121]], %[[r123]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r172:.*]] = vector.shuffle %[[r121]], %[[r123]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r173:.*]] = arith.addf %[[r171]], %[[r172]] : vector<32xf16> - //CHECK: %[[r174:.*]] = vector.shuffle %[[r125]], %[[r127]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r175:.*]] = vector.shuffle %[[r125]], %[[r127]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r176:.*]] = arith.addf %[[r174]], %[[r175]] : vector<32xf16> - //CHECK: %[[r177:.*]] = vector.shuffle %[[r129]], %[[r131]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r178:.*]] = vector.shuffle %[[r129]], %[[r131]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r179:.*]] = arith.addf %[[r177]], %[[r178]] : vector<32xf16> - //CHECK: %[[r180:.*]] = vector.shuffle %[[r134]], %[[r137]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r181:.*]] = vector.shuffle %[[r134]], %[[r137]] [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r182:.*]] = arith.addf %[[r180]], %[[r181]] : vector<32xf16> - //CHECK: %[[r183:.*]] = vector.shuffle %[[r140]], %[[r143]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r184:.*]] = vector.shuffle %[[r140]], %[[r143]] [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r185:.*]] = arith.addf %[[r183]], %[[r184]] : vector<32xf16> - //CHECK: %[[r186:.*]] = vector.shuffle %[[r146]], %[[r149]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r187:.*]] = vector.shuffle %[[r146]], %[[r149]] [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r188:.*]] = arith.addf %[[r186]], %[[r187]] : vector<32xf16> - //CHECK: %[[r189:.*]] = vector.shuffle %[[r152]], %[[r155]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r190:.*]] = vector.shuffle %[[r152]], %[[r155]] [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r191:.*]] = arith.addf %[[r189]], %[[r190]] : vector<32xf16> - //CHECK: %[[r192:.*]] = vector.shuffle %[[r158]], %[[r161]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r193:.*]] = vector.shuffle %[[r158]], %[[r161]] [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r194:.*]] = arith.addf %[[r192]], %[[r193]] : vector<32xf16> - //CHECK: %[[r195:.*]] = vector.shuffle %[[r164]], %[[r167]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r196:.*]] = vector.shuffle %[[r164]], %[[r167]] [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r197:.*]] = arith.addf %[[r195]], %[[r196]] : vector<32xf16> - //CHECK: %[[r198:.*]] = vector.shuffle %[[r170]], %[[r173]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r199:.*]] = vector.shuffle %[[r170]], %[[r173]] [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r200:.*]] = arith.addf %[[r198]], %[[r199]] : vector<32xf16> - //CHECK: %[[r201:.*]] = vector.shuffle %[[r176]], %[[r179]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r202:.*]] = vector.shuffle %[[r176]], %[[r179]] [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r203:.*]] = arith.addf %[[r201]], %[[r202]] : vector<32xf16> - //CHECK: %[[r204:.*]] = vector.shuffle %[[r182]], %[[r185]] [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r205:.*]] = vector.shuffle %[[r182]], %[[r185]] [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r206:.*]] = arith.addf %[[r204]], %[[r205]] : vector<32xf16> - //CHECK: %[[r207:.*]] = vector.shuffle %[[r188]], %[[r191]] [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r208:.*]] = vector.shuffle %[[r188]], %[[r191]] [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r209:.*]] = arith.addf %[[r207]], %[[r208]] : vector<32xf16> - //CHECK: %[[r210:.*]] = vector.shuffle %[[r194]], %[[r197]] [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r211:.*]] = vector.shuffle %[[r194]], %[[r197]] [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r212:.*]] = arith.addf %[[r210]], %[[r211]] : vector<32xf16> - //CHECK: %[[r213:.*]] = vector.shuffle %[[r200]], %[[r203]] [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r214:.*]] = vector.shuffle %[[r200]], %[[r203]] [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r215:.*]] = arith.addf %[[r213]], %[[r214]] : vector<32xf16> - //CHECK: %[[r216:.*]] = vector.shuffle %[[r206]], %[[r209]] [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29, 32, 33, 36, 37, 40, 41, 44, 45, 48, 49, 52, 53, 56, 57, 60, 61] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r217:.*]] = vector.shuffle %[[r206]], %[[r209]] [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31, 34, 35, 38, 39, 42, 43, 46, 47, 50, 51, 54, 55, 58, 59, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r218:.*]] = arith.addf %[[r216]], %[[r217]] : vector<32xf16> - //CHECK: %[[r219:.*]] = vector.shuffle %[[r212]], %[[r215]] [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29, 32, 33, 36, 37, 40, 41, 44, 45, 48, 49, 52, 53, 56, 57, 60, 61] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r220:.*]] = vector.shuffle %[[r212]], %[[r215]] [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31, 34, 35, 38, 39, 42, 43, 46, 47, 50, 51, 54, 55, 58, 59, 62, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r221:.*]] = arith.addf %[[r219]], %[[r220]] : vector<32xf16> - //CHECK: %[[r222:.*]] = vector.shuffle %[[r218]], %[[r221]] [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r223:.*]] = vector.shuffle %[[r218]], %[[r221]] [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63] : vector<32xf16>, vector<32xf16> - //CHECK: %[[r224:.*]] = arith.addf %[[r222]], %[[r223]] : vector<32xf16> - //CHECK: %[[r225:.*]] = vector.extractelement %[[r224]][%[[c0_i32]] : i32] : vector<32xf16> - //CHECK: %[[r226:.*]] = vector.splat %[[r225]] : vector<1x1xf16> - //CHECK: %[[r227:.*]] = vector.extractelement %[[r224]][%[[c1_i32]] : i32] : vector<32xf16> - //CHECK: %[[r228:.*]] = vector.splat %[[r227]] : vector<1x1xf16> - //CHECK: %[[r229:.*]] = vector.extractelement %224[%c2_i32 : i32] : vector<32xf16> - //CHECK: %[[r230:.*]] = vector.splat %[[r229]] : vector<1x1xf16> - //CHECK: %[[r231:.*]] = vector.extractelement %[[r224]][%[[c3_i32]] : i32] : vector<32xf16> - //CHECK: %[[r232:.*]] = vector.splat %[[r231]] : vector<1x1xf16> - //CHECK: %[[r233:.*]] = vector.extractelement %[[r224]][%[[c4_i32]] : i32] : vector<32xf16> - //CHECK: %[[r234:.*]] = vector.splat %[[r233]] : vector<1x1xf16> - //CHECK: %[[r235:.*]] = vector.extractelement %[[r224]][%[[c5_i32]] : i32] : vector<32xf16> - //CHECK: %[[r236:.*]] = vector.splat %[[r235]] : vector<1x1xf16> - //CHECK: %[[r237:.*]] = vector.extractelement %[[r224]][%[[c6_i32]] : i32] : vector<32xf16> - //CHECK: %[[r238:.*]] = vector.splat %[[r237]] : vector<1x1xf16> - //CHECK: %[[r239:.*]] = vector.extractelement %[[r224]][%[[c7_i32]] : i32] : vector<32xf16> - //CHECK: %[[r240:.*]] = vector.splat %[[r239]] : vector<1x1xf16> - //CHECK: %[[r241:.*]] = vector.extractelement %[[r224]][%[[c8_i32]] : i32] : vector<32xf16> - //CHECK: %[[r242:.*]] = vector.splat %[[r241]] : vector<1x1xf16> - //CHECK: %[[r243:.*]] = vector.extractelement %[[r224]][%[[c9_i32]] : i32] : vector<32xf16> - //CHECK: %[[r244:.*]] = vector.splat %[[r243]] : vector<1x1xf16> - //CHECK: %[[r245:.*]] = vector.extractelement %[[r224]][%[[c10_i32]] : i32] : vector<32xf16> - //CHECK: %[[r246:.*]] = vector.splat %[[r245]] : vector<1x1xf16> - //CHECK: %[[r247:.*]] = vector.extractelement %[[r224]][%[[c11_i32]] : i32] : vector<32xf16> - //CHECK: %[[r248:.*]] = vector.splat %[[r247]] : vector<1x1xf16> - //CHECK: %[[r249:.*]] = vector.extractelement %[[r224]][%[[c12_i32]] : i32] : vector<32xf16> - //CHECK: %[[r250:.*]] = vector.splat %[[r249]] : vector<1x1xf16> - //CHECK: %[[r251:.*]] = vector.extractelement %[[r224]][%[[c13_i32]] : i32] : vector<32xf16> - //CHECK: %[[r252:.*]] = vector.splat %[[r251]] : vector<1x1xf16> - //CHECK: %[[r253:.*]] = vector.extractelement %[[r224]][%[[c14_i32]] : i32] : vector<32xf16> - //CHECK: %[[r254:.*]] = vector.splat %[[r253]] : vector<1x1xf16> - //CHECK: %[[r255:.*]] = vector.extractelement %[[r224]][%[[c15_i32]] : i32] : vector<32xf16> - //CHECK: %[[r256:.*]] = vector.splat %[[r255]] : vector<1x1xf16> - //CHECK: %[[r257:.*]] = vector.extractelement %[[r224]][%[[c16_i32]] : i32] : vector<32xf16> - //CHECK: %[[r258:.*]] = vector.splat %[[r257]] : vector<1x1xf16> - //CHECK: %[[r259:.*]] = vector.extractelement %[[r224]][%[[c17_i32]] : i32] : vector<32xf16> - //CHECK: %[[r260:.*]] = vector.splat %[[r259]] : vector<1x1xf16> - //CHECK: %[[r261:.*]] = vector.extractelement %[[r224]][%[[c18_i32]] : i32] : vector<32xf16> - //CHECK: %[[r262:.*]] = vector.splat %[[r261]] : vector<1x1xf16> - //CHECK: %[[r263:.*]] = vector.extractelement %[[r224]][%[[c19_i32]] : i32] : vector<32xf16> - //CHECK: %[[r264:.*]] = vector.splat %[[r263]] : vector<1x1xf16> - //CHECK: %[[r265:.*]] = vector.extractelement %[[r224]][%[[c20_i32]] : i32] : vector<32xf16> - //CHECK: %[[r266:.*]] = vector.splat %[[r265]] : vector<1x1xf16> - //CHECK: %[[r267:.*]] = vector.extractelement %[[r224]][%[[c21_i32]] : i32] : vector<32xf16> - //CHECK: %[[r268:.*]] = vector.splat %[[r267]] : vector<1x1xf16> - //CHECK: %[[r269:.*]] = vector.extractelement %[[r224]][%[[c22_i32]] : i32] : vector<32xf16> - //CHECK: %[[r270:.*]] = vector.splat %[[r269]] : vector<1x1xf16> - //CHECK: %[[r271:.*]] = vector.extractelement %[[r224]][%[[c23_i32]] : i32] : vector<32xf16> - //CHECK: %[[r272:.*]] = vector.splat %[[r271]] : vector<1x1xf16> - //CHECK: %[[r273:.*]] = vector.extractelement %[[r224]][%[[c24_i32]] : i32] : vector<32xf16> - //CHECK: %[[r274:.*]] = vector.splat %[[r273]] : vector<1x1xf16> - //CHECK: %[[r275:.*]] = vector.extractelement %[[r224]][%[[c25_i32]] : i32] : vector<32xf16> - //CHECK: %[[r276:.*]] = vector.splat %[[r275]] : vector<1x1xf16> - //CHECK: %[[r277:.*]] = vector.extractelement %[[r224]][%[[c26_i32]] : i32] : vector<32xf16> - //CHECK: %[[r278:.*]] = vector.splat %[[r277]] : vector<1x1xf16> - //CHECK: %[[r279:.*]] = vector.extractelement %[[r224]][%[[c27_i32]] : i32] : vector<32xf16> - //CHECK: %[[r280:.*]] = vector.splat %[[r279]] : vector<1x1xf16> - //CHECK: %[[r281:.*]] = vector.extractelement %[[r224]][%c28_i32 : i32] : vector<32xf16> - //CHECK: %[[r282:.*]] = vector.splat %[[r281]] : vector<1x1xf16> - //CHECK: %[[r283:.*]] = vector.extractelement %[[r224]][%[[c29_i32]] : i32] : vector<32xf16> - //CHECK: %[[r284:.*]] = vector.splat %[[r283]] : vector<1x1xf16> - //CHECK: %[[r285:.*]] = vector.extractelement %[[r224]][%[[c30_i32]] : i32] : vector<32xf16> - //CHECK: %[[r286:.*]] = vector.splat %[[r285]] : vector<1x1xf16> - //CHECK: %[[r287:.*]] = vector.extractelement %[[r224]][%[[c31_i32]] : i32] : vector<32xf16> - //CHECK: %[[r288:.*]] = vector.splat %[[r287]] : vector<1x1xf16> - //CHECK: %[[r289:.*]] = vector.shuffle %[[r226]], %[[r228]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r290:.*]] = vector.shuffle %[[r230]], %[[r232]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r291:.*]] = vector.shuffle %[[r234]], %[[r236]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r292:.*]] = vector.shuffle %[[r238]], %[[r240]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r293:.*]] = vector.shuffle %[[r242]], %[[r244]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r294:.*]] = vector.shuffle %[[r246]], %[[r248]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r295:.*]] = vector.shuffle %[[r250]], %[[r252]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r296:.*]] = vector.shuffle %[[r254]], %[[r256]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r297:.*]] = vector.shuffle %[[r258]], %[[r260]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r298:.*]] = vector.shuffle %[[r262]], %[[r264]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r299:.*]] = vector.shuffle %[[r266]], %[[r268]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r300:.*]] = vector.shuffle %[[r270]], %[[r272]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r301:.*]] = vector.shuffle %[[r274]], %[[r276]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r302:.*]] = vector.shuffle %[[r278]], %[[r280]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r303:.*]] = vector.shuffle %[[r282]], %[[r284]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r304:.*]] = vector.shuffle %[[r286]], %[[r288]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r305:.*]] = vector.shuffle %[[r289]], %[[r290]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r306:.*]] = vector.shuffle %[[r291]], %[[r292]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r307:.*]] = vector.shuffle %[[r293]], %[[r294]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r308:.*]] = vector.shuffle %[[r295]], %[[r296]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r309:.*]] = vector.shuffle %[[r297]], %[[r298]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r310:.*]] = vector.shuffle %[[r299]], %[[r300]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r311:.*]] = vector.shuffle %[[r301]], %[[r302]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r312:.*]] = vector.shuffle %[[r303]], %[[r304]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r313:.*]] = vector.shuffle %[[r305]], %[[r306]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK: %[[r314:.*]] = vector.shuffle %[[r307]], %[[r308]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK: %[[r315:.*]] = vector.shuffle %[[r309]], %[[r310]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK: %[[r316:.*]] = vector.shuffle %[[r311]], %[[r312]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK: %[[r317:.*]] = vector.shuffle %[[r313]], %[[r314]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x1xf16>, vector<8x1xf16> - //CHECK: %[[r318:.*]] = vector.shuffle %[[r315]], %[[r316]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x1xf16>, vector<8x1xf16> - //CHECK: %[[r319:.*]] = vector.shuffle %[[r317]], %[[r318]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x1xf16>, vector<16x1xf16> - //CHECK: %[[r320:.*]] = xetile.broadcast %[[r319]] [1] : vector<32x1xf16> -> vector<32x64xf16> - //CHECK: %[[r321:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r322:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c32]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r323:.*]] = xetile.init_tile %[[arg0]][%[[c8]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r324:.*]] = xetile.init_tile %[[arg0]][%[[c8]], %[[c32]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r325:.*]] = xetile.init_tile %[[arg0]][%[[c16]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r326:.*]] = xetile.init_tile %[[arg0]][%[[c16]], %[[c32]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r327:.*]] = xetile.init_tile %[[arg0]][%[[c24]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r328:.*]] = xetile.init_tile %[[arg0]][%[[c24]], %[[c32]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r329:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [0, 0], sizes = [8, 64], strides = [1, 1]} : vector<32x64xf16> to vector<8x64xf16> - //CHECK: %[[r330:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [8, 0], sizes = [8, 64], strides = [1, 1]} : vector<32x64xf16> to vector<8x64xf16> - //CHECK: %[[r331:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [16, 0], sizes = [8, 64], strides = [1, 1]} : vector<32x64xf16> to vector<8x64xf16> - //CHECK: %[[r332:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [24, 0], sizes = [8, 64], strides = [1, 1]} : vector<32x64xf16> to vector<8x64xf16> - //CHECK: %[[r333:.*]] = vector.extract_strided_slice %[[r329]] {offsets = [0, 0], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> - //CHECK: %[[r334:.*]] = vector.extract_strided_slice %[[r329]] {offsets = [0, 32], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> - //CHECK: %[[r335:.*]] = vector.extract_strided_slice %[[r330]] {offsets = [0, 0], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> - //CHECK: %[[r336:.*]] = vector.extract_strided_slice %[[r330]] {offsets = [0, 32], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> - //CHECK: %[[r337:.*]] = vector.extract_strided_slice %[[r331]] {offsets = [0, 0], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> - //CHECK: %[[r338:.*]] = vector.extract_strided_slice %[[r331]] {offsets = [0, 32], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> - //CHECK: %[[r339:.*]] = vector.extract_strided_slice %[[r332]] {offsets = [0, 0], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> - //CHECK: %[[r340:.*]] = vector.extract_strided_slice %[[r332]] {offsets = [0, 32], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> - //CHECK: xetile.store_tile %[[r333]], %[[r321]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r334]], %[[r322]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r335]], %[[r323]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r336]], %[[r324]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r337]], %[[r325]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r338]], %[[r326]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r339]], %[[r327]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r340]], %[[r328]] : vector<8x32xf16>, !xetile.tile<8x32xf16> + //CHECK-COUNT-32: %{{.*}} = vector.extract_strided_slice %[[r2]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> + //CHECK-COUNT-32: %{{.*}} = vector.extract_strided_slice %[[r3]] {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<32x32xf16> to vector<1x32xf16> + + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> + //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> + + + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + + + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49, 50, 51, 52, 53, 54, 55] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + + + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, 36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29, 32, 33, 36, 37, 40, 41, 44, 45, 48, 49, 52, 53, 56, 57, 60, 61] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31, 34, 35, 38, 39, 42, 43, 46, 47, 50, 51, 54, 55, 58, 59, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29, 32, 33, 36, 37, 40, 41, 44, 45, 48, 49, 52, 53, 56, 57, 60, 61] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31, 34, 35, 38, 39, 42, 43, 46, 47, 50, 51, 54, 55, 58, 59, 62, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63] : vector<32xf16>, vector<32xf16> + //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + + //CHECK-COUNT-32: %{{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<32xf16> + //CHECK-COUNT-32: %{{.*}} = vector.splat %{{.*}} : vector<1x32xf16> + + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + + //CHECK-COUNT-8: %{{.*}} = xetile.init_tile %[[arg0]][{{.*}}] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> + //CHECK-COUNT-8: xetile.store_tile %{{.*}}, %{{.*}} : vector<8x32xf16>, !xetile.tile<8x32xf16> %1 = xetile.init_tile %a[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16> %2 = xetile.load_tile %1: !xetile.tile<32x64xf16> -> vector<32x64xf16> %3 = xetile.reduction , %2 [1]: vector<32x64xf16> -> vector<32x1xf16> @@ -1272,134 +1112,54 @@ gpu.module @test_kernel { //CHECK: %[[r222:.*]] = vector.shuffle %[[r218]], %[[r221]] [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62] : vector<32xf16>, vector<32xf16> //CHECK: %[[r223:.*]] = vector.shuffle %[[r218]], %[[r221]] [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63] : vector<32xf16>, vector<32xf16> //CHECK: %[[r224:.*]] = arith.addf %222, %223 : vector<32xf16> - //CHECK: %[[r225:.*]] = vector.extractelement %224[%c0_i32 : i32] : vector<32xf16> - //CHECK: %[[r226:.*]] = vector.splat %225 : vector<1x1xf16> - //CHECK: %[[r227:.*]] = vector.extractelement %224[%c1_i32 : i32] : vector<32xf16> - //CHECK: %[[r228:.*]] = vector.splat %227 : vector<1x1xf16> - //CHECK: %[[r229:.*]] = vector.extractelement %224[%c2_i32 : i32] : vector<32xf16> - //CHECK: %[[r230:.*]] = vector.splat %229 : vector<1x1xf16> - //CHECK: %[[r231:.*]] = vector.extractelement %224[%c3_i32 : i32] : vector<32xf16> - //CHECK: %[[r232:.*]] = vector.splat %231 : vector<1x1xf16> - //CHECK: %[[r233:.*]] = vector.extractelement %224[%c4_i32 : i32] : vector<32xf16> - //CHECK: %[[r234:.*]] = vector.splat %233 : vector<1x1xf16> - //CHECK: %[[r235:.*]] = vector.extractelement %224[%c5_i32 : i32] : vector<32xf16> - //CHECK: %[[r236:.*]] = vector.splat %235 : vector<1x1xf16> - //CHECK: %[[r237:.*]] = vector.extractelement %224[%c6_i32 : i32] : vector<32xf16> - //CHECK: %[[r238:.*]] = vector.splat %237 : vector<1x1xf16> - //CHECK: %[[r239:.*]] = vector.extractelement %224[%c7_i32 : i32] : vector<32xf16> - //CHECK: %[[r240:.*]] = vector.splat %239 : vector<1x1xf16> - //CHECK: %[[r241:.*]] = vector.extractelement %224[%c8_i32 : i32] : vector<32xf16> - //CHECK: %[[r242:.*]] = vector.splat %241 : vector<1x1xf16> - //CHECK: %[[r243:.*]] = vector.extractelement %224[%c9_i32 : i32] : vector<32xf16> - //CHECK: %[[r244:.*]] = vector.splat %243 : vector<1x1xf16> - //CHECK: %[[r245:.*]] = vector.extractelement %224[%c10_i32 : i32] : vector<32xf16> - //CHECK: %[[r246:.*]] = vector.splat %245 : vector<1x1xf16> - //CHECK: %[[r247:.*]] = vector.extractelement %224[%c11_i32 : i32] : vector<32xf16> - //CHECK: %[[r248:.*]] = vector.splat %247 : vector<1x1xf16> - //CHECK: %[[r249:.*]] = vector.extractelement %224[%c12_i32 : i32] : vector<32xf16> - //CHECK: %[[r250:.*]] = vector.splat %249 : vector<1x1xf16> - //CHECK: %[[r251:.*]] = vector.extractelement %224[%c13_i32 : i32] : vector<32xf16> - //CHECK: %[[r252:.*]] = vector.splat %251 : vector<1x1xf16> - //CHECK: %[[r253:.*]] = vector.extractelement %224[%c14_i32 : i32] : vector<32xf16> - //CHECK: %[[r254:.*]] = vector.splat %253 : vector<1x1xf16> - //CHECK: %[[r255:.*]] = vector.extractelement %224[%c15_i32 : i32] : vector<32xf16> - //CHECK: %[[r256:.*]] = vector.splat %255 : vector<1x1xf16> - //CHECK: %[[r257:.*]] = vector.extractelement %224[%c16_i32 : i32] : vector<32xf16> - //CHECK: %[[r258:.*]] = vector.splat %257 : vector<1x1xf16> - //CHECK: %[[r259:.*]] = vector.extractelement %224[%c17_i32 : i32] : vector<32xf16> - //CHECK: %[[r260:.*]] = vector.splat %259 : vector<1x1xf16> - //CHECK: %[[r261:.*]] = vector.extractelement %224[%c18_i32 : i32] : vector<32xf16> - //CHECK: %[[r262:.*]] = vector.splat %261 : vector<1x1xf16> - //CHECK: %[[r263:.*]] = vector.extractelement %224[%c19_i32 : i32] : vector<32xf16> - //CHECK: %[[r264:.*]] = vector.splat %263 : vector<1x1xf16> - //CHECK: %[[r265:.*]] = vector.extractelement %224[%c20_i32 : i32] : vector<32xf16> - //CHECK: %[[r266:.*]] = vector.splat %265 : vector<1x1xf16> - //CHECK: %[[r267:.*]] = vector.extractelement %224[%c21_i32 : i32] : vector<32xf16> - //CHECK: %[[r268:.*]] = vector.splat %267 : vector<1x1xf16> - //CHECK: %[[r269:.*]] = vector.extractelement %224[%c22_i32 : i32] : vector<32xf16> - //CHECK: %[[r270:.*]] = vector.splat %269 : vector<1x1xf16> - //CHECK: %[[r271:.*]] = vector.extractelement %224[%c23_i32 : i32] : vector<32xf16> - //CHECK: %[[r272:.*]] = vector.splat %271 : vector<1x1xf16> - //CHECK: %[[r273:.*]] = vector.extractelement %224[%c24_i32 : i32] : vector<32xf16> - //CHECK: %[[r274:.*]] = vector.splat %273 : vector<1x1xf16> - //CHECK: %[[r275:.*]] = vector.extractelement %224[%c25_i32 : i32] : vector<32xf16> - //CHECK: %[[r276:.*]] = vector.splat %275 : vector<1x1xf16> - //CHECK: %[[r277:.*]] = vector.extractelement %224[%c26_i32 : i32] : vector<32xf16> - //CHECK: %[[r278:.*]] = vector.splat %277 : vector<1x1xf16> - //CHECK: %[[r279:.*]] = vector.extractelement %224[%c27_i32 : i32] : vector<32xf16> - //CHECK: %[[r280:.*]] = vector.splat %279 : vector<1x1xf16> - //CHECK: %[[r281:.*]] = vector.extractelement %224[%c28_i32 : i32] : vector<32xf16> - //CHECK: %[[r282:.*]] = vector.splat %281 : vector<1x1xf16> - //CHECK: %[[r283:.*]] = vector.extractelement %224[%c29_i32 : i32] : vector<32xf16> - //CHECK: %[[r284:.*]] = vector.splat %283 : vector<1x1xf16> - //CHECK: %[[r285:.*]] = vector.extractelement %224[%c30_i32 : i32] : vector<32xf16> - //CHECK: %[[r286:.*]] = vector.splat %285 : vector<1x1xf16> - //CHECK: %[[r287:.*]] = vector.extractelement %224[%c31_i32 : i32] : vector<32xf16> - //CHECK: %[[r288:.*]] = vector.splat %287 : vector<1x1xf16> - //CHECK: %[[r289:.*]] = vector.shuffle %[[r226]], %[[r228]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r290:.*]] = vector.shuffle %[[r230]], %[[r232]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r291:.*]] = vector.shuffle %[[r234]], %[[r236]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r292:.*]] = vector.shuffle %[[r238]], %[[r240]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r293:.*]] = vector.shuffle %[[r242]], %[[r244]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r294:.*]] = vector.shuffle %[[r246]], %[[r248]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r295:.*]] = vector.shuffle %[[r250]], %[[r252]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r296:.*]] = vector.shuffle %[[r254]], %[[r256]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r297:.*]] = vector.shuffle %[[r258]], %[[r260]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r298:.*]] = vector.shuffle %[[r262]], %[[r264]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r299:.*]] = vector.shuffle %[[r266]], %[[r268]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r300:.*]] = vector.shuffle %[[r270]], %[[r272]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r301:.*]] = vector.shuffle %[[r274]], %[[r276]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r302:.*]] = vector.shuffle %[[r278]], %[[r280]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r303:.*]] = vector.shuffle %[[r282]], %[[r284]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r304:.*]] = vector.shuffle %[[r286]], %[[r288]] [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK: %[[r305:.*]] = vector.shuffle %[[r289]], %[[r290]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r306:.*]] = vector.shuffle %[[r291]], %[[r292]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r307:.*]] = vector.shuffle %[[r293]], %[[r294]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r308:.*]] = vector.shuffle %[[r295]], %[[r296]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r309:.*]] = vector.shuffle %[[r297]], %[[r298]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r310:.*]] = vector.shuffle %[[r299]], %[[r300]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r311:.*]] = vector.shuffle %[[r301]], %[[r302]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r312:.*]] = vector.shuffle %[[r303]], %[[r304]] [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK: %[[r313:.*]] = vector.shuffle %[[r305]], %[[r306]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK: %[[r314:.*]] = vector.shuffle %[[r307]], %[[r308]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK: %[[r315:.*]] = vector.shuffle %[[r309]], %[[r310]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK: %[[r316:.*]] = vector.shuffle %[[r311]], %[[r312]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK: %[[r317:.*]] = vector.shuffle %[[r313]], %[[r314]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x1xf16>, vector<8x1xf16> - //CHECK: %[[r318:.*]] = vector.shuffle %[[r315]], %[[r316]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x1xf16>, vector<8x1xf16> - //CHECK: %[[r319:.*]] = vector.shuffle %[[r317]], %[[r318]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x1xf16>, vector<16x1xf16> - //CHECK: %[[r320:.*]] = xetile.broadcast %[[r319]] [1] : vector<32x1xf16> -> vector<32x64xf16> - //CHECK: %[[r321:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [0, 0], sizes = [32, 8], strides = [1, 1]} : vector<32x64xf16> to vector<32x8xf16> - //CHECK: %[[r322:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [0, 8], sizes = [32, 8], strides = [1, 1]} : vector<32x64xf16> to vector<32x8xf16> - //CHECK: %[[r323:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [0, 16], sizes = [32, 8], strides = [1, 1]} : vector<32x64xf16> to vector<32x8xf16> - //CHECK: %[[r324:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [0, 24], sizes = [32, 8], strides = [1, 1]} : vector<32x64xf16> to vector<32x8xf16> - //CHECK: %[[r325:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [0, 32], sizes = [32, 8], strides = [1, 1]} : vector<32x64xf16> to vector<32x8xf16> - //CHECK: %[[r326:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [0, 40], sizes = [32, 8], strides = [1, 1]} : vector<32x64xf16> to vector<32x8xf16> - //CHECK: %[[r327:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [0, 48], sizes = [32, 8], strides = [1, 1]} : vector<32x64xf16> to vector<32x8xf16> - //CHECK: %[[r328:.*]] = vector.extract_strided_slice %[[r320]] {offsets = [0, 56], sizes = [32, 8], strides = [1, 1]} : vector<32x64xf16> to vector<32x8xf16> - //CHECK: %[[r329:.*]] = xetile.transpose %[[r321]], [1, 0] : vector<32x8xf16> -> vector<8x32xf16> - //CHECK: %[[r330:.*]] = xetile.transpose %[[r322]], [1, 0] : vector<32x8xf16> -> vector<8x32xf16> - //CHECK: %[[r331:.*]] = xetile.transpose %[[r323]], [1, 0] : vector<32x8xf16> -> vector<8x32xf16> - //CHECK: %[[r332:.*]] = xetile.transpose %[[r324]], [1, 0] : vector<32x8xf16> -> vector<8x32xf16> - //CHECK: %[[r333:.*]] = xetile.transpose %[[r325]], [1, 0] : vector<32x8xf16> -> vector<8x32xf16> - //CHECK: %[[r334:.*]] = xetile.transpose %[[r326]], [1, 0] : vector<32x8xf16> -> vector<8x32xf16> - //CHECK: %[[r335:.*]] = xetile.transpose %[[r327]], [1, 0] : vector<32x8xf16> -> vector<8x32xf16> - //CHECK: %[[r336:.*]] = xetile.transpose %[[r328]], [1, 0] : vector<32x8xf16> -> vector<8x32xf16> - //CHECK: %[[r337:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r338:.*]] = xetile.init_tile %[[arg0]][%[[c8]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r339:.*]] = xetile.init_tile %[[arg0]][%[[c16]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r340:.*]] = xetile.init_tile %[[arg0]][%[[c24]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r341:.*]] = xetile.init_tile %[[arg0]][%[[c32]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r342:.*]] = xetile.init_tile %[[arg0]][%[[c40]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r343:.*]] = xetile.init_tile %[[arg0]][%[[c48]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: %[[r344:.*]] = xetile.init_tile %[[arg0]][%[[c56]], %[[c0]]] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r329]], %[[r337]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r330]], %[[r338]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r331]], %[[r339]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r332]], %[[r340]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r333]], %[[r341]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r334]], %[[r342]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r335]], %[[r343]] : vector<8x32xf16>, !xetile.tile<8x32xf16> - //CHECK: xetile.store_tile %[[r336]], %[[r344]] : vector<8x32xf16>, !xetile.tile<8x32xf16> + + //CHECK-COUNT-32: %{{.*}} = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<32xf16> + //CHECK-COUNT-32: %{{.*}} = vector.splat %{{.*}} : vector<1x8xf16> + + //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x8xf16>, vector<1x8xf16> + //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x8xf16>, vector<2x8xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x8xf16>, vector<4x8xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x8xf16>, vector<8x8xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x8xf16>, vector<16x8xf16> + //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x8xf16>, vector<1x8xf16> + //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x8xf16>, vector<2x8xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x8xf16>, vector<4x8xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x8xf16>, vector<8x8xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x8xf16>, vector<16x8xf16> + //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x8xf16>, vector<1x8xf16> + //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x8xf16>, vector<2x8xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x8xf16>, vector<4x8xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x8xf16>, vector<8x8xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x8xf16>, vector<16x8xf16> + //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x8xf16>, vector<1x8xf16> + //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x8xf16>, vector<2x8xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x8xf16>, vector<4x8xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x8xf16>, vector<8x8xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x8xf16>, vector<16x8xf16> + //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x8xf16>, vector<1x8xf16> + //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x8xf16>, vector<2x8xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x8xf16>, vector<4x8xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x8xf16>, vector<8x8xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x8xf16>, vector<16x8xf16> + //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x8xf16>, vector<1x8xf16> + //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x8xf16>, vector<2x8xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x8xf16>, vector<4x8xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x8xf16>, vector<8x8xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x8xf16>, vector<16x8xf16> + //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x8xf16>, vector<1x8xf16> + //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x8xf16>, vector<2x8xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x8xf16>, vector<4x8xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x8xf16>, vector<8x8xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x8xf16>, vector<16x8xf16> + //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x8xf16>, vector<1x8xf16> + //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x8xf16>, vector<2x8xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x8xf16>, vector<4x8xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x8xf16>, vector<8x8xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x8xf16>, vector<16x8xf16> + + //CHECK-COUNT-8: %{{.*}} = xetile.transpose %{{.*}}, [1, 0] : vector<32x8xf16> -> vector<8x32xf16> + //CHECK-COUNT-8: %{{.*}} = xetile.init_tile %[[arg0]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> + //CHECK-COUNT-8: xetile.store_tile %{{.*}}, %{{.*}} : vector<8x32xf16>, !xetile.tile<8x32xf16> %1 = xetile.init_tile %a[0, 0] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16> %2 = xetile.load_tile %1: !xetile.tile<32x64xf16> -> vector<32x64xf16> %3 = xetile.reduction , %2 [1]: vector<32x64xf16> -> vector<32x1xf16> @@ -1426,15 +1186,33 @@ gpu.module @test_kernel { //CHECK-COUNT-64: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [1, 32], strides = [1, 1]} : vector<8x32xf16> to vector<1x32xf16> //CHECK-COUNT-62: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1x32xf16> - //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> - //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<1x32xf16> to vector<32xf16> - //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf16>, vector<32xf16> + + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> %4 = xetile.reduction , %3 [0]: vector<32x64xf16> -> vector<1x64xf16> - //CHECK: %{{.*}} = vector.shape_cast %{{.*}} : vector<64xf16> to vector<1x64xf16> - //CHECK: %{{.*}} = xetile.broadcast %{{.*}} [0] : vector<1x64xf16> -> vector<32x64xf16> %5 = xetile.broadcast %4 [0]: vector<1x64xf16> -> vector<32x64xf16> - //CHECK-COUNT-4: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 64], strides = [1, 1]} : vector<32x64xf16> to vector<8x64xf16> - //CHECK-COUNT-8: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> //CHECK-COUNT-8: %{{.*}} = arith.divf %{{.*}}, %{{.*}} : vector<8x32xf16> %6 = arith.divf %3, %5: vector<32x64xf16> //CHECK-COUNT-8: %{{.*}} = xetile.init_tile %[[arg0]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> @@ -1477,14 +1255,34 @@ gpu.module @test_kernel { //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63] : vector<32xf16>, vector<32xf16> //CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<32xf16> - //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x1xf16>, vector<8x1xf16> - //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x1xf16>, vector<16x1xf16> - //CHECK: %{{.*}} = xetile.broadcast %{{.*}} [1] : vector<32x1xf16> -> vector<32x64xf16> - //CHECK-COUNT-4: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 64], strides = [1, 1]} : vector<32x64xf16> to vector<8x64xf16> - //CHECK-COUNT-8: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> + //CHECK-COUNT-32: %{{.*}} = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<32xf16> + //CHECK-COUNT-32: %{{.*}} = vector.splat %{{.*}} : vector<1x32xf16> + + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> + //CHECK-COUNT-8: %{{.*}} = arith.divf %{{.*}}, %{{.*}} : vector<8x32xf16> //CHECK-COUNT-8: %{{.*}} = xetile.init_tile %[[arg0]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> //CHECK-COUNT-8: xetile.store_tile %{{.*}}, %{{.*}} : vector<8x32xf16>, !xetile.tile<8x32xf16> @@ -1503,13 +1301,16 @@ gpu.module @test_kernel { //CHECK-SAME(%[[arg0:.*]]: memref<1024x1024xf16>) gpu.func @sglevel_softmax_transpose(%a: memref<1024x1024xf16>) { - //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x1xf16>, vector<1x1xf16> - //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> - //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> - //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x1xf16>, vector<8x1xf16> - //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x1xf16>, vector<16x1xf16> - //CHECK: %{{.*}} = xetile.broadcast %{{.*}} [1] : vector<32x1xf16> -> vector<32x64xf16> - //CHECK-COUNT-8: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [32, 8], strides = [1, 1]} : vector<32x64xf16> to vector<32x8xf16> + + //CHECK-COUNT-32: %{{.*}} = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<32xf16> + //CHECK-COUNT-32: %{{.*}} = vector.splat %{{.*}} : vector<1x8xf16> + + //CHECK-COUNT-16: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x8xf16>, vector<1x8xf16> + //CHECK-COUNT-8: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x8xf16>, vector<2x8xf16> + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x8xf16>, vector<4x8xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x8xf16>, vector<8x8xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x8xf16>, vector<16x8xf16> + //CHECK-COUNT-8: %{{.*}} = arith.divf %{{.*}}, %{{.*}} : vector<32x8xf16> //CHECK-COUNT-8: %{{.*}} = xetile.transpose %{{.*}}, [1, 0] : vector<32x8xf16> -> vector<8x32xf16> //CHECK-COUNT-8: %{{.*}} = xetile.init_tile %[[arg0]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xetile.tile<8x32xf16> @@ -1638,17 +1439,15 @@ gpu.module @test_kernel { %2 = xetile.load_tile %1 {padding = 0.000000e+00 : f32} : !xetile.tile<32x1xf32> -> vector<32x1xf32> //CHECK: %[[r4:.*]] = xetile.transpose %[[r2]], [1, 0] : vector<16x1xf32> -> vector<1x16xf32> //CHECK: %[[r5:.*]] = xetile.transpose %[[r3]], [1, 0] : vector<16x1xf32> -> vector<1x16xf32> - //CHECK: %[[r6:.*]] = vector.shape_cast %[[r4]] : vector<1x16xf32> to vector<16xf32> - //CHECK: %[[r7:.*]] = vector.shape_cast %[[r5]] : vector<1x16xf32> to vector<16xf32> - //CHECK: %[[r8:.*]] = vector.shuffle %[[r6]], %[[r7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> - //CHECK: %[[r9:.*]] = vector.shape_cast %[[r8]] : vector<32xf32> to vector<1x32xf32> %3 = xetile.transpose %2, [1, 0] : vector<32x1xf32> -> vector<1x32xf32> - //CHECK: %{{.*}} = xetile.broadcast %{{.*}} [0] : vector<1x32xf32> -> vector<64x32xf32> + + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x16xf32>, vector<1x16xf32> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x16xf32>, vector<2x16xf32> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x16xf32>, vector<4x16xf32> %4 = xetile.broadcast %3 [0] : vector<1x32xf32> -> vector<64x32xf32> + //CHECK-COUNT-16: %{{.*}} = xetile.init_tile %[[arg1]][%{{.*}}, %{{.*}}] : memref<256x384xf32> -> !xetile.tile<8x16xf32> %5 = xetile.init_tile %arg1[0, 0] : memref<256x384xf32> -> !xetile.tile<64x32xf32> - //CHECK-COUNT-8: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<64x32xf32> to vector<8x32xf32> - //CHECK-COUNT-16: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 16], strides = [1, 1]} : vector<8x32xf32> to vector<8x16xf32> //CHECK-COUNT-16: xetile.store_tile %{{.*}}, %{{.*}} : vector<8x16xf32>, !xetile.tile<8x16xf32> xetile.store_tile %4, %5 : vector<64x32xf32>, !xetile.tile<64x32xf32> gpu.return @@ -1659,15 +1458,21 @@ gpu.module @test_kernel { //CHECK: %[[r0:.*]] = xetile.init_tile %[[arg0]][0, 0] : memref<1x384xf16> -> !xetile.tile<1x32xf16> //CHECK: %[[r1:.*]] = xetile.load_tile %[[r0]] {padding = 0.000000e+00 : f32} : !xetile.tile<1x32xf16> -> vector<1x32xf16> //CHECK: %[[r2:.*]] = xetile.transpose %[[r1]], [1, 0] : vector<1x32xf16> -> vector<32x1xf16> - //CHECK: %[[r3:.*]] = xetile.broadcast %[[r2]] [1] : vector<32x1xf16> -> vector<32x64xf16> %1 = xetile.init_tile %arg0[0, 0] : memref<1x384xf16> -> !xetile.tile<1x32xf16> %2 = xetile.load_tile %1 {padding = 0.000000e+00 : f32} : !xetile.tile<1x32xf16> -> vector<1x32xf16> %3 = xetile.transpose %2, [1, 0] : vector<1x32xf16> -> vector<32x1xf16> + + //CHECK-COUNT-32: %{{.*}} = vector.extract_strided_slice %[[r2]] {offsets = [{{.*}}, 0], sizes = [1, 1], strides = [1, 1]} : vector<32x1xf16> to vector<1x1xf16> + + //CHECK: %{{.*}} = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16> + //CHECK: %{{.*}} = vector.splat %{{.*}} : vector<1x32xf16> + + //CHECK-COUNT-4: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x32xf16>, vector<1x32xf16> + //CHECK-COUNT-2: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x32xf16>, vector<2x32xf16> + //CHECK: %{{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x32xf16>, vector<4x32xf16> %4 = xetile.broadcast %3 [1] : vector<32x1xf16> -> vector<32x64xf16> //CHECK-COUNT-8: %{{.*}} = xetile.init_tile %[[arg1]][%{{.*}}, %{{.*}}] : memref<384x256xf16> -> !xetile.tile<8x32xf16> %5 = xetile.init_tile %arg1[0, 0] : memref<384x256xf16> -> !xetile.tile<32x64xf16> - //CHECK-COUNT-4: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 64], strides = [1, 1]} : vector<32x64xf16> to vector<8x64xf16> - //CHECK-COUNT-8: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 32], strides = [1, 1]} : vector<8x64xf16> to vector<8x32xf16> //CHECK-COUNT-8: xetile.store_tile %{{.*}}, %{{.*}} : vector<8x32xf16>, !xetile.tile<8x32xf16> xetile.store_tile %4, %5 : vector<32x64xf16>, !xetile.tile<32x64xf16> gpu.return diff --git a/test/Integration/Dialect/XeTile/xetile-to-func-vc.pp b/test/Integration/Dialect/XeTile/xetile-to-func-vc.pp index 9cd8594f6..e0335abc6 100644 --- a/test/Integration/Dialect/XeTile/xetile-to-func-vc.pp +++ b/test/Integration/Dialect/XeTile/xetile-to-func-vc.pp @@ -2,7 +2,7 @@ cse gpu.module(xetile-init-duplicate xetile-canonicalization - xetile-blocking + xetile-blocking{enable-2d-transform=true} cse convert-xetile-to-xegpu cse diff --git a/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp b/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp index 9037120f3..27352d9bc 100644 --- a/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp +++ b/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp @@ -4,7 +4,7 @@ cse xetile-init-duplicate xetile-canonicalization - xetile-blocking + xetile-blocking{enable-2d-transform=true} canonicalize convert-xetile-to-xegpu cse