Skip to content

Commit f191671

Browse files
authored
Add an e2e test for XeTile gather/scatter on SLM (#988)
1 parent 601a5f5 commit f191671

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
10+
// NOTES :
11+
// This example assumes one subgroup per one workgroup and the kernel specifies the computation
12+
// done by a single subgroup.
13+
14+
module @gemm attributes {gpu.container_module} {
15+
// a test case case return the transpose of A, which is viewed as memref<32x32xf16>.
16+
// it uses one workgroup containing 32 subgroups, organized as (8x4), so each subgroup
17+
// works on a 4x8 tile of A. It used SLM to do the transpose, to evaluate the functionality
18+
// of the SLM operations.
19+
func.func @test(%A: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} {
20+
%c1 = arith.constant 1 : index
21+
%c4 = arith.constant 4 : index
22+
%c8 = arith.constant 8 : index
23+
%A_gpu = gpu.alloc host_shared () : memref<32x32xf16>
24+
memref.copy %A, %A_gpu : memref<32x32xf16> to memref<32x32xf16>
25+
%B_gpu = gpu.alloc host_shared () : memref<32x32xf16>
26+
gpu.launch_func @test_kernel::@trans_kernel blocks in (%c1, %c1, %c1) threads in (%c4, %c8, %c1) args(%A_gpu : memref<32x32xf16>, %B_gpu : memref<32x32xf16>)
27+
gpu.dealloc %A_gpu : memref<32x32xf16>
28+
return %B_gpu : memref<32x32xf16>
29+
}
30+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL, Bfloat16ConversionINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_INTEL_bfloat16_conversion]>, api=OpenCL, #spirv.resource_limits<>>} {
31+
gpu.func @trans_kernel(%A: memref<32x32xf16>, %B: memref<32x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
32+
%c0 = arith.constant 0 : index
33+
%c2 = arith.constant 2 : index
34+
%c3 = arith.constant 3 : index
35+
%c4 = arith.constant 4 : index
36+
%c8 = arith.constant 8 : index
37+
%c128 = arith.constant 128 : index
38+
%c256 = arith.constant 256 : index
39+
40+
%sgid = gpu.subgroup_id : index
41+
// %tid_y = arith.divui %sgid, %c4 : index
42+
// %tid_x = arith.remui %sgid, %c4 : index
43+
%tid_y = arith.shrui %sgid, %c2 : index
44+
%tid_x = arith.andi %sgid, %c3 : index
45+
46+
%off_y = arith.muli %tid_y, %c4 : index
47+
%off_x = arith.muli %tid_x, %c8 : index
48+
49+
// load data from global memory using block load
50+
%a_tile = xetile.init_tile %A[%off_y, %off_x] : memref<32x32xf16> -> !xetile.tile<4x8xf16>
51+
%data = xetile.load_tile %a_tile : !xetile.tile<4x8xf16> -> vector<4x8xf16>
52+
53+
%slm = memref.alloc() : memref<32x32xf16, 3>
54+
%cast = memref.reinterpret_cast %slm to offset: [0], sizes: [1024], strides: [1] : memref<32x32xf16, 3> to memref<1024xf16, 3>
55+
%mask = arith.constant dense<true>: vector<4x8xi1>
56+
57+
// store data to slm using original layout
58+
%base_indices = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7],
59+
[32, 33, 34, 35, 36, 37, 38, 39],
60+
[64, 65, 66, 67, 68, 69, 70, 71],
61+
[96, 97, 98, 99, 100, 101, 102, 103]]>: vector<4x8xindex>
62+
%off_y2 = arith.muli %tid_y, %c128 : index
63+
%offset = arith.addi %off_y2, %off_x : index
64+
%offsets = vector.splat %offset: vector<4x8xindex>
65+
%indices = arith.addi %base_indices, %offsets : vector<4x8xindex>
66+
%st_tile = xetile.init_tile %cast, %indices : memref<1024xf16, 3>, vector<4x8xindex> -> !xetile.tile<4x8xf16, #xetile.tile_attr<scattered = true, memory_space=3>>
67+
xetile.store %data, %st_tile, %mask : vector<4x8xf16>, !xetile.tile<4x8xf16, #xetile.tile_attr<scattered = true, memory_space=3>>, vector<4x8xi1>
68+
69+
gpu.barrier
70+
71+
// load data from slm using indices with transpose effects
72+
%trans_base_indices = arith.constant dense<[[0, 32, 64, 96, 128, 160, 192, 224],
73+
[1, 33, 65, 97, 129, 161, 193, 225],
74+
[2, 34, 66, 98, 130, 162, 194, 226],
75+
[3, 35, 67, 99, 131, 163, 195, 227]]>: vector<4x8xindex>
76+
77+
%trans_off_x = arith.muli %tid_x, %c256 : index
78+
%trans_off_y = arith.muli %tid_y, %c4 : index
79+
%trans_off = arith.addi %trans_off_x, %trans_off_y : index
80+
%trans_offsets = vector.splat %trans_off: vector<4x8xindex>
81+
%trans_indices = arith.addi %trans_base_indices, %trans_offsets : vector<4x8xindex>
82+
%ld_tile = xetile.init_tile %cast, %trans_indices : memref<1024xf16, 3>, vector<4x8xindex> -> !xetile.tile<4x8xf16, #xetile.tile_attr<scattered = true, memory_space=3>>
83+
%d = xetile.load %ld_tile, %mask : !xetile.tile<4x8xf16, #xetile.tile_attr<scattered = true, memory_space=3>>, vector<4x8xi1> -> vector<4x8xf16>
84+
85+
%b_tile = xetile.init_tile %B[%off_y, %off_x] : memref<32x32xf16> -> !xetile.tile<4x8xf16>
86+
xetile.store_tile %d, %b_tile: vector<4x8xf16>, !xetile.tile<4x8xf16>
87+
gpu.return
88+
}
89+
}
90+
func.func @main() attributes {llvm.emit_c_interface} {
91+
%c0 = arith.constant 0 : index
92+
%c1 = arith.constant 1 : index
93+
%c32 = arith.constant 32 : index
94+
%cf_0 = arith.constant 0.0 : bf16
95+
%cf_1 = arith.constant 1.0 : bf16
96+
%A = memref.alloc() : memref<32x32xf16>
97+
%Ref = memref.alloc() : memref<32x32xf32>
98+
// intialize matrix A ;
99+
scf.for %i = %c0 to %c32 step %c1 {
100+
scf.for %j = %c0 to %c32 step %c1 {
101+
%m = arith.muli %i, %c32 : index
102+
%a = arith.addi %m, %j : index
103+
%v = index.castu %a : index to i16
104+
%val = arith.uitofp %v : i16 to f16
105+
memref.store %val, %A[%i, %j] : memref<32x32xf16>
106+
%v32 = index.castu %a : index to i32
107+
%val32 = arith.uitofp %v32 : i32 to f32
108+
memref.store %val32, %Ref[%j, %i] : memref<32x32xf32>
109+
}
110+
}
111+
%B = call @test(%A) : (memref<32x32xf16>) -> memref<32x32xf16>
112+
%cast = memref.cast %B : memref<32x32xf16> to memref<*xf16>
113+
%Ref_cast = memref.cast %Ref : memref<32x32xf32> to memref<*xf32>
114+
//CHECK: [ALLCLOSE: TRUE]
115+
call @printAllcloseF16(%cast, %Ref_cast) : (memref<*xf16>, memref<*xf32>) -> ()
116+
memref.dealloc %A : memref<32x32xf16>
117+
memref.dealloc %Ref : memref<32x32xf32>
118+
return
119+
}
120+
func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface}
121+
func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
122+
}

0 commit comments

Comments
 (0)