-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add Wg level 4K gemm * Fix precommit
- Loading branch information
Showing
1 changed file
with
253 additions
and
0 deletions.
There are no files selected for viewing
253 changes: 253 additions & 0 deletions
253
test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ | ||
// RUN: --runner imex-cpu-runner -e main \ | ||
// RUN: --entry-point-result=void \ | ||
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck | ||
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ | ||
// RUN: --runner imex-cpu-runner -e main \ | ||
// RUN: --entry-point-result=void \ | ||
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck | ||
|
||
#wg_map_a = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 32]> | ||
#tile_attr_a = #xetile.tile_attr<wg_map = #wg_map_a> | ||
|
||
#wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]> | ||
#tile_attr_b = #xetile.tile_attr<wg_map = #wg_map_b> | ||
|
||
#wg_map_c = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]> | ||
#tile_attr_c = #xetile.tile_attr<wg_map = #wg_map_c> | ||
|
||
module @gemm attributes {gpu.container_module} { | ||
func.func @test(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { | ||
%c1 = arith.constant 1 : index | ||
%c2 = arith.constant 2 : index | ||
%c4 = arith.constant 4 : index | ||
%c8 = arith.constant 8 : index | ||
%c16 = arith.constant 16 : index | ||
%c32 = arith.constant 32 : index | ||
%c64 = arith.constant 64 : index | ||
%c128 = arith.constant 128 : index | ||
%c512 = arith.constant 512 : index | ||
%A_gpu = gpu.alloc host_shared () : memref<4096x4096xbf16> | ||
memref.copy %A, %A_gpu : memref<4096x4096xbf16> to memref<4096x4096xbf16> | ||
%B_gpu = gpu.alloc host_shared () : memref<4096x4096xbf16> | ||
memref.copy %B, %B_gpu : memref<4096x4096xbf16> to memref<4096x4096xbf16> | ||
%C_gpu = gpu.alloc host_shared () : memref<4096x4096xf32> | ||
memref.copy %C, %C_gpu : memref<4096x4096xf32> to memref<4096x4096xf32> | ||
gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xbf16>, %B_gpu : memref<4096x4096xbf16>, %C_gpu : memref<4096x4096xf32>) | ||
gpu.dealloc %A_gpu : memref<4096x4096xbf16> | ||
gpu.dealloc %B_gpu : memref<4096x4096xbf16> | ||
return %C_gpu : memref<4096x4096xf32> | ||
} | ||
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} { | ||
gpu.func @test_kernel(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c32 = arith.constant 32 : index | ||
%c64 = arith.constant 64 : index | ||
%c256 = arith.constant 256 : index | ||
%c4096 = arith.constant 4096 : index | ||
%block_id_x = gpu.block_id x | ||
%block_id_y = gpu.block_id y | ||
%m = arith.muli %block_id_x, %c256 : index | ||
%n = arith.muli %block_id_y, %c256 : index | ||
// intialize C tile and load it | ||
// %prefetch_c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> | ||
// -> !xetile.tile<256x256xf32, #tile_attr_c> | ||
%c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> | ||
-> !xetile.tile<256x256xf32, #tile_attr_c> | ||
%c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<256x256xf32, #tile_attr_c> | ||
-> vector<256x256xf32> | ||
|
||
// initalize A and B tiles | ||
%a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xbf16> | ||
-> !xetile.tile<256x32xbf16, #tile_attr_a> | ||
%b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xbf16> | ||
-> !xetile.tile<32x256xbf16, #tile_attr_b> | ||
|
||
// prefetch first 32 slice | ||
%prefetch_a_init_tile_1 = xetile.init_tile %A[%m, %c0] : memref<4096x4096xbf16> | ||
-> !xetile.tile<256x32xbf16, #tile_attr_a> | ||
%prefetch_b_init_tile_1 = xetile.init_tile %B[%c0, %n] : memref<4096x4096xbf16> | ||
-> !xetile.tile<32x256xbf16, #tile_attr_b> | ||
xetile.prefetch_tile %prefetch_a_init_tile_1 : !xetile.tile<256x32xbf16, #tile_attr_a> | ||
xetile.prefetch_tile %prefetch_b_init_tile_1 : !xetile.tile<32x256xbf16, #tile_attr_b> | ||
|
||
// prefetch second 32 slice | ||
%prefetch_a_init_tile_2 = xetile.init_tile %A[%m, %c32] : memref<4096x4096xbf16> | ||
-> !xetile.tile<256x32xbf16, #tile_attr_a> | ||
%prefetch_b_init_tile_2 = xetile.init_tile %B[%c32, %n] : memref<4096x4096xbf16> | ||
-> !xetile.tile<32x256xbf16, #tile_attr_b> | ||
xetile.prefetch_tile %prefetch_a_init_tile_2 : !xetile.tile<256x32xbf16, #tile_attr_a> | ||
xetile.prefetch_tile %prefetch_b_init_tile_2 : !xetile.tile<32x256xbf16, #tile_attr_b> | ||
|
||
|
||
// prefetch third 32 slice | ||
%prefetch_a_init_tile_3 = xetile.init_tile %A[%m, %c64] : memref<4096x4096xbf16> | ||
-> !xetile.tile<256x32xbf16, #tile_attr_a> | ||
%prefetch_b_init_tile_3 = xetile.init_tile %B[%c64, %n] : memref<4096x4096xbf16> | ||
-> !xetile.tile<32x256xbf16, #tile_attr_b> | ||
|
||
xegpu.alloc_nbarrier 1 | ||
%nbarrier_id = arith.constant 0 : i8 | ||
%num_threads = arith.constant 32 : i8 | ||
%nbarrier = xegpu.init_nbarrier %nbarrier_id, %num_threads : i8, i8 -> !xegpu.nbarrier | ||
%c0_i32 = arith.constant 0 : i32 | ||
|
||
// compute the value of C tile by iterating over tiles in k-dimension and doing dpas | ||
%out:5 = scf.for %k = %c0 to %c4096 step %c32 | ||
iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value, | ||
%prefetch_a_tile = %prefetch_a_init_tile_3, | ||
%prefetch_b_tile = %prefetch_b_init_tile_3 | ||
) | ||
-> (!xetile.tile<256x32xbf16, #tile_attr_a>, | ||
!xetile.tile<32x256xbf16, #tile_attr_b>, | ||
vector<256x256xf32>, | ||
!xetile.tile<256x32xbf16, #tile_attr_a>, | ||
!xetile.tile<32x256xbf16, #tile_attr_b> | ||
) { | ||
|
||
// all SGs must arrive here first | ||
// %every_8th_iter = arith.remui %k, %c256 : index | ||
// %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 | ||
// %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 | ||
// scf.if %every_8th_iter_cond { | ||
xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier | ||
// } | ||
|
||
|
||
// load A and B tiles | ||
%a_value = xetile.load_tile %a_tile : !xetile.tile<256x32xbf16, #tile_attr_a> | ||
-> vector<256x32xbf16> | ||
%b_value = xetile.load_tile %b_tile : !xetile.tile<32x256xbf16, #tile_attr_b> | ||
-> vector<32x256xbf16> | ||
|
||
xegpu.compile_hint | ||
|
||
// prefetch next A and B tiles | ||
xetile.prefetch_tile %prefetch_a_tile : !xetile.tile<256x32xbf16, #tile_attr_a> | ||
xetile.prefetch_tile %prefetch_b_tile : !xetile.tile<32x256xbf16, #tile_attr_b> | ||
|
||
xegpu.compile_hint | ||
|
||
// update prefetch tile offsets | ||
%15 = xetile.update_tile_offset %prefetch_a_tile, [%c0, %c32] : !xetile.tile<256x32xbf16, #tile_attr_a> | ||
%16 = xetile.update_tile_offset %prefetch_b_tile, [%c32, %c0] : !xetile.tile<32x256xbf16, #tile_attr_b> | ||
// update the offsets for A and B tiles | ||
%a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] | ||
: !xetile.tile<256x32xbf16, #tile_attr_a> | ||
%b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] | ||
: !xetile.tile<32x256xbf16, #tile_attr_b> | ||
|
||
xegpu.compile_hint | ||
|
||
// perform dpas and accumulate | ||
%c_new_value = xetile.tile_mma %a_value, %b_value, %c_value {wg_map_a = #wg_map_a, wg_map_b = #wg_map_b, wg_map_c = #wg_map_c} | ||
: vector<256x32xbf16>, vector<32x256xbf16>, vector<256x256xf32> -> vector<256x256xf32> | ||
|
||
xegpu.compile_hint | ||
// barrier wait | ||
// scf.if %every_8th_iter_cond { | ||
xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier | ||
// } | ||
// partial C tile result | ||
scf.yield %a_next_tile, %b_next_tile, %c_new_value, %15, %16 | ||
: !xetile.tile<256x32xbf16, #tile_attr_a>, | ||
!xetile.tile<32x256xbf16, #tile_attr_b>, vector<256x256xf32>, | ||
!xetile.tile<256x32xbf16, #tile_attr_a>, | ||
!xetile.tile<32x256xbf16, #tile_attr_b> | ||
} | ||
// store the final accumulated C tile result back to memory | ||
%c_init_tile_1 = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> | ||
-> !xetile.tile<256x256xf32, #tile_attr_c> | ||
xetile.store_tile %out#2, %c_init_tile_1 : vector<256x256xf32>, | ||
!xetile.tile<256x256xf32, #tile_attr_c> | ||
xegpu.compile_hint | ||
gpu.return | ||
} | ||
} | ||
|
||
// compute CPU reference (takes minutes) | ||
func.func @cpu_reference(%A : memref<4096x4096xbf16>, %B : memref<4096x4096xbf16>, %C : memref<4096x4096xf32>) { | ||
%c4096 = arith.constant 4096 : index | ||
%c16 = arith.constant 16 : index | ||
%c1 = arith.constant 1 : index | ||
%c0 = arith.constant 0 : index | ||
scf.for %i = %c0 to %c4096 step %c1 { | ||
scf.for %j = %c0 to %c4096 step %c1 { | ||
%c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> | ||
%c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { | ||
%c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { | ||
%k_dpas = arith.addi %k_tile, %k : index | ||
%a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xbf16> | ||
%b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xbf16> | ||
%a_cast = arith.extf %a_val : bf16 to f32 | ||
%b_cast = arith.extf %b_val : bf16 to f32 | ||
%t = arith.mulf %a_cast, %b_cast : f32 | ||
%c_sum = arith.addf %t, %c_dpas_partial : f32 | ||
scf.yield %c_sum : f32 | ||
} | ||
scf.yield %c_val_dpas : f32 | ||
} | ||
memref.store %c_val , %C[%i, %j] : memref<4096x4096xf32> | ||
} | ||
} | ||
return | ||
} | ||
|
||
func.func @main() attributes {llvm.emit_c_interface} { | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c1_f16 = arith.constant 1.0 : bf16 | ||
%c2_f16 = arith.constant 2.0 : bf16 | ||
%c4096 = arith.constant 4096 : index | ||
%cf_0 = arith.constant 0.0 : bf16 | ||
%cf_1 = arith.constant 1.0 : bf16 | ||
%c_gen_int = arith.constant 0 : i1 | ||
%cf_lower = arith.constant 0.0 : f32 | ||
%cf_upper = arith.constant 1.0 : f32 | ||
|
||
%A = memref.alloc() : memref<4096x4096xbf16> | ||
%B = memref.alloc() : memref<4096x4096xbf16> | ||
%C = memref.alloc() : memref<4096x4096xf32> | ||
%C_ref = memref.alloc() : memref<4096x4096xf32> | ||
|
||
// convert the memref to 1D and fill with random values in (0.0, 1.0) | ||
%A_random = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16> | ||
call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () | ||
|
||
// convert the memref to 1D and fill with random values in (0.0, 1.0) | ||
%B_random = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16> | ||
call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () | ||
|
||
// intialize matrix C and C_ref ; C[i, j] = 0 | ||
%c0_f16 = arith.constant 0.0 : bf16 | ||
%c0_f32 = arith.constant 0.0 : f32 | ||
scf.for %i = %c0 to %c4096 step %c1 { | ||
scf.for %j = %c0 to %c4096 step %c1 { | ||
memref.store %c0_f32, %C[%i, %j] : memref<4096x4096xf32> | ||
memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> | ||
} | ||
} | ||
|
||
// run GPU | ||
%2 = call @test(%A, %B, %C) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> | ||
|
||
// run CPU | ||
call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> () | ||
|
||
%cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> | ||
%cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> | ||
// CHECK: [ALLCLOSE: TRUE] | ||
call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () | ||
memref.dealloc %A : memref<4096x4096xbf16> | ||
memref.dealloc %B : memref<4096x4096xbf16> | ||
memref.dealloc %C : memref<4096x4096xf32> | ||
memref.dealloc %C_ref : memref<4096x4096xf32> | ||
return | ||
} | ||
func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} | ||
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} | ||
func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} | ||
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} | ||
func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} | ||
} |