Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefetch 2D: encode dummy data type as part of intrinsic name as the … #949

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions lib/Conversion/XeGPUToVC/LSCPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ static std::string getLSCIntrinsicStr(llvm::StringRef opName, int simd_lanes,
// lsc.load/store/prefetch.2d.ugm. The fullname is in format of
// 1. lsc.load.2d.ugm.desc.<transform>.<retType>.<cache_controls>
// 2. lsc.store.2d.ugm.desc.<cacheCtrType>.<dataType>
// 3. lsc.prefetch.2d.ugm.desc.<predType>
// 3. lsc.prefetch.2d.ugm.desc.<predType>.<dataType>
// All the types are encoded as vN[i/f]M, where N is the number of elements,
// and M is the bit width. So for vector<16xf32>, it will be v16f32, and for
// vector<16xi1>, it will be v16i1. cacheCtrType is fixed to vNi8, where N is
Expand Down Expand Up @@ -228,8 +228,8 @@ static std::string getBlockIntrinsicStr(llvm::StringRef opName,
cache_levels, dataTyStr)
.str();
} else if (opName == "prefetch") {
return llvm::formatv("llvm.genx.lsc.prefetch.2d.ugm.desc.v{0}i8",
cache_levels)
return llvm::formatv("llvm.genx.lsc.prefetch.2d.ugm.desc.v{0}i8.{1}",
cache_levels, dataTyStr)
.str();
}
llvm_unreachable("unsupported opName");
Expand Down Expand Up @@ -675,34 +675,22 @@ gen2DPrefetchIntrinsicCall(ConversionPatternRewriter &rewriter, Location &loc,
assert(tdescTy.getRank() == 2 && !tdescTy.isScattered() &&
"Only works on 2D block TensorDesc.");

auto intrinsicStr = getBlockIntrinsicStr("prefetch");
auto nblks = tdescTy.getArrayLength();
auto shape = tdescTy.getShape();
auto elemTy = tdescTy.getElementType();
auto noRetTy = TypeRange({});
auto bitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();

// Sub 32bit data types are packed into 32bit data types (i32).
auto packFactor = 32 / bitWidth;

// If packing is needed, the innermost dimensions gets scaled by the packing
// factor. In such case, the shape[1] must be a multiple of the pack factor.
// Otherwise, packing cannot be done correctly
if (packFactor > 1) {
assert(
shape[1] % packFactor == 0 &&
"shape[1] must be a multiple of pack factor (32 / element bitwidth)");
}
auto bitWidth = elemTy.getIntOrFloatBitWidth();
auto prefix = elemTy.isInteger() ? "i" : elemTy.isBF16() ? "bf" : "f";
auto typeStr = llvm::formatv("{0}{1}", prefix, bitWidth).str();
auto intrinsicStr = getBlockIntrinsicStr("prefetch", typeStr);

// for arg8: dummy value, type has to be always the same since intrinsic
// func name for prefetch is the same regardless of the element type.
// Different type used for dummy causes type conflict in case of multiple
// calls with different dummy arg type.
auto attr = (TypedAttr)rewriter.getIntegerAttr(rewriter.getI32Type(), 0);
// for arg8: dummy value
auto attr = elemTy.isInteger()
? (TypedAttr)rewriter.getIntegerAttr(elemTy, 0)
: (TypedAttr)rewriter.getFloatAttr(elemTy, 0.0);
auto dummy = constant_val(attr);
return gen2DBlockIntrinsicCall(
rewriter, loc, intrinsicStr, noRetTy, l1, l3, nblks,
{shape[0], bitWidth == 64 ? shape[1] * 2 : shape[1] / packFactor},
payload, dummy);
return gen2DBlockIntrinsicCall(rewriter, loc, intrinsicStr, noRetTy, l1, l3,
nblks, shape, payload, dummy);
}

// generate a call to lsc.store.2d.ugm.* intrinsic for 2D block store, which is
Expand Down
16 changes: 10 additions & 6 deletions test/Conversion/XeGPUToVC/prefetchnd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
module @gemm attributes {gpu.container_module} {

gpu.module @test_kernel {

gpu.func @test_prefetch(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %{{.*}} : memref<8x16xf16> -> index
Expand Down Expand Up @@ -50,14 +51,15 @@ module @gemm attributes {gpu.container_module} {
//CHECK: %[[r26:.*]] = vector.insert %[[c1807_i32]], %[[r25]] [7] : i32 into vector<16xi32>
%2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>

//LSC: %[[cst_2:.*]] = arith.constant 0.000000e+00 : f16
//LSC: %[[true:.*]] = arith.constant true
//LSC: %[[c0_i8:.*]] = arith.constant 0 : i8
//LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8>
//LSC: %[[c1_i8:.*]] = arith.constant 1 : i8
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c16_i16:.*]] = arith.constant 16 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f16(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> ()
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f16(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c16_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[cst_2]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> ()
xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1 : !xegpu.tensor_desc<16x16xf16>

Expand Down Expand Up @@ -94,14 +96,16 @@ module @two_type attributes {gpu.container_module} {
%2 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>

//LSC: %[[c0_i32:.*]] = arith.constant 0 : i32
//LSC: %[[c0_f16:.*]] = arith.constant 0.000000e+00 : f16
//LSC: %[[true:.*]] = arith.constant true
//LSC: %[[c0_i8:.*]] = arith.constant 0 : i8
//LSC: %[[r27:.*]] = vector.from_elements %[[c0_i8]], %[[c0_i8]] : vector<2xi8>
//LSC: %[[c1_i8:.*]] = arith.constant 1 : i8
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c8_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c16_i16:.*]] = arith.constant 16 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_i32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, i32) -> ()
//LSC: %[[c8_i16:.*]] = arith.constant 8 : i16
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f16(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r8]], %[[c0_i32]], %[[c0_i32]], %[[c0_f16]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f16) -> ()
//LSC: %[[c0_f32:.*]] = arith.constant 0.000000e+00 : f32
//LSC: func.call @llvm.genx.lsc.prefetch.2d.ugm.desc.v2i8.f32(%[[true]], %[[r27]], %[[c1_i8]], %[[c16_i16]], %[[c8_i16]], %[[r17]], %[[c0_i32]], %[[c0_i32]], %[[c0_f32]]) : (i1, vector<2xi8>, i8, i16, i16, vector<16xi32>, i32, i32, f32) -> ()
xegpu.prefetch_nd %0 : !xegpu.tensor_desc<8x16xf16>
xegpu.prefetch_nd %1 : !xegpu.tensor_desc<8x16xf32>

Expand Down
Loading