diff --git a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp index 05562a390..550a2257a 100644 --- a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp +++ b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp @@ -792,7 +792,7 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase { unsigned rank = type.getRank(); auto elemType = type.getElementType(); - if (rank < 1 || type.getNumElements() == 1) + if (rank < 1) return elemType; unsigned sum = 1; diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index 4cbd8815b..00c7fc5b9 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -905,16 +905,19 @@ struct SgTileReductionOpPattern sources, shape, op.getKind(), loc, elemTy, rewriter); llvm::SmallVector newOps; { - // intermediate is a vector of values with type of vector, - // each value represents a portion of the reduced value. For example, + // intermediate is a vector of values with type of vector + // (where n is max of min(shape[0]/2,16) and 1), + // each element is the reduced value for a row. For example, // for vector<32x4x1x16> with reduction on dim 1 and dim 3. the - // intermediate values will be two vectors of vector<16xf16>. The values + // intermediate values will be two values of vector<16xf16>. The values // in the first vector represents the reduction result of the first 16 // rows. Here we will extract each value and splat it to a vector<1x1xf16> // as results to their consumers. for (auto v : intermediates) { auto targetTy = mlir::VectorType::get({1, 1}, elemTy); - for (auto i = 0; i < shape[3]; i++) { + auto vecTy = mlir::dyn_cast(v.getType()); + assert(vecTy && "expect vector type"); + for (auto i = 0; i < vecTy.getShape()[0]; i++) { auto pos = rewriter.create( op.getLoc(), rewriter.getI32IntegerAttr(i)); auto extractOp = diff --git a/test/Conversion/XeTileToXeGPU/reduction.mlir b/test/Conversion/XeTileToXeGPU/reduction.mlir index a428b2ade..9a1eeebea 100644 --- a/test/Conversion/XeTileToXeGPU/reduction.mlir +++ b/test/Conversion/XeTileToXeGPU/reduction.mlir @@ -1,4 +1,4 @@ -// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \ +// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-canonicalization --xetile-blocking \ // RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s module { gpu.module @test_kernel { @@ -111,54 +111,6 @@ module { //CHECK: {{.*}} = arith.constant {{.*}} : i32 //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16> //CHECK-COUNT-8: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x1xf16>, vector<1x1xf16> //CHECK-COUNT-4: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16> //CHECK-COUNT-2: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16> @@ -250,30 +202,6 @@ module { //CHECK: {{.*}} = arith.constant {{.*}} : i32 //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> - //CHECK: {{.*}} = arith.constant {{.*}} : i32 - //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32> - //CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32> //CHECK-COUNT-4: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x1xf32>, vector<1x1xf32> //CHECK-COUNT-2: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x1xf32>, vector<2x1xf32> //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf32>, vector<4x1xf32> @@ -283,6 +211,43 @@ module { gpu.return } + gpu.func @inner_reduction_small_size_1(%arg0: memref<*xf32>, %arg1: memref<*xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense<0.000000e+00> : vector<1xf32> + %cst_0 = arith.constant dense : vector<1x16xi1> + %cst_1 = arith.constant dense : vector<1x1xi1> + %cst_2 = arith.constant dense<0> : vector<1x1xindex> + %cst_3 = arith.constant dense<0> : vector<1x16xindex> + %cast = memref.cast %arg0 : memref<*xf32> to memref + %cast_4 = memref.cast %arg1 : memref<*xf32> to memref + %0 = xetile.init_tile %cast, %cst_3 : memref, vector<1x16xindex> -> !xetile.tile<1x16xf32, #xetile.tile_attr> + %1 = xetile.load %0, %cst_0 : !xetile.tile<1x16xf32, #xetile.tile_attr>, vector<1x16xi1> -> vector<1x16xf32> + //CHECK: {{.*}} = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32> + //CHECK: {{.*}} = vector.shape_cast %{{.*}} : vector<1x16xf32> to vector<16xf32> + //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<16xf32>, vector<16xf32> + //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf32>, vector<16xf32> + //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<8xf32> + //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<8xf32>, vector<8xf32> + //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7] : vector<8xf32>, vector<8xf32> + //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<4xf32> + //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<4xf32>, vector<4xf32> + //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [2, 3] : vector<4xf32>, vector<4xf32> + //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<2xf32> + //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0] : vector<2xf32>, vector<2xf32> + //CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [1] : vector<2xf32>, vector<2xf32> + //CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1xf32> + //CHECK: {{.*}} = arith.constant {{.*}} : i32 + //CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<1xf32> + //CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf32> + //CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x1xf32> to vector<1xf32> + //CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1xf32> + + %2 = vector.multi_reduction , %1, %cst [1] : vector<1x16xf32> to vector<1xf32> + %3 = vector.shape_cast %2 : vector<1xf32> to vector<1x1xf32> + %4 = xetile.init_tile %cast_4, %cst_2 : memref, vector<1x1xindex> -> !xetile.tile<1x1xf32, #xetile.tile_attr> + xetile.store %3, %4, %cst_1 : vector<1x1xf32>, !xetile.tile<1x1xf32, #xetile.tile_attr>, vector<1x1xi1> + gpu.return + } + //CHECK: gpu.func @outter_reduction(%[[arg0:.*]]: memref<128x256xf16>, %[[arg1:.*]]: memref<128x256xf16>) { gpu.func @outter_reduction(%a: memref<128x256xf16>, %b: memref<128x256xf16>) { //CHECK: %[[c0:.*]] = arith.constant 0 : index