diff --git a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp index 28b881b6c..c6dfec61b 100644 --- a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp +++ b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp @@ -688,16 +688,31 @@ gen2DPrefetchIntrinsicCall(ConversionPatternRewriter &rewriter, Location &loc, auto intrinsicStr = getBlockIntrinsicStr("prefetch"); auto nblks = tdescTy.getArrayLength(); auto shape = tdescTy.getShape(); - auto elemTy = tdescTy.getElementType(); auto noRetTy = TypeRange({}); + auto bitWidth = tdescTy.getElementType().getIntOrFloatBitWidth(); + + // Sub 32bit data types are packed into 32bit data types (i32). + auto packFactor = 32 / bitWidth; + + // If packing is needed, the innermost dimensions gets scaled by the packing + // factor. In such case, the shape[1] must be a multiple of the pack factor. + // Otherwise, packing cannot be done correctly + if (packFactor > 1) { + assert( + shape[1] % packFactor == 0 && + "shape[1] must be a multiple of pack factor (32 / element bitwidth)"); + } - // for arg8: dummy value - auto attr = elemTy.isInteger() - ? (TypedAttr)rewriter.getIntegerAttr(elemTy, 0) - : (TypedAttr)rewriter.getFloatAttr(elemTy, 0.0); + // for arg8: dummy value, type has to be always the same since intrinsic + // func name for prefetch is the same regardless of the element type. + // Different type used for dummy causes type conflict in case of multiple + // calls with different dummy arg type. + auto attr = (TypedAttr)rewriter.getIntegerAttr(rewriter.getI32Type(), 0); auto dummy = constant_val(attr); - return gen2DBlockIntrinsicCall(rewriter, loc, intrinsicStr, noRetTy, l1, l3, - nblks, shape, payload, dummy); + return gen2DBlockIntrinsicCall( + rewriter, loc, intrinsicStr, noRetTy, l1, l3, nblks, + {shape[0], bitWidth == 64 ? shape[1] * 2 : shape[1] / packFactor}, + payload, dummy); } // generate a call to lsc.store.2d.ugm.* intrinsic for 2D block store, which is diff --git a/test/Conversion/XeGPUToVC/prefetchnd.mlir b/test/Conversion/XeGPUToVC/prefetchnd.mlir index ce5e4598a..3f71d93f1 100644 --- a/test/Conversion/XeGPUToVC/prefetchnd.mlir +++ b/test/Conversion/XeGPUToVC/prefetchnd.mlir @@ -1,14 +1,9 @@ -// RUN: imex-opt -convert-xegpu-to-vc --cse %s | FileCheck %s --check-prefixes=CHECK,LSC +// RUN: imex-opt --split-input-file -convert-xegpu-to-vc --cse %s | FileCheck %s --check-prefixes=CHECK,LSC + +// ----- module @gemm attributes {gpu.container_module} { gpu.module @test_kernel { - - //RAW: func.func private @llvm.genx.raw.sends2.noresult.i1.v16i32.v128f32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<128xf32>) attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} - //RAW: func.func private @llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32(vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} - //RAW: func.func private @llvm.genx.raw.send2.v128i32.i1.v16i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<128xi32>) -> vector<128xi32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} - //RAW: func.func private @llvm.genx.raw.send2.v64i32.i1.v16i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<64xi32>) -> vector<64xi32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} - //RAW: func.func private @llvm.genx.raw.send2.noresult.i1.v16i32(i8, i8, i1, i8, i8, i32, i32, vector<16xi32>) attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} - gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %{{.*}} : memref<8x16xf16> -> index @@ -55,15 +50,14 @@ module @gemm attributes {gpu.container_module} { //CHECK: %[[r26:.*]] = vector.insert %[[c1807_i32]], %[[r25]] [7] : i32 into vector<16xi32> %2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - //LSC: %[[cst_2:.*]] = arith.constant 0.000000e+00 : f16 //LSC: %[[true:.*]] = arith.constant true //LSC: %[[c0_i8:.*]] = arith.constant 0 : i8 //LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8> //LSC: %[[c1_i8:.*]] = arith.constant 1 : i8 - //LSC: %[[c16_i16:.*]] = arith.constant 16 : i16 //LSC: %[[c8_i16:.*]] = arith.constant 8 : i16 - //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> () - //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> () + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () + //LSC: %[[c16_i16:.*]] = arith.constant 16 : i16 + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16> xegpu.prefetch_nd %1 : !xegpu.tensor_desc<16x16xf16> @@ -89,3 +83,34 @@ module @gemm attributes {gpu.container_module} { } } + +// ----- +module @two_type attributes {gpu.container_module} { + + gpu.module @test_kernel { + gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf32>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + + //LSC: %[[c0_i32:.*]] = arith.constant 0 : i32 + //LSC: %[[true:.*]] = arith.constant true + //LSC: %[[c0_i8:.*]] = arith.constant 0 : i8 + //LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8> + //LSC: %[[c1_i8:.*]] = arith.constant 1 : i8 + //LSC: %[[c8_i16:.*]] = arith.constant 8 : i16 + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () + //LSC: %[[c16_i16:.*]] = arith.constant 16 : i16 + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () + xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16> + xegpu.prefetch_nd %1 : !xegpu.tensor_desc<8x16xf32> + + %3 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %4 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } + + } +}