diff --git a/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir b/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir new file mode 100644 index 000000000..b088561f4 --- /dev/null +++ b/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir @@ -0,0 +1,253 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: --runner imex-cpu-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +#wg_map_a = #xetile.wg_map +#tile_attr_a = #xetile.tile_attr + +#wg_map_b = #xetile.wg_map +#tile_attr_b = #xetile.tile_attr + +#wg_map_c = #xetile.wg_map +#tile_attr_c = #xetile.tile_attr + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc host_shared () : memref<4096x4096xbf16> + memref.copy %A, %A_gpu : memref<4096x4096xbf16> to memref<4096x4096xbf16> + %B_gpu = gpu.alloc host_shared () : memref<4096x4096xbf16> + memref.copy %B, %B_gpu : memref<4096x4096xbf16> to memref<4096x4096xbf16> + %C_gpu = gpu.alloc host_shared () : memref<4096x4096xf32> + memref.copy %C, %C_gpu : memref<4096x4096xf32> to memref<4096x4096xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xbf16>, %B_gpu : memref<4096x4096xbf16>, %C_gpu : memref<4096x4096xf32>) + gpu.dealloc %A_gpu : memref<4096x4096xbf16> + gpu.dealloc %B_gpu : memref<4096x4096xbf16> + return %C_gpu : memref<4096x4096xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c256 : index + %n = arith.muli %block_id_y, %c256 : index + // intialize C tile and load it + // %prefetch_c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> + // -> !xetile.tile<256x256xf32, #tile_attr_c> + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> + -> !xetile.tile<256x256xf32, #tile_attr_c> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<256x256xf32, #tile_attr_c> + -> vector<256x256xf32> + + // initalize A and B tiles + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xbf16> + -> !xetile.tile<256x32xbf16, #tile_attr_a> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xbf16> + -> !xetile.tile<32x256xbf16, #tile_attr_b> + + // prefetch first 32 slice + %prefetch_a_init_tile_1 = xetile.init_tile %A[%m, %c0] : memref<4096x4096xbf16> + -> !xetile.tile<256x32xbf16, #tile_attr_a> + %prefetch_b_init_tile_1 = xetile.init_tile %B[%c0, %n] : memref<4096x4096xbf16> + -> !xetile.tile<32x256xbf16, #tile_attr_b> + xetile.prefetch_tile %prefetch_a_init_tile_1 : !xetile.tile<256x32xbf16, #tile_attr_a> + xetile.prefetch_tile %prefetch_b_init_tile_1 : !xetile.tile<32x256xbf16, #tile_attr_b> + + // prefetch second 32 slice + %prefetch_a_init_tile_2 = xetile.init_tile %A[%m, %c32] : memref<4096x4096xbf16> + -> !xetile.tile<256x32xbf16, #tile_attr_a> + %prefetch_b_init_tile_2 = xetile.init_tile %B[%c32, %n] : memref<4096x4096xbf16> + -> !xetile.tile<32x256xbf16, #tile_attr_b> + xetile.prefetch_tile %prefetch_a_init_tile_2 : !xetile.tile<256x32xbf16, #tile_attr_a> + xetile.prefetch_tile %prefetch_b_init_tile_2 : !xetile.tile<32x256xbf16, #tile_attr_b> + + + // prefetch third 32 slice + %prefetch_a_init_tile_3 = xetile.init_tile %A[%m, %c64] : memref<4096x4096xbf16> + -> !xetile.tile<256x32xbf16, #tile_attr_a> + %prefetch_b_init_tile_3 = xetile.init_tile %B[%c64, %n] : memref<4096x4096xbf16> + -> !xetile.tile<32x256xbf16, #tile_attr_b> + + xegpu.alloc_nbarrier 1 + %nbarrier_id = arith.constant 0 : i8 + %num_threads = arith.constant 32 : i8 + %nbarrier = xegpu.init_nbarrier %nbarrier_id, %num_threads : i8, i8 -> !xegpu.nbarrier + %c0_i32 = arith.constant 0 : i32 + + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + %out:5 = scf.for %k = %c0 to %c4096 step %c32 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value, + %prefetch_a_tile = %prefetch_a_init_tile_3, + %prefetch_b_tile = %prefetch_b_init_tile_3 + ) + -> (!xetile.tile<256x32xbf16, #tile_attr_a>, + !xetile.tile<32x256xbf16, #tile_attr_b>, + vector<256x256xf32>, + !xetile.tile<256x32xbf16, #tile_attr_a>, + !xetile.tile<32x256xbf16, #tile_attr_b> + ) { + + // all SGs must arrive here first + // %every_8th_iter = arith.remui %k, %c256 : index + // %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 + // %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 + // scf.if %every_8th_iter_cond { + xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier + // } + + + // load A and B tiles + %a_value = xetile.load_tile %a_tile : !xetile.tile<256x32xbf16, #tile_attr_a> + -> vector<256x32xbf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<32x256xbf16, #tile_attr_b> + -> vector<32x256xbf16> + + xegpu.compile_hint + + // prefetch next A and B tiles + xetile.prefetch_tile %prefetch_a_tile : !xetile.tile<256x32xbf16, #tile_attr_a> + xetile.prefetch_tile %prefetch_b_tile : !xetile.tile<32x256xbf16, #tile_attr_b> + + xegpu.compile_hint + + // update prefetch tile offsets + %15 = xetile.update_tile_offset %prefetch_a_tile, [%c0, %c32] : !xetile.tile<256x32xbf16, #tile_attr_a> + %16 = xetile.update_tile_offset %prefetch_b_tile, [%c32, %c0] : !xetile.tile<32x256xbf16, #tile_attr_b> + // update the offsets for A and B tiles + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] + : !xetile.tile<256x32xbf16, #tile_attr_a> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] + : !xetile.tile<32x256xbf16, #tile_attr_b> + + xegpu.compile_hint + + // perform dpas and accumulate + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value {wg_map_a = #wg_map_a, wg_map_b = #wg_map_b, wg_map_c = #wg_map_c} + : vector<256x32xbf16>, vector<32x256xbf16>, vector<256x256xf32> -> vector<256x256xf32> + + xegpu.compile_hint + // barrier wait + // scf.if %every_8th_iter_cond { + xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier + // } + // partial C tile result + scf.yield %a_next_tile, %b_next_tile, %c_new_value, %15, %16 + : !xetile.tile<256x32xbf16, #tile_attr_a>, + !xetile.tile<32x256xbf16, #tile_attr_b>, vector<256x256xf32>, + !xetile.tile<256x32xbf16, #tile_attr_a>, + !xetile.tile<32x256xbf16, #tile_attr_b> + } + // store the final accumulated C tile result back to memory + %c_init_tile_1 = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> + -> !xetile.tile<256x256xf32, #tile_attr_c> + xetile.store_tile %out#2, %c_init_tile_1 : vector<256x256xf32>, + !xetile.tile<256x256xf32, #tile_attr_c> + xegpu.compile_hint + gpu.return + } + } + + // compute CPU reference (takes minutes) + func.func @cpu_reference(%A : memref<4096x4096xbf16>, %B : memref<4096x4096xbf16>, %C : memref<4096x4096xf32>) { + %c4096 = arith.constant 4096 : index + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> + %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { + %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { + %k_dpas = arith.addi %k_tile, %k : index + %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xbf16> + %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xbf16> + %a_cast = arith.extf %a_val : bf16 to f32 + %b_cast = arith.extf %b_val : bf16 to f32 + %t = arith.mulf %a_cast, %b_cast : f32 + %c_sum = arith.addf %t, %c_dpas_partial : f32 + scf.yield %c_sum : f32 + } + scf.yield %c_val_dpas : f32 + } + memref.store %c_val , %C[%i, %j] : memref<4096x4096xf32> + } + } + return + } + + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_f16 = arith.constant 1.0 : bf16 + %c2_f16 = arith.constant 2.0 : bf16 + %c4096 = arith.constant 4096 : index + %cf_0 = arith.constant 0.0 : bf16 + %cf_1 = arith.constant 1.0 : bf16 + %c_gen_int = arith.constant 0 : i1 + %cf_lower = arith.constant 0.0 : f32 + %cf_upper = arith.constant 1.0 : f32 + + %A = memref.alloc() : memref<4096x4096xbf16> + %B = memref.alloc() : memref<4096x4096xbf16> + %C = memref.alloc() : memref<4096x4096xf32> + %C_ref = memref.alloc() : memref<4096x4096xf32> + + // convert the memref to 1D and fill with random values in (0.0, 1.0) + %A_random = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () + + // convert the memref to 1D and fill with random values in (0.0, 1.0) + %B_random = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () + + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_f16 = arith.constant 0.0 : bf16 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<4096x4096xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> + } + } + + // run GPU + %2 = call @test(%A, %B, %C) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> + + // run CPU + call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> () + + %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> + %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %A : memref<4096x4096xbf16> + memref.dealloc %B : memref<4096x4096xbf16> + memref.dealloc %C : memref<4096x4096xf32> + memref.dealloc %C_ref : memref<4096x4096xf32> + return + } + func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} +}