|
| 1 | +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ |
| 2 | +// RUN: --runner imex-cpu-runner -e main \ |
| 3 | +// RUN: --entry-point-result=void \ |
| 4 | +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck |
| 5 | +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ |
| 6 | +// RUN: --runner imex-cpu-runner -e main \ |
| 7 | +// RUN: --entry-point-result=void \ |
| 8 | +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck |
| 9 | + |
| 10 | +#wg_map_a = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 32]> |
| 11 | +#tile_attr_a = #xetile.tile_attr<wg_map = #wg_map_a> |
| 12 | + |
| 13 | +#wg_map_b = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]> |
| 14 | +#tile_attr_b = #xetile.tile_attr<wg_map = #wg_map_b> |
| 15 | + |
| 16 | +#wg_map_c = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]> |
| 17 | +#tile_attr_c = #xetile.tile_attr<wg_map = #wg_map_c> |
| 18 | + |
| 19 | +module @gemm attributes {gpu.container_module} { |
| 20 | + func.func @test(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { |
| 21 | + %c1 = arith.constant 1 : index |
| 22 | + %c2 = arith.constant 2 : index |
| 23 | + %c4 = arith.constant 4 : index |
| 24 | + %c8 = arith.constant 8 : index |
| 25 | + %c16 = arith.constant 16 : index |
| 26 | + %c32 = arith.constant 32 : index |
| 27 | + %c64 = arith.constant 64 : index |
| 28 | + %c128 = arith.constant 128 : index |
| 29 | + %c512 = arith.constant 512 : index |
| 30 | + %A_gpu = gpu.alloc host_shared () : memref<4096x4096xbf16> |
| 31 | + memref.copy %A, %A_gpu : memref<4096x4096xbf16> to memref<4096x4096xbf16> |
| 32 | + %B_gpu = gpu.alloc host_shared () : memref<4096x4096xbf16> |
| 33 | + memref.copy %B, %B_gpu : memref<4096x4096xbf16> to memref<4096x4096xbf16> |
| 34 | + %C_gpu = gpu.alloc host_shared () : memref<4096x4096xf32> |
| 35 | + memref.copy %C, %C_gpu : memref<4096x4096xf32> to memref<4096x4096xf32> |
| 36 | + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xbf16>, %B_gpu : memref<4096x4096xbf16>, %C_gpu : memref<4096x4096xf32>) |
| 37 | + gpu.dealloc %A_gpu : memref<4096x4096xbf16> |
| 38 | + gpu.dealloc %B_gpu : memref<4096x4096xbf16> |
| 39 | + return %C_gpu : memref<4096x4096xf32> |
| 40 | + } |
| 41 | + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} { |
| 42 | + gpu.func @test_kernel(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { |
| 43 | + %c0 = arith.constant 0 : index |
| 44 | + %c1 = arith.constant 1 : index |
| 45 | + %c32 = arith.constant 32 : index |
| 46 | + %c64 = arith.constant 64 : index |
| 47 | + %c256 = arith.constant 256 : index |
| 48 | + %c4096 = arith.constant 4096 : index |
| 49 | + %block_id_x = gpu.block_id x |
| 50 | + %block_id_y = gpu.block_id y |
| 51 | + %m = arith.muli %block_id_x, %c256 : index |
| 52 | + %n = arith.muli %block_id_y, %c256 : index |
| 53 | + // intialize C tile and load it |
| 54 | + // %prefetch_c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> |
| 55 | + // -> !xetile.tile<256x256xf32, #tile_attr_c> |
| 56 | + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> |
| 57 | + -> !xetile.tile<256x256xf32, #tile_attr_c> |
| 58 | + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<256x256xf32, #tile_attr_c> |
| 59 | + -> vector<256x256xf32> |
| 60 | + |
| 61 | + // initalize A and B tiles |
| 62 | + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xbf16> |
| 63 | + -> !xetile.tile<256x32xbf16, #tile_attr_a> |
| 64 | + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xbf16> |
| 65 | + -> !xetile.tile<32x256xbf16, #tile_attr_b> |
| 66 | + |
| 67 | + // prefetch first 32 slice |
| 68 | + %prefetch_a_init_tile_1 = xetile.init_tile %A[%m, %c0] : memref<4096x4096xbf16> |
| 69 | + -> !xetile.tile<256x32xbf16, #tile_attr_a> |
| 70 | + %prefetch_b_init_tile_1 = xetile.init_tile %B[%c0, %n] : memref<4096x4096xbf16> |
| 71 | + -> !xetile.tile<32x256xbf16, #tile_attr_b> |
| 72 | + xetile.prefetch_tile %prefetch_a_init_tile_1 : !xetile.tile<256x32xbf16, #tile_attr_a> |
| 73 | + xetile.prefetch_tile %prefetch_b_init_tile_1 : !xetile.tile<32x256xbf16, #tile_attr_b> |
| 74 | + |
| 75 | + // prefetch second 32 slice |
| 76 | + %prefetch_a_init_tile_2 = xetile.init_tile %A[%m, %c32] : memref<4096x4096xbf16> |
| 77 | + -> !xetile.tile<256x32xbf16, #tile_attr_a> |
| 78 | + %prefetch_b_init_tile_2 = xetile.init_tile %B[%c32, %n] : memref<4096x4096xbf16> |
| 79 | + -> !xetile.tile<32x256xbf16, #tile_attr_b> |
| 80 | + xetile.prefetch_tile %prefetch_a_init_tile_2 : !xetile.tile<256x32xbf16, #tile_attr_a> |
| 81 | + xetile.prefetch_tile %prefetch_b_init_tile_2 : !xetile.tile<32x256xbf16, #tile_attr_b> |
| 82 | + |
| 83 | + |
| 84 | + // prefetch third 32 slice |
| 85 | + %prefetch_a_init_tile_3 = xetile.init_tile %A[%m, %c64] : memref<4096x4096xbf16> |
| 86 | + -> !xetile.tile<256x32xbf16, #tile_attr_a> |
| 87 | + %prefetch_b_init_tile_3 = xetile.init_tile %B[%c64, %n] : memref<4096x4096xbf16> |
| 88 | + -> !xetile.tile<32x256xbf16, #tile_attr_b> |
| 89 | + |
| 90 | + xegpu.alloc_nbarrier 1 |
| 91 | + %nbarrier_id = arith.constant 0 : i8 |
| 92 | + %num_threads = arith.constant 32 : i8 |
| 93 | + %nbarrier = xegpu.init_nbarrier %nbarrier_id, %num_threads : i8, i8 -> !xegpu.nbarrier |
| 94 | + %c0_i32 = arith.constant 0 : i32 |
| 95 | + |
| 96 | + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas |
| 97 | + %out:5 = scf.for %k = %c0 to %c4096 step %c32 |
| 98 | + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value, |
| 99 | + %prefetch_a_tile = %prefetch_a_init_tile_3, |
| 100 | + %prefetch_b_tile = %prefetch_b_init_tile_3 |
| 101 | + ) |
| 102 | + -> (!xetile.tile<256x32xbf16, #tile_attr_a>, |
| 103 | + !xetile.tile<32x256xbf16, #tile_attr_b>, |
| 104 | + vector<256x256xf32>, |
| 105 | + !xetile.tile<256x32xbf16, #tile_attr_a>, |
| 106 | + !xetile.tile<32x256xbf16, #tile_attr_b> |
| 107 | + ) { |
| 108 | + |
| 109 | + // all SGs must arrive here first |
| 110 | + // %every_8th_iter = arith.remui %k, %c256 : index |
| 111 | + // %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 |
| 112 | + // %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 |
| 113 | + // scf.if %every_8th_iter_cond { |
| 114 | + xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier |
| 115 | + // } |
| 116 | + |
| 117 | + |
| 118 | + // load A and B tiles |
| 119 | + %a_value = xetile.load_tile %a_tile : !xetile.tile<256x32xbf16, #tile_attr_a> |
| 120 | + -> vector<256x32xbf16> |
| 121 | + %b_value = xetile.load_tile %b_tile : !xetile.tile<32x256xbf16, #tile_attr_b> |
| 122 | + -> vector<32x256xbf16> |
| 123 | + |
| 124 | + xegpu.compile_hint |
| 125 | + |
| 126 | + // prefetch next A and B tiles |
| 127 | + xetile.prefetch_tile %prefetch_a_tile : !xetile.tile<256x32xbf16, #tile_attr_a> |
| 128 | + xetile.prefetch_tile %prefetch_b_tile : !xetile.tile<32x256xbf16, #tile_attr_b> |
| 129 | + |
| 130 | + xegpu.compile_hint |
| 131 | + |
| 132 | + // update prefetch tile offsets |
| 133 | + %15 = xetile.update_tile_offset %prefetch_a_tile, [%c0, %c32] : !xetile.tile<256x32xbf16, #tile_attr_a> |
| 134 | + %16 = xetile.update_tile_offset %prefetch_b_tile, [%c32, %c0] : !xetile.tile<32x256xbf16, #tile_attr_b> |
| 135 | + // update the offsets for A and B tiles |
| 136 | + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] |
| 137 | + : !xetile.tile<256x32xbf16, #tile_attr_a> |
| 138 | + %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] |
| 139 | + : !xetile.tile<32x256xbf16, #tile_attr_b> |
| 140 | + |
| 141 | + xegpu.compile_hint |
| 142 | + |
| 143 | + // perform dpas and accumulate |
| 144 | + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value {wg_map_a = #wg_map_a, wg_map_b = #wg_map_b, wg_map_c = #wg_map_c} |
| 145 | + : vector<256x32xbf16>, vector<32x256xbf16>, vector<256x256xf32> -> vector<256x256xf32> |
| 146 | + |
| 147 | + xegpu.compile_hint |
| 148 | + // barrier wait |
| 149 | + // scf.if %every_8th_iter_cond { |
| 150 | + xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier |
| 151 | + // } |
| 152 | + // partial C tile result |
| 153 | + scf.yield %a_next_tile, %b_next_tile, %c_new_value, %15, %16 |
| 154 | + : !xetile.tile<256x32xbf16, #tile_attr_a>, |
| 155 | + !xetile.tile<32x256xbf16, #tile_attr_b>, vector<256x256xf32>, |
| 156 | + !xetile.tile<256x32xbf16, #tile_attr_a>, |
| 157 | + !xetile.tile<32x256xbf16, #tile_attr_b> |
| 158 | + } |
| 159 | + // store the final accumulated C tile result back to memory |
| 160 | + %c_init_tile_1 = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> |
| 161 | + -> !xetile.tile<256x256xf32, #tile_attr_c> |
| 162 | + xetile.store_tile %out#2, %c_init_tile_1 : vector<256x256xf32>, |
| 163 | + !xetile.tile<256x256xf32, #tile_attr_c> |
| 164 | + xegpu.compile_hint |
| 165 | + gpu.return |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + // compute CPU reference (takes minutes) |
| 170 | + func.func @cpu_reference(%A : memref<4096x4096xbf16>, %B : memref<4096x4096xbf16>, %C : memref<4096x4096xf32>) { |
| 171 | + %c4096 = arith.constant 4096 : index |
| 172 | + %c16 = arith.constant 16 : index |
| 173 | + %c1 = arith.constant 1 : index |
| 174 | + %c0 = arith.constant 0 : index |
| 175 | + scf.for %i = %c0 to %c4096 step %c1 { |
| 176 | + scf.for %j = %c0 to %c4096 step %c1 { |
| 177 | + %c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32> |
| 178 | + %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 { |
| 179 | + %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 { |
| 180 | + %k_dpas = arith.addi %k_tile, %k : index |
| 181 | + %a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xbf16> |
| 182 | + %b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xbf16> |
| 183 | + %a_cast = arith.extf %a_val : bf16 to f32 |
| 184 | + %b_cast = arith.extf %b_val : bf16 to f32 |
| 185 | + %t = arith.mulf %a_cast, %b_cast : f32 |
| 186 | + %c_sum = arith.addf %t, %c_dpas_partial : f32 |
| 187 | + scf.yield %c_sum : f32 |
| 188 | + } |
| 189 | + scf.yield %c_val_dpas : f32 |
| 190 | + } |
| 191 | + memref.store %c_val , %C[%i, %j] : memref<4096x4096xf32> |
| 192 | + } |
| 193 | + } |
| 194 | + return |
| 195 | + } |
| 196 | + |
| 197 | + func.func @main() attributes {llvm.emit_c_interface} { |
| 198 | + %c0 = arith.constant 0 : index |
| 199 | + %c1 = arith.constant 1 : index |
| 200 | + %c1_f16 = arith.constant 1.0 : bf16 |
| 201 | + %c2_f16 = arith.constant 2.0 : bf16 |
| 202 | + %c4096 = arith.constant 4096 : index |
| 203 | + %cf_0 = arith.constant 0.0 : bf16 |
| 204 | + %cf_1 = arith.constant 1.0 : bf16 |
| 205 | + %c_gen_int = arith.constant 0 : i1 |
| 206 | + %cf_lower = arith.constant 0.0 : f32 |
| 207 | + %cf_upper = arith.constant 1.0 : f32 |
| 208 | + |
| 209 | + %A = memref.alloc() : memref<4096x4096xbf16> |
| 210 | + %B = memref.alloc() : memref<4096x4096xbf16> |
| 211 | + %C = memref.alloc() : memref<4096x4096xf32> |
| 212 | + %C_ref = memref.alloc() : memref<4096x4096xf32> |
| 213 | + |
| 214 | + // convert the memref to 1D and fill with random values in (0.0, 1.0) |
| 215 | + %A_random = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16> |
| 216 | + call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () |
| 217 | + |
| 218 | + // convert the memref to 1D and fill with random values in (0.0, 1.0) |
| 219 | + %B_random = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16> |
| 220 | + call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () |
| 221 | + |
| 222 | + // intialize matrix C and C_ref ; C[i, j] = 0 |
| 223 | + %c0_f16 = arith.constant 0.0 : bf16 |
| 224 | + %c0_f32 = arith.constant 0.0 : f32 |
| 225 | + scf.for %i = %c0 to %c4096 step %c1 { |
| 226 | + scf.for %j = %c0 to %c4096 step %c1 { |
| 227 | + memref.store %c0_f32, %C[%i, %j] : memref<4096x4096xf32> |
| 228 | + memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> |
| 229 | + } |
| 230 | + } |
| 231 | + |
| 232 | + // run GPU |
| 233 | + %2 = call @test(%A, %B, %C) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> |
| 234 | + |
| 235 | + // run CPU |
| 236 | + call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> () |
| 237 | + |
| 238 | + %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> |
| 239 | + %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> |
| 240 | + // CHECK: [ALLCLOSE: TRUE] |
| 241 | + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () |
| 242 | + memref.dealloc %A : memref<4096x4096xbf16> |
| 243 | + memref.dealloc %B : memref<4096x4096xbf16> |
| 244 | + memref.dealloc %C : memref<4096x4096xf32> |
| 245 | + memref.dealloc %C_ref : memref<4096x4096xf32> |
| 246 | + return |
| 247 | + } |
| 248 | + func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} |
| 249 | + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} |
| 250 | + func.func private @printAllcloseBF16(memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} |
| 251 | + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} |
| 252 | + func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} |
| 253 | +} |
0 commit comments