From d96e712ae4859f247fd001be0c659b801d0a40a8 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 23 Oct 2024 12:04:02 -0700 Subject: [PATCH] LSC 2D prefetch: Unify type of dummy data arg to i32 since the generated intinrsic call is independent of dummy data arg type. Type mismatch error happens if multiple callers use different types for dummy data arg. For non 32bit element type, normalize 2D prefetch shape around i32 data type. Inner dimension gets scaled according to ratio between 32 and bitwidth of element type Add mixed type 2D prefetch test case. --- lib/Conversion/XeGPUToVC/LSCPatterns.cpp | 29 ++++++++++---- test/Conversion/XeGPUToVC/prefetchnd.mlir | 49 +++++++++++++++++------ 2 files changed, 59 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp index 28b881b6c..c6dfec61b 100644 --- a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp +++ b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp @@ -688,16 +688,31 @@ gen2DPrefetchIntrinsicCall(ConversionPatternRewriter &rewriter, Location &loc, auto intrinsicStr = getBlockIntrinsicStr("prefetch"); auto nblks = tdescTy.getArrayLength(); auto shape = tdescTy.getShape(); - auto elemTy = tdescTy.getElementType(); auto noRetTy = TypeRange({}); + auto bitWidth = tdescTy.getElementType().getIntOrFloatBitWidth(); + + // Sub 32bit data types are packed into 32bit data types (i32). + auto packFactor = 32 / bitWidth; + + // If packing is needed, the innermost dimensions gets scaled by the packing + // factor. In such case, the shape[1] must be a multiple of the pack factor. + // Otherwise, packing cannot be done correctly + if (packFactor > 1) { + assert( + shape[1] % packFactor == 0 && + "shape[1] must be a multiple of pack factor (32 / element bitwidth)"); + } - // for arg8: dummy value - auto attr = elemTy.isInteger() - ? (TypedAttr)rewriter.getIntegerAttr(elemTy, 0) - : (TypedAttr)rewriter.getFloatAttr(elemTy, 0.0); + // for arg8: dummy value, type has to be always the same since intrinsic + // func name for prefetch is the same regardless of the element type. + // Different type used for dummy causes type conflict in case of multiple + // calls with different dummy arg type. + auto attr = (TypedAttr)rewriter.getIntegerAttr(rewriter.getI32Type(), 0); auto dummy = constant_val(attr); - return gen2DBlockIntrinsicCall(rewriter, loc, intrinsicStr, noRetTy, l1, l3, - nblks, shape, payload, dummy); + return gen2DBlockIntrinsicCall( + rewriter, loc, intrinsicStr, noRetTy, l1, l3, nblks, + {shape[0], bitWidth == 64 ? shape[1] * 2 : shape[1] / packFactor}, + payload, dummy); } // generate a call to lsc.store.2d.ugm.* intrinsic for 2D block store, which is diff --git a/test/Conversion/XeGPUToVC/prefetchnd.mlir b/test/Conversion/XeGPUToVC/prefetchnd.mlir index ce5e4598a..3f71d93f1 100644 --- a/test/Conversion/XeGPUToVC/prefetchnd.mlir +++ b/test/Conversion/XeGPUToVC/prefetchnd.mlir @@ -1,14 +1,9 @@ -// RUN: imex-opt -convert-xegpu-to-vc --cse %s | FileCheck %s --check-prefixes=CHECK,LSC +// RUN: imex-opt --split-input-file -convert-xegpu-to-vc --cse %s | FileCheck %s --check-prefixes=CHECK,LSC + +// ----- module @gemm attributes {gpu.container_module} { gpu.module @test_kernel { - - //RAW: func.func private @llvm.genx.raw.sends2.noresult.i1.v16i32.v128f32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<128xf32>) attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} - //RAW: func.func private @llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32(vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} - //RAW: func.func private @llvm.genx.raw.send2.v128i32.i1.v16i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<128xi32>) -> vector<128xi32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} - //RAW: func.func private @llvm.genx.raw.send2.v64i32.i1.v16i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<64xi32>) -> vector<64xi32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} - //RAW: func.func private @llvm.genx.raw.send2.noresult.i1.v16i32(i8, i8, i1, i8, i8, i32, i32, vector<16xi32>) attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes>} - gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %{{.*}} : memref<8x16xf16> -> index @@ -55,15 +50,14 @@ module @gemm attributes {gpu.container_module} { //CHECK: %[[r26:.*]] = vector.insert %[[c1807_i32]], %[[r25]] [7] : i32 into vector<16xi32> %2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - //LSC: %[[cst_2:.*]] = arith.constant 0.000000e+00 : f16 //LSC: %[[true:.*]] = arith.constant true //LSC: %[[c0_i8:.*]] = arith.constant 0 : i8 //LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8> //LSC: %[[c1_i8:.*]] = arith.constant 1 : i8 - //LSC: %[[c16_i16:.*]] = arith.constant 16 : i16 //LSC: %[[c8_i16:.*]] = arith.constant 8 : i16 - //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> () - //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> () + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () + //LSC: %[[c16_i16:.*]] = arith.constant 16 : i16 + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16> xegpu.prefetch_nd %1 : !xegpu.tensor_desc<16x16xf16> @@ -89,3 +83,34 @@ module @gemm attributes {gpu.container_module} { } } + +// ----- +module @two_type attributes {gpu.container_module} { + + gpu.module @test_kernel { + gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf32>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + + //LSC: %[[c0_i32:.*]] = arith.constant 0 : i32 + //LSC: %[[true:.*]] = arith.constant true + //LSC: %[[c0_i8:.*]] = arith.constant 0 : i8 + //LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8> + //LSC: %[[c1_i8:.*]] = arith.constant 1 : i8 + //LSC: %[[c8_i16:.*]] = arith.constant 8 : i16 + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () + //LSC: %[[c16_i16:.*]] = arith.constant 16 : i16 + //LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> () + xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16> + xegpu.prefetch_nd %1 : !xegpu.tensor_desc<8x16xf32> + + %3 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %4 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } + + } +}