54
54
#if !defined(ELEM_TYPE )
55
55
# define ELEM_TYPE double
56
56
#endif
57
- #if !defined(MAX_KERNEL_DIM )
58
- # define MAX_KERNEL_DIM 80
59
- #endif
60
57
#if !defined(ALIGNMENT )
61
58
# define ALIGNMENT 64
62
59
#endif
@@ -291,6 +288,7 @@ int main(int argc, char* argv[]) {
291
288
#else
292
289
const int mn = m * n , mk = m * k , kn = k * n ;
293
290
#endif
291
+ const int max_kernel_dim = ceil (sqrt (m * n ));
294
292
int * stack_hst = NULL , * stack_dev = NULL , * trans_hst = NULL , * trans_dev = NULL ;
295
293
ELEM_TYPE * amat_hst = NULL , * bmat_hst = NULL , * cmat_hst = NULL ;
296
294
ELEM_TYPE * amat_dev = NULL , * bmat_dev = NULL , * cmat_dev = NULL ;
@@ -353,7 +351,7 @@ int main(int argc, char* argv[]) {
353
351
PRINTF (
354
352
"%s%s%i %i %i %i %i %i %i %i\n" , 0 < argc ? argv [0 ] : "" , 0 < argc ? " " : "" , nrepeat , stack_size , m , n , k , nc , na , nb );
355
353
PRINTF ("typename (id=%i): %s\n" , DBCSR_TYPE (ELEM_TYPE ), DBCSR_STRINGIFY (ELEM_TYPE ));
356
- if (MAX_KERNEL_DIM < m || MAX_KERNEL_DIM < n || MAX_KERNEL_DIM < k ) {
354
+ if (MAX_KERNEL_DIM < max_kernel_dim ) {
357
355
fprintf (stderr , "ERROR: Matrix shape exceeds MAX_KERNEL_DIM!\n" );
358
356
result = EXIT_FAILURE ;
359
357
}
@@ -364,7 +362,7 @@ int main(int argc, char* argv[]) {
364
362
CHECK (c_dbcsr_acc_host_mem_allocate ((void * * )(void * )& stack_hst , sizeof (int ) * 3 * stack_size , stream ), & result , check );
365
363
CHECK (c_dbcsr_acc_host_mem_allocate ((void * * )(void * )& trans_hst , sizeof (int ) * nb , stream ), & result , check );
366
364
CHECK (c_dbcsr_acc_stream_sync (stream ), & result , check ); /* ensure host-data is allocated */
367
- if (NULL != amat_hst && NULL != bmat_hst && NULL != trans_hst && NULL != stack_hst ) {
365
+ if (NULL != amat_hst && NULL != bmat_hst && NULL != trans_hst && NULL != stack_hst && EXIT_SUCCESS == result ) {
368
366
init_stack (stack_hst , stack_size , NRAND , rnd , mn , mk , kn , nc , na , nb );
369
367
#if defined(_OPENMP )
370
368
# pragma omp parallel
@@ -404,7 +402,7 @@ int main(int argc, char* argv[]) {
404
402
}
405
403
#if defined(USE_LIBXSMM )
406
404
CHECK (c_dbcsr_acc_stream_sync (stream ), & result , check );
407
- if (NULL != amat_hst && NULL != bmat_hst && NULL != stack_hst ) {
405
+ if (NULL != amat_hst && NULL != bmat_hst && NULL != stack_hst && EXIT_SUCCESS == result ) {
408
406
const size_t size = (sizeof (ELEM_TYPE ) * (mk * na + kn * nb ) + sizeof (int ) * 3 * stack_size ) * nrepeat_h2d ;
409
407
duration = libxsmm_timer_duration (start , libxsmm_timer_tick ());
410
408
perf_h2d = size / (duration * (1ULL << 30 ));
@@ -414,17 +412,17 @@ int main(int argc, char* argv[]) {
414
412
#if defined(TRANSPOSE ) && defined(VALIDATE )
415
413
/* warmup execution and prebuild transpose-kernel */
416
414
for (r = 0 ; r < warmup / 2 ; ++ r ) {
417
- CHECK (libsmm_acc_transpose (trans_dev , 0 /*offset*/ , nb , bmat_dev , DBCSR_TYPE (ELEM_TYPE ), k , n , MAX_KERNEL_DIM , stream ),
415
+ CHECK (libsmm_acc_transpose (trans_dev , 0 /*offset*/ , nb , bmat_dev , DBCSR_TYPE (ELEM_TYPE ), k , n , max_kernel_dim , stream ),
418
416
& result , check );
419
- CHECK (libsmm_acc_transpose (trans_dev , 0 /*offset*/ , nb , bmat_dev , DBCSR_TYPE (ELEM_TYPE ), n , k , MAX_KERNEL_DIM , stream ),
417
+ CHECK (libsmm_acc_transpose (trans_dev , 0 /*offset*/ , nb , bmat_dev , DBCSR_TYPE (ELEM_TYPE ), n , k , max_kernel_dim , stream ),
420
418
& result , check );
421
419
}
422
420
# if defined(USE_LIBXSMM )
423
421
CHECK (c_dbcsr_acc_stream_sync (stream ), & result , check );
424
422
start = libxsmm_timer_tick ();
425
423
# endif
426
424
/* to perform NN-SMMs on the device, all B-matrices are transposed upfront (SMM-kernel is limited to NT) */
427
- CHECK (libsmm_acc_transpose (trans_dev , 0 /*offset*/ , nb , bmat_dev , DBCSR_TYPE (ELEM_TYPE ), k , n , MAX_KERNEL_DIM , stream ),
425
+ CHECK (libsmm_acc_transpose (trans_dev , 0 /*offset*/ , nb , bmat_dev , DBCSR_TYPE (ELEM_TYPE ), k , n , max_kernel_dim , stream ),
428
426
& result , check );
429
427
# if defined(USE_LIBXSMM )
430
428
CHECK (c_dbcsr_acc_stream_sync (stream ), & result , check );
@@ -434,7 +432,7 @@ int main(int argc, char* argv[]) {
434
432
/* warmup execution and prebuild SMM-kernel */
435
433
for (r = 0 ; r < warmup ; ++ r ) {
436
434
CHECK (libsmm_acc_process (stack_hst , stack_dev , stack_size , DBCSR_TYPE (ELEM_TYPE ), amat_dev , bmat_dev , cmat_dev , m , n , k ,
437
- MAX_KERNEL_DIM , 1 /*homogeneous*/ , stream , stream ),
435
+ max_kernel_dim , 1 /*homogeneous*/ , stream , stream ),
438
436
& result , check );
439
437
}
440
438
CHECK (c_dbcsr_acc_memset_zero (cmat_dev , 0 /*offset*/ , sizeof (ELEM_TYPE ) * mn * nc , stream ), & result , check );
@@ -445,28 +443,30 @@ int main(int argc, char* argv[]) {
445
443
for (r = 0 ; r < nrepeat ; ++ r ) {
446
444
/* GPU-kernel is limited to C += Ai * Bi^T, i.e., NT (for NN, all Bi must be transposed upfront) */
447
445
CHECK (libsmm_acc_process (stack_hst , stack_dev , stack_size , DBCSR_TYPE (ELEM_TYPE ), amat_dev , bmat_dev , cmat_dev , m , n , k ,
448
- MAX_KERNEL_DIM , 1 /*homogeneous*/ , stream , stream ),
446
+ max_kernel_dim , 1 /*homogeneous*/ , stream , stream ),
449
447
& result , check );
450
448
}
451
449
#if defined(USE_LIBXSMM )
452
450
CHECK (c_dbcsr_acc_stream_sync (stream ), & result , check );
453
451
duration = libxsmm_timer_duration (start , libxsmm_timer_tick ());
454
- if (0 < duration && EXIT_SUCCESS == result ) {
452
+ if (EXIT_SUCCESS == result ) {
453
+ if (0 < duration ) {
455
454
# if defined(TRANSPOSE ) && defined(VALIDATE )
456
- PRINTF ("transpose: %.2g ms %.1f GFLOPS/s\n" , 1000.0 * (duration + transpose ) / (nrepeat * nrepeat_smm ),
457
- 1E-9 * ((size_t )2 * m * n * k * stack_size * nrepeat * nrepeat_smm ) / (duration + transpose ));
455
+ PRINTF ("transpose: %.2g ms %.1f GFLOPS/s\n" , 1000.0 * (duration + transpose ) / (nrepeat * nrepeat_smm ),
456
+ 1E-9 * ((size_t )2 * m * n * k * stack_size * nrepeat * nrepeat_smm ) / (duration + transpose ));
458
457
# endif
459
- perf_dev = 1E-9 * ((size_t )2 * m * n * k * stack_size * nrepeat * nrepeat_smm ) / duration ;
460
- PRINTF ("device: %.2g ms %.1f GFLOPS/s\n" , 1000.0 * duration / (nrepeat * nrepeat_smm ), perf_dev );
461
- }
462
- else {
458
+ perf_dev = 1E-9 * ((size_t )2 * m * n * k * stack_size * nrepeat * nrepeat_smm ) / duration ;
459
+ PRINTF ("device: %.2g ms %.1f GFLOPS/s\n" , 1000.0 * duration / (nrepeat * nrepeat_smm ), perf_dev );
460
+ }
461
+ else {
463
462
# if defined(TRANSPOSE )
464
- PRINTF ("transpose: 0 ms 0 GFLOPS/s\n" );
463
+ PRINTF ("transpose: 0 ms 0 GFLOPS/s\n" );
465
464
# endif
466
- PRINTF ("device: 0 ms 0 GFLOPS/s\n" );
465
+ PRINTF ("device: 0 ms 0 GFLOPS/s\n" );
466
+ }
467
467
}
468
468
# if defined(VALIDATE )
469
- {
469
+ if ( EXIT_SUCCESS == result ) {
470
470
ELEM_TYPE * const gold_hst = (ELEM_TYPE * )(0 != check ? libxsmm_malloc (sizeof (ELEM_TYPE ) * mn * nc ) : NULL );
471
471
/* determine host's performance independent of current result code/status */
472
472
if (NULL != gold_hst && NULL != amat_hst && NULL != bmat_hst && NULL != stack_hst ) {
0 commit comments