Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSC 2D prefetch: Unify type of dummy data arg to i32 #940

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions lib/Conversion/XeGPUToVC/LSCPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,16 +688,31 @@ gen2DPrefetchIntrinsicCall(ConversionPatternRewriter &rewriter, Location &loc,
auto intrinsicStr = getBlockIntrinsicStr("prefetch");
auto nblks = tdescTy.getArrayLength();
auto shape = tdescTy.getShape();
auto elemTy = tdescTy.getElementType();
auto noRetTy = TypeRange({});
auto bitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();

// Sub 32bit data types are packed into 32bit data types (i32).
auto packFactor = 32 / bitWidth;

// If packing is needed, the innermost dimensions gets scaled by the packing
// factor. In such case, the shape[1] must be a multiple of the pack factor.
// Otherwise, packing cannot be done correctly
if (packFactor > 1) {
assert(
shape[1] % packFactor == 0 &&
"shape[1] must be a multiple of pack factor (32 / element bitwidth)");
}

// for arg8: dummy value
auto attr = elemTy.isInteger()
? (TypedAttr)rewriter.getIntegerAttr(elemTy, 0)
: (TypedAttr)rewriter.getFloatAttr(elemTy, 0.0);
// for arg8: dummy value, type has to be always the same since intrinsic
// func name for prefetch is the same regardless of the element type.
// Different type used for dummy causes type conflict in case of multiple
// calls with different dummy arg type.
auto attr = (TypedAttr)rewriter.getIntegerAttr(rewriter.getI32Type(), 0);
auto dummy = constant_val(attr);
return gen2DBlockIntrinsicCall(rewriter, loc, intrinsicStr, noRetTy, l1, l3,
nblks, shape, payload, dummy);
return gen2DBlockIntrinsicCall(
rewriter, loc, intrinsicStr, noRetTy, l1, l3, nblks,
{shape[0], bitWidth == 64 ? shape[1] * 2 : shape[1] / packFactor},
payload, dummy);
}

// generate a call to lsc.store.2d.ugm.* intrinsic for 2D block store, which is
Expand Down
49 changes: 37 additions & 12 deletions test/Conversion/XeGPUToVC/prefetchnd.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
// RUN: imex-opt -convert-xegpu-to-vc --cse %s | FileCheck %s --check-prefixes=CHECK,LSC
// RUN: imex-opt --split-input-file -convert-xegpu-to-vc --cse %s | FileCheck %s --check-prefixes=CHECK,LSC

// -----
module @gemm attributes {gpu.container_module} {

gpu.module @test_kernel {

//RAW: func.func private @llvm.genx.raw.sends2.noresult.i1.v16i32.v128f32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<128xf32>) attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes<linkage_name = "llvm.genx.raw.sends2.noresult.i1.v16i32.v128f32", linkage_type = <Import>>}
//RAW: func.func private @llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32(vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes<linkage_name = "llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32", linkage_type = <Import>>}
//RAW: func.func private @llvm.genx.raw.send2.v128i32.i1.v16i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<128xi32>) -> vector<128xi32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes<linkage_name = "llvm.genx.raw.send2.v128i32.i1.v16i32", linkage_type = <Import>>}
//RAW: func.func private @llvm.genx.raw.send2.v64i32.i1.v16i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<16xi32>, vector<64xi32>) -> vector<64xi32> attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes<linkage_name = "llvm.genx.raw.send2.v64i32.i1.v16i32", linkage_type = <Import>>}
//RAW: func.func private @llvm.genx.raw.send2.noresult.i1.v16i32(i8, i8, i1, i8, i8, i32, i32, vector<16xi32>) attributes {VectorComputeFunctionINTEL, linkage_attributes = #spirv.linkage_attributes<linkage_name = "llvm.genx.raw.send2.noresult.i1.v16i32", linkage_type = <Import>>}

gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %{{.*}} : memref<8x16xf16> -> index
Expand Down Expand Up @@ -55,15 +50,14 @@ module @gemm attributes {gpu.container_module} {
//CHECK: %[[r26:.*]] = vector.insert %[[c1807_i32]], %[[r25]] [7] : i32 into vector<16xi32>
%2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>

//LSC: %[[cst_2:.*]] = arith.constant 0.000000e+00 : f16
//LSC: %[[true:.*]] = arith.constant true
//LSC: %[[c0_i8:.*]] = arith.constant 0 : i8
//LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8>
//LSC: %[[c1_i8:.*]] = arith.constant 1 : i8
//LSC: %[[c16_i16:.*]] = arith.constant 16 : i16
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> ()
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> ()
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c16_i16:.*]] = arith.constant 16 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1 : !xegpu.tensor_desc<16x16xf16>

Expand All @@ -89,3 +83,34 @@ module @gemm attributes {gpu.container_module} {

}
}

// -----
module @two_type attributes {gpu.container_module} {

gpu.module @test_kernel {
gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<8x16xf32>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>

//LSC: %[[c0_i32:.*]] = arith.constant 0 : i32
//LSC: %[[true:.*]] = arith.constant true
//LSC: %[[c0_i8:.*]] = arith.constant 0 : i8
//LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8>
//LSC: %[[c1_i8:.*]] = arith.constant 1 : i8
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c16_i16:.*]] = arith.constant 16 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1 : !xegpu.tensor_desc<8x16xf32>

%3 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%4 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>

xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.return
}

}
}
Loading