Skip to content

Commit 1cf86a8

Browse files
committed
Updates to XeGPU f16 GEMM 4kx4kx4k performance test case
This PR introduces following optimizations: 1. Larger loads for A and B 2. f16 stores to C instead of f32 stores 3. Periodic barrier syncing instead of syncing every K iteration 4. Avoid using signed div/rem ops
1 parent 180dc90 commit 1cf86a8

7 files changed

+652
-25
lines changed

lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ encodeVectorType(ConversionPatternRewriter &rewriter, VectorType type,
5959
case 128:
6060
str += "v128";
6161
break;
62+
case 256:
63+
str += "v256";
64+
break;
6265
default:
6366
assert(0 && "add more support");
6467
break;

test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_f16_f16_f32.mlir renamed to test/Integration/Dialect/XeGPU/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,38 @@ module @gemm attributes {gpu.container_module} {
383383
gpu.return
384384
}
385385
}
386+
387+
// compute CPU reference (takes minutes)
388+
func.func @cpu_reference(%A : memref<4096x4096xf16>, %B : memref<4096x4096xf16>, %C : memref<4096x4096xf32>) {
389+
%c4096 = arith.constant 4096 : index
390+
%c16 = arith.constant 16 : index
391+
%c1 = arith.constant 1 : index
392+
%c0 = arith.constant 0 : index
393+
scf.for %i = %c0 to %c4096 step %c1 {
394+
scf.for %j = %c0 to %c4096 step %c1 {
395+
%c_curr = memref.load %C[%i, %j] : memref<4096x4096xf32>
396+
%c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args(%c_partial = %c_curr) -> f32 {
397+
%c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 {
398+
%k_dpas = arith.addi %k_tile, %k : index
399+
%a_val = memref.load %A[%i, %k_dpas] : memref<4096x4096xf16>
400+
%b_val = memref.load %B[%k_dpas, %j] : memref<4096x4096xf16>
401+
%a_cast = arith.extf %a_val : f16 to f32
402+
%b_cast = arith.extf %b_val : f16 to f32
403+
%t = arith.mulf %a_cast, %b_cast : f32
404+
// %t_cast = arith.extf %t : f16 to f16
405+
%c_sum = arith.addf %t, %c_dpas_partial : f32
406+
scf.yield %c_sum : f32
407+
}
408+
scf.yield %c_val_dpas : f32
409+
}
410+
// %c_val_f16 = arith.truncf %c_val : f32 to f16
411+
// %c_val_ = arith.extf %c_val_f16 : f16 to f32
412+
memref.store %c_val , %C[%i, %j] : memref<4096x4096xf32>
413+
}
414+
}
415+
return
416+
}
417+
386418
func.func @main() attributes {llvm.emit_c_interface} {
387419
%c0 = arith.constant 0 : index
388420
%c1 = arith.constant 1 : index
@@ -448,23 +480,8 @@ module @gemm attributes {gpu.container_module} {
448480
// run GPU
449481
%2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32>
450482

451-
// compute CPU reference (takes minutes)
452-
scf.for %i = %c0 to %c4096 step %c1 {
453-
scf.for %j = %c0 to %c4096 step %c1 {
454-
%c_curr = memref.load %C_ref[%i, %j] : memref<4096x4096xf32>
455-
%c_val = scf.for %k = %c0 to %c4096 step %c1 iter_args(%c_partial = %c_curr) -> f32 {
456-
%a_val = memref.load %A[%i, %k] : memref<4096x4096xf16>
457-
%b_val = memref.load %B[%k, %j] : memref<4096x4096xf16>
458-
%a_cast = arith.extf %a_val : f16 to f32
459-
%b_cast = arith.extf %b_val : f16 to f32
460-
%t = arith.mulf %a_cast, %b_cast : f32
461-
// %t_cast = arith.extf %t : f16 to f32
462-
%c_sum = arith.addf %t, %c_partial : f32
463-
scf.yield %c_sum : f32
464-
}
465-
memref.store %c_val , %C_ref[%i, %j] : memref<4096x4096xf32>
466-
}
467-
}
483+
// run CPU
484+
call @cpu_reference(%A, %B, %C_ref) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> ()
468485

469486
// %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16>
470487
// call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Benchmark name : gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32
2+
Platform : Intel(R) Data Center GPU Max 1550
3+
Requirements : doubleGRF
4+
5+
Kernel test_kernel : 250 registers
6+
the kernel execution time is (ms, on L0 runtime):avg: 0.7909, min: 0.5862, max: 2.3459 (over 1000 runs)
7+
TFlops : avg:173.775, min:58.587, max:234.457

0 commit comments

Comments
 (0)