@@ -383,6 +383,38 @@ module @gemm attributes {gpu.container_module} {
383
383
gpu.return
384
384
}
385
385
}
386
+
387
+ // compute CPU reference (takes minutes)
388
+ func.func @cpu_reference (%A : memref <4096 x4096 xf16 >, %B : memref <4096 x4096 xf16 >, %C : memref <4096 x4096 xf32 >) {
389
+ %c4096 = arith.constant 4096 : index
390
+ %c16 = arith.constant 16 : index
391
+ %c1 = arith.constant 1 : index
392
+ %c0 = arith.constant 0 : index
393
+ scf.for %i = %c0 to %c4096 step %c1 {
394
+ scf.for %j = %c0 to %c4096 step %c1 {
395
+ %c_curr = memref.load %C [%i , %j ] : memref <4096 x4096 xf32 >
396
+ %c_val = scf.for %k_tile = %c0 to %c4096 step %c16 iter_args (%c_partial = %c_curr ) -> f32 {
397
+ %c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args (%c_dpas_partial = %c_partial ) -> f32 {
398
+ %k_dpas = arith.addi %k_tile , %k : index
399
+ %a_val = memref.load %A [%i , %k_dpas ] : memref <4096 x4096 xf16 >
400
+ %b_val = memref.load %B [%k_dpas , %j ] : memref <4096 x4096 xf16 >
401
+ %a_cast = arith.extf %a_val : f16 to f32
402
+ %b_cast = arith.extf %b_val : f16 to f32
403
+ %t = arith.mulf %a_cast , %b_cast : f32
404
+ // %t_cast = arith.extf %t : f16 to f16
405
+ %c_sum = arith.addf %t , %c_dpas_partial : f32
406
+ scf.yield %c_sum : f32
407
+ }
408
+ scf.yield %c_val_dpas : f32
409
+ }
410
+ // %c_val_f16 = arith.truncf %c_val : f32 to f16
411
+ // %c_val_ = arith.extf %c_val_f16 : f16 to f32
412
+ memref.store %c_val , %C [%i , %j ] : memref <4096 x4096 xf32 >
413
+ }
414
+ }
415
+ return
416
+ }
417
+
386
418
func.func @main () attributes {llvm.emit_c_interface } {
387
419
%c0 = arith.constant 0 : index
388
420
%c1 = arith.constant 1 : index
@@ -448,23 +480,8 @@ module @gemm attributes {gpu.container_module} {
448
480
// run GPU
449
481
%2 = call @test (%A , %B , %C ) : (memref <4096 x4096 xf16 >, memref <4096 x4096 xf16 >, memref <4096 x4096 xf32 >) -> memref <4096 x4096 xf32 >
450
482
451
- // compute CPU reference (takes minutes)
452
- scf.for %i = %c0 to %c4096 step %c1 {
453
- scf.for %j = %c0 to %c4096 step %c1 {
454
- %c_curr = memref.load %C_ref [%i , %j ] : memref <4096 x4096 xf32 >
455
- %c_val = scf.for %k = %c0 to %c4096 step %c1 iter_args (%c_partial = %c_curr ) -> f32 {
456
- %a_val = memref.load %A [%i , %k ] : memref <4096 x4096 xf16 >
457
- %b_val = memref.load %B [%k , %j ] : memref <4096 x4096 xf16 >
458
- %a_cast = arith.extf %a_val : f16 to f32
459
- %b_cast = arith.extf %b_val : f16 to f32
460
- %t = arith.mulf %a_cast , %b_cast : f32
461
- // %t_cast = arith.extf %t : f16 to f32
462
- %c_sum = arith.addf %t , %c_partial : f32
463
- scf.yield %c_sum : f32
464
- }
465
- memref.store %c_val , %C_ref [%i , %j ] : memref <4096 x4096 xf32 >
466
- }
467
- }
483
+ // run CPU
484
+ call @cpu_reference (%A , %B , %C_ref ) : (memref <4096 x4096 xf16 >, memref <4096 x4096 xf16 >, memref <4096 x4096 xf32 >) -> ()
468
485
469
486
// %cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16>
470
487
// call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
0 commit comments