diff --git a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp index 5a9c24c2b..d637c0cd1 100644 --- a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp +++ b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp @@ -199,7 +199,7 @@ static std::string getLSCIntrinsicStr(llvm::StringRef opName, int simd_lanes, // lsc.load/store/prefetch.2d.ugm. The fullname is in format of // 1. lsc.load.2d.ugm.desc... // 2. lsc.store.2d.ugm.desc.. -// 3. lsc.prefetch.2d.ugm.desc. +// 3. lsc.prefetch.2d.ugm.desc.. // All the types are encoded as vN[i/f]M, where N is the number of elements, // and M is the bit width. So for vector<16xf32>, it will be v16f32, and for // vector<16xi1>, it will be v16i1. cacheCtrType is fixed to vNi8, where N is @@ -228,8 +228,8 @@ static std::string getBlockIntrinsicStr(llvm::StringRef opName, cache_levels, dataTyStr) .str(); } else if (opName == "prefetch") { - return llvm::formatv("llvm.genx.lsc.prefetch.2d.ugm.desc.v{0}i8", - cache_levels) + return llvm::formatv("llvm.genx.lsc.prefetch.2d.ugm.desc.v{0}i8.{1}", + cache_levels, dataTyStr) .str(); } llvm_unreachable("unsupported opName"); @@ -675,34 +675,22 @@ gen2DPrefetchIntrinsicCall(ConversionPatternRewriter &rewriter, Location &loc, assert(tdescTy.getRank() == 2 && !tdescTy.isScattered() && "Only works on 2D block TensorDesc."); - auto intrinsicStr = getBlockIntrinsicStr("prefetch"); auto nblks = tdescTy.getArrayLength(); auto shape = tdescTy.getShape(); + auto elemTy = tdescTy.getElementType(); auto noRetTy = TypeRange({}); - auto bitWidth = tdescTy.getElementType().getIntOrFloatBitWidth(); - - // Sub 32bit data types are packed into 32bit data types (i32). - auto packFactor = 32 / bitWidth; - - // If packing is needed, the innermost dimensions gets scaled by the packing - // factor. In such case, the shape[1] must be a multiple of the pack factor. - // Otherwise, packing cannot be done correctly - if (packFactor > 1) { - assert( - shape[1] % packFactor == 0 && - "shape[1] must be a multiple of pack factor (32 / element bitwidth)"); - } + auto bitWidth = elemTy.getIntOrFloatBitWidth(); + auto prefix = elemTy.isInteger() ? "i" : elemTy.isBF16() ? "bf" : "f"; + auto typeStr = llvm::formatv("{0}{1}", prefix, bitWidth).str(); + auto intrinsicStr = getBlockIntrinsicStr("prefetch", typeStr); - // for arg8: dummy value, type has to be always the same since intrinsic - // func name for prefetch is the same regardless of the element type. - // Different type used for dummy causes type conflict in case of multiple - // calls with different dummy arg type. - auto attr = (TypedAttr)rewriter.getIntegerAttr(rewriter.getI32Type(), 0); + // for arg8: dummy value + auto attr = elemTy.isInteger() + ? (TypedAttr)rewriter.getIntegerAttr(elemTy, 0) + : (TypedAttr)rewriter.getFloatAttr(elemTy, 0.0); auto dummy = constant_val(attr); - return gen2DBlockIntrinsicCall( - rewriter, loc, intrinsicStr, noRetTy, l1, l3, nblks, - {shape[0], bitWidth == 64 ? shape[1] * 2 : shape[1] / packFactor}, - payload, dummy); + return gen2DBlockIntrinsicCall(rewriter, loc, intrinsicStr, noRetTy, l1, l3, + nblks, shape, payload, dummy); } // generate a call to lsc.store.2d.ugm.* intrinsic for 2D block store, which is diff --git a/test/Conversion/XeGPUToVC/prefetchnd.mlir b/test/Conversion/XeGPUToVC/prefetchnd.mlir index 3f71d93f1..0bcd9ee09 100644 --- a/test/Conversion/XeGPUToVC/prefetchnd.mlir +++ b/test/Conversion/XeGPUToVC/prefetchnd.mlir @@ -4,6 +4,7 @@ module @gemm attributes {gpu.container_module} { gpu.module @test_kernel { + gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %{{.*}} : memref<8x16xf16> -> index @@ -50,14 +51,15 @@ module @gemm attributes {gpu.container_module} { //CHECK: %[[r26:.*]] = vector.insert %[[c1807_i32]], %[[r25]] [7] : i32 into vector<16xi32> %2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + //LSC: %[[cst_2:.*]] = arith.constant 0.000000e+00 : f16 //LSC: %[[true:.*]] = arith.constant true //LSC: %[[c0_i8:.*]] = arith.constant 0 : i8 //LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8> //LSC: %[[c1_i8:.*]] = arith.constant 1 : i8 - //LSC: %[[c8_i16:.*]] = arith.constant 8 : i16 - //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () //LSC: %[[c16_i16:.*]] = arith.constant 16 : i16 - //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () + //LSC: %[[c8_i16:.*]] = arith.constant 8 : i16 + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f16(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> () + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f16(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> () xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16> xegpu.prefetch_nd %1 : !xegpu.tensor_desc<16x16xf16> @@ -94,14 +96,16 @@ module @two_type attributes {gpu.container_module} { %2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> //LSC: %[[c0_i32:.*]] = arith.constant 0 : i32 + //LSC: %[[c0_f16:.*]] = arith.constant 0.000000e+00 : f16 //LSC: %[[true:.*]] = arith.constant true //LSC: %[[c0_i8:.*]] = arith.constant 0 : i8 //LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8> //LSC: %[[c1_i8:.*]] = arith.constant 1 : i8 - //LSC: %[[c8_i16:.*]] = arith.constant 8 : i16 - //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () //LSC: %[[c16_i16:.*]] = arith.constant 16 : i16 - //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () + //LSC: %[[c8_i16:.*]] = arith.constant 8 : i16 + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f16(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_f16]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> () + //LSC: %[[c0_f32:.*]] = arith.constant 0.000000e+00 : f32 + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f32(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_f32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f32) -> () xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16> xegpu.prefetch_nd %1 : !xegpu.tensor_desc<8x16xf32>