Skip to content

Commit fc16bd4

Browse files
authored
Add Wg level 4K gemm (#989)
* Add Wg level 4K gemm * Fix precommit
1 parent f191671 commit fc16bd4

File tree

1 file changed

+253
-0
lines changed

1 file changed

+253
-0
lines changed
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
10+
#wg_map_a = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 32]>
11+
#tile_attr_a = #xetile.tile_attr<wg_map = #wg_map_a>
12+
13+
#wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
14+
#tile_attr_b = #xetile.tile_attr<wg_map = #wg_map_b>
15+
16+
#wg_map_c = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>
17+
#tile_attr_c = #xetile.tile_attr<wg_map = #wg_map_c>
18+
19+
module @gemm attributes {gpu.container_module} {
20+
func.func @test(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} {
21+
%c1 = arith.constant 1 : index
22+
%c2 = arith.constant 2 : index
23+
%c4 = arith.constant 4 : index
24+
%c8 = arith.constant 8 : index
25+
%c16 = arith.constant 16 : index
26+
%c32 = arith.constant 32 : index
27+
%c64 = arith.constant 64 : index
28+
%c128 = arith.constant 128 : index
29+
%c512 = arith.constant 512 : index
30+
%A_gpu = gpu.alloc host_shared () : memref<4096x4096xbf16>
31+
memref.copy %A, %A_gpu : memref<4096x4096xbf16> to memref<4096x4096xbf16>
32+
%B_gpu = gpu.alloc host_shared () : memref<4096x4096xbf16>
33+
memref.copy %B, %B_gpu : memref<4096x4096xbf16> to memref<4096x4096xbf16>
34+
%C_gpu = gpu.alloc host_shared () : memref<4096x4096xf32>
35+
memref.copy %C, %C_gpu : memref<4096x4096xf32> to memref<4096x4096xf32>
36+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xbf16>, %B_gpu : memref<4096x4096xbf16>, %C_gpu : memref<4096x4096xf32>)
37+
gpu.dealloc %A_gpu : memref<4096x4096xbf16>
38+
gpu.dealloc %B_gpu : memref<4096x4096xbf16>
39+
return %C_gpu : memref<4096x4096xf32>
40+
}
41+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
42+
gpu.func @test_kernel(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
43+
%c0 = arith.constant 0 : index
44+
%c1 = arith.constant 1 : index
45+
%c32 = arith.constant 32 : index
46+
%c64 = arith.constant 64 : index
47+
%c256 = arith.constant 256 : index
48+
%c4096 = arith.constant 4096 : index
49+
%block_id_x = gpu.block_id x
50+
%block_id_y = gpu.block_id y
51+
%m = arith.muli %block_id_x, %c256 : index
52+
%n = arith.muli %block_id_y, %c256 : index
53+
// intialize C tile and load it
54+
// %prefetch_c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32>
55+
// -> !xetile.tile<256x256xf32, #tile_attr_c>
56+
%c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32>
57+
-> !xetile.tile<256x256xf32, #tile_attr_c>
58+
%c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<256x256xf32, #tile_attr_c>
59+
-> vector<256x256xf32>
60+
61+
// initalize A and B tiles
62+
%a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xbf16>
63+
-> !xetile.tile<256x32xbf16, #tile_attr_a>
64+
%b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xbf16>
65+
-> !xetile.tile<32x256xbf16, #tile_attr_b>
66+
67+
// prefetch first 32 slice
68+
%prefetch_a_init_tile_1 = xetile.init_tile %A[%m, %c0] : memref<4096x4096xbf16>
69+
-> !xetile.tile<256x32xbf16, #tile_attr_a>
70+
%prefetch_b_init_tile_1 = xetile.init_tile %B[%c0, %n] : memref<4096x4096xbf16>
71+
-> !xetile.tile<32x256xbf16, #tile_attr_b>
72+
xetile.prefetch_tile %prefetch_a_init_tile_1 : !xetile.tile<256x32xbf16, #tile_attr_a>
73+
xetile.prefetch_tile %prefetch_b_init_tile_1 : !xetile.tile<32x256xbf16, #tile_attr_b>
74+
75+
// prefetch second 32 slice
76+
%prefetch_a_init_tile_2 = xetile.init_tile %A[%m, %c32] : memref<4096x4096xbf16>
77+
-> !xetile.tile<256x32xbf16, #tile_attr_a>
78+
%prefetch_b_init_tile_2 = xetile.init_tile %B[%c32, %n] : memref<4096x4096xbf16>
79+
-> !xetile.tile<32x256xbf16, #tile_attr_b>
80+
xetile.prefetch_tile %prefetch_a_init_tile_2 : !xetile.tile<256x32xbf16, #tile_attr_a>
81+
xetile.prefetch_tile %prefetch_b_init_tile_2 : !xetile.tile<32x256xbf16, #tile_attr_b>
82+
83+
84+
// prefetch third 32 slice
85+
%prefetch_a_init_tile_3 = xetile.init_tile %A[%m, %c64] : memref<4096x4096xbf16>
86+
-> !xetile.tile<256x32xbf16, #tile_attr_a>
87+
%prefetch_b_init_tile_3 = xetile.init_tile %B[%c64, %n] : memref<4096x4096xbf16>
88+
-> !xetile.tile<32x256xbf16, #tile_attr_b>
89+
90+
xegpu.alloc_nbarrier 1
91+
%nbarrier_id = arith.constant 0 : i8
92+
%num_threads = arith.constant 32 : i8
93+
%nbarrier = xegpu.init_nbarrier %nbarrier_id, %num_threads : i8, i8 -> !xegpu.nbarrier
94+
%c0_i32 = arith.constant 0 : i32
95+
96+
// compute the value of C tile by iterating over tiles in k-dimension and doing dpas
97+
%out:5 = scf.for %k = %c0 to %c4096 step %c32
98+
iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value,
99+
%prefetch_a_tile = %prefetch_a_init_tile_3,
100+
%prefetch_b_tile = %prefetch_b_init_tile_3
101+
)
102+
-> (!xetile.tile<256x32xbf16, #tile_attr_a>,
103+
!xetile.tile<32x256xbf16, #tile_attr_b>,
104+
vector<256x256xf32>,
105+
!xetile.tile<256x32xbf16, #tile_attr_a>,
106+
!xetile.tile<32x256xbf16, #tile_attr_b>
107+
) {
108+
109+
// all SGs must arrive here first
110+
// %every_8th_iter = arith.remui %k, %c256 : index
111+
// %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32
112+
// %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32
113+
// scf.if %every_8th_iter_cond {
114+
xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier
115+
// }
116+
117+
118+
// load A and B tiles
119+
%a_value = xetile.load_tile %a_tile : !xetile.tile<256x32xbf16, #tile_attr_a>
120+
-> vector<256x32xbf16>
121+
%b_value = xetile.load_tile %b_tile : !xetile.tile<32x256xbf16, #tile_attr_b>
122+
-> vector<32x256xbf16>
123+
124+
xegpu.compile_hint
125+
126+
// prefetch next A and B tiles
127+
xetile.prefetch_tile %prefetch_a_tile : !xetile.tile<256x32xbf16, #tile_attr_a>
128+
xetile.prefetch_tile %prefetch_b_tile : !xetile.tile<32x256xbf16, #tile_attr_b>
129+
130+
xegpu.compile_hint
131+
132+
// update prefetch tile offsets
133+
%15 = xetile.update_tile_offset %prefetch_a_tile, [%c0, %c32] : !xetile.tile<256x32xbf16, #tile_attr_a>
134+
%16 = xetile.update_tile_offset %prefetch_b_tile, [%c32, %c0] : !xetile.tile<32x256xbf16, #tile_attr_b>
135+
// update the offsets for A and B tiles
136+
%a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32]
137+
: !xetile.tile<256x32xbf16, #tile_attr_a>
138+
%b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0]
139+
: !xetile.tile<32x256xbf16, #tile_attr_b>
140+
141+
xegpu.compile_hint
142+
143+
// perform dpas and accumulate
144+
%c_new_value = xetile.tile_mma %a_value, %b_value, %c_value {wg_map_a = #wg_map_a, wg_map_b = #wg_map_b, wg_map_c = #wg_map_c}
145+
: vector<256x32xbf16>, vector<32x256xbf16>, vector<256x256xf32> -> vector<256x256xf32>
146+
147+
xegpu.compile_hint
148+
// barrier wait
149+
// scf.if %every_8th_iter_cond {
150+
xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier
151+
// }
152+
// partial C tile result
153+
scf.yield %a_next_tile, %b_next_tile, %c_new_value, %15, %16
154+
: !xetile.tile<256x32xbf16, #tile_attr_a>,
155+
!xetile.tile<32x256xbf16, #tile_attr_b>, vector<256x256xf32>,
156+
!xetile.tile<256x32xbf16, #tile_attr_a>,
157+
!xetile.tile<32x256xbf16, #tile_attr_b>
158+
}
159+
// store the final accumulated C tile result back to memory
160+
%c_init_tile_1 = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32>
161+
-> !xetile.tile<256x256xf32, #tile_attr_c>
162+
xetile.store_tile %out#2, %c_init_tile_1 : vector<256x256xf32>,
163+
!xetile.tile<256x256xf32, #tile_attr_c>
164+
xegpu.compile_hint
165+
gpu.return
166+
}
167+
}
168+
169+
// compute CPU reference (takes minutes)
170+
func.func @cpu_reference(%A : memref<4096x4096xbf16>, %B : memref<4096x4096xbf16>, %C : memref<4096x4096xf32>) {
171+
%c4096 = arith.constant 4096 : index
172+
%c16 = arith.constant 16 : index
173+
%c1 = arith.constant 1 : index
174+
%c0 = arith.constant 0 : index
175+
scf.for %i = %c0 to %c4096 step %c1 {
176+
scf.for %j = %c0 to %c4096 step %c1 {
177+
%c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32>
178+
%c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 {
179+
%c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 {
180+
%k_dpas = arith.addi %k_tile, %k : index
181+
%a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xbf16>
182+
%b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xbf16>
183+
%a_cast = arith.extf %a_val : bf16 to f32
184+
%b_cast = arith.extf %b_val : bf16 to f32
185+
%t = arith.mulf %a_cast, %b_cast : f32
186+
%c_sum = arith.addf %t, %c_dpas_partial : f32
187+
scf.yield %c_sum : f32
188+
}
189+
scf.yield %c_val_dpas : f32
190+
}
191+
memref.store %c_val , %C[%i, %j] : memref<4096x4096xf32>
192+
}
193+
}
194+
return
195+
}
196+
197+
func.func @main() attributes {llvm.emit_c_interface} {
198+
%c0 = arith.constant 0 : index
199+
%c1 = arith.constant 1 : index
200+
%c1_f16 = arith.constant 1.0 : bf16
201+
%c2_f16 = arith.constant 2.0 : bf16
202+
%c4096 = arith.constant 4096 : index
203+
%cf_0 = arith.constant 0.0 : bf16
204+
%cf_1 = arith.constant 1.0 : bf16
205+
%c_gen_int = arith.constant 0 : i1
206+
%cf_lower = arith.constant 0.0 : f32
207+
%cf_upper = arith.constant 1.0 : f32
208+
209+
%A = memref.alloc() : memref<4096x4096xbf16>
210+
%B = memref.alloc() : memref<4096x4096xbf16>
211+
%C = memref.alloc() : memref<4096x4096xf32>
212+
%C_ref = memref.alloc() : memref<4096x4096xf32>
213+
214+
// convert the memref to 1D and fill with random values in (0.0, 1.0)
215+
%A_random = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16>
216+
call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> ()
217+
218+
// convert the memref to 1D and fill with random values in (0.0, 1.0)
219+
%B_random = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16>
220+
call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> ()
221+
222+
// intialize matrix C and C_ref ; C[i, j] = 0
223+
%c0_f16 = arith.constant 0.0 : bf16
224+
%c0_f32 = arith.constant 0.0 : f32
225+
scf.for %i = %c0 to %c4096 step %c1 {
226+
scf.for %j = %c0 to %c4096 step %c1 {
227+
memref.store %c0_f32, %C[%i, %j] : memref<4096x4096xf32>
228+
memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32>
229+
}
230+
}
231+
232+
// run GPU
233+
%2 = call @test(%A, %B, %C) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32>
234+
235+
// run CPU
236+
call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> ()
237+
238+
%cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32>
239+
%cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32>
240+
// CHECK: [ALLCLOSE: TRUE]
241+
call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> ()
242+
memref.dealloc %A : memref<4096x4096xbf16>
243+
memref.dealloc %B : memref<4096x4096xbf16>
244+
memref.dealloc %C : memref<4096x4096xf32>
245+
memref.dealloc %C_ref : memref<4096x4096xf32>
246+
return
247+
}
248+
func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface}
249+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
250+
func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
251+
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
252+
func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
253+
}

0 commit comments

Comments
 (0)