@@ -1376,4 +1376,86 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
1376
1376
return err ;
1377
1377
}
1378
1378
1379
+ int ompi_coll_base_allreduce_intra_k_bruck (const void * sbuf , void * rbuf , int count ,
1380
+ struct ompi_datatype_t * dtype ,
1381
+ struct ompi_op_t * op ,
1382
+ struct ompi_communicator_t * comm ,
1383
+ mca_coll_base_module_t * module )
1384
+ {
1385
+ int line = -1 ;
1386
+ char * partial_buf = NULL ;
1387
+ char * partial_buf_start = NULL ;
1388
+ char * sendtmpbuf = NULL ;
1389
+ char * buffer1 = NULL ;
1390
+ char * buffer1_start = NULL ;
1391
+ int err = OMPI_SUCCESS ;
1392
+
1393
+ ptrdiff_t extent , lb ;
1394
+ ompi_datatype_get_extent (dtype , & lb , & extent );
1395
+
1396
+ int rank = ompi_comm_rank (comm );
1397
+ int size = ompi_comm_size (comm );
1398
+
1399
+ sendtmpbuf = (char * ) sbuf ;
1400
+ if ( sbuf == MPI_IN_PLACE ) {
1401
+ sendtmpbuf = (char * )rbuf ;
1402
+ }
1403
+ ptrdiff_t buf_size , gap = 0 ;
1404
+ buf_size = opal_datatype_span (& dtype -> super , (int64_t )count * size , & gap );
1405
+ partial_buf = (char * ) malloc (buf_size );
1406
+ partial_buf_start = partial_buf - gap ;
1407
+ buf_size = opal_datatype_span (& dtype -> super , (int64_t )count , & gap );
1408
+ buffer1 = (char * ) malloc (buf_size );
1409
+ buffer1_start = buffer1 - gap ;
1410
+
1411
+ err = ompi_datatype_copy_content_same_ddt (dtype , count ,
1412
+ (char * )buffer1_start ,
1413
+ (char * )sendtmpbuf );
1414
+ if (MPI_SUCCESS != err ) { line = __LINE__ ; goto err_hndl ; }
1415
+
1416
+ // apply allgather data so that each rank has a full copy to do reduce (trade bandwidth for better latency)
1417
+ err = comm -> c_coll -> coll_allgather (buffer1_start , count , dtype ,
1418
+ partial_buf_start , count , dtype ,
1419
+ comm , comm -> c_coll -> coll_allgather_module );
1420
+ if (MPI_SUCCESS != err ) { line = __LINE__ ; goto err_hndl ; }
1421
+
1422
+ for (int target = 1 ; target < size ; target ++ )
1423
+ {
1424
+ ompi_op_reduce (op ,
1425
+ partial_buf_start + (ptrdiff_t )target * count * extent ,
1426
+ partial_buf_start ,
1427
+ count ,
1428
+ dtype );
1429
+ }
1430
+
1431
+ // move data to rbuf
1432
+ err = ompi_datatype_copy_content_same_ddt (dtype , count ,
1433
+ (char * )rbuf ,
1434
+ (char * )partial_buf_start );
1435
+ if (MPI_SUCCESS != err ) { line = __LINE__ ; goto err_hndl ; }
1436
+
1437
+ if (NULL != buffer1 ) {
1438
+ free (buffer1 );
1439
+ buffer1 = NULL ;
1440
+ buffer1_start = NULL ;
1441
+ }
1442
+ return OMPI_SUCCESS ;
1443
+
1444
+ err_hndl :
1445
+ if (NULL != partial_buf ) {
1446
+ free (partial_buf );
1447
+ partial_buf = NULL ;
1448
+ partial_buf_start = NULL ;
1449
+ }
1450
+ if (NULL != buffer1 ) {
1451
+ free (buffer1 );
1452
+ buffer1 = NULL ;
1453
+ buffer1_start = NULL ;
1454
+ }
1455
+ OPAL_OUTPUT ((ompi_coll_base_framework .framework_output , "%s:%4d\tError occurred %d, rank %2d" ,
1456
+ __FILE__ , line , err , rank ));
1457
+ (void )line ; // silence compiler warning
1458
+ return err ;
1459
+
1460
+ }
1379
1461
/* copied function (with appropriate renaming) ends here */
0 commit comments