Skip to content

Commit 51298c7

Browse files
committed
mca/coll: add allreduce method based on allgather (bruck)
use allgather to do allreduce; this reduce the latency comparing to gather and reduce on root followed by broadcast to other ranks at the cost of additional memory usage and message exchanges. Signed-off-by: Jun Tang <[email protected]>
1 parent aca2938 commit 51298c7

File tree

3 files changed

+87
-1
lines changed

3 files changed

+87
-1
lines changed

ompi/mca/coll/base/coll_base_allreduce.c

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,4 +1376,86 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
13761376
return err;
13771377
}
13781378

1379+
int ompi_coll_base_allreduce_intra_k_bruck(const void *sbuf, void *rbuf, int count,
1380+
struct ompi_datatype_t *dtype,
1381+
struct ompi_op_t *op,
1382+
struct ompi_communicator_t *comm,
1383+
mca_coll_base_module_t *module)
1384+
{
1385+
int line = -1;
1386+
char *partial_buf = NULL;
1387+
char *partial_buf_start = NULL;
1388+
char *sendtmpbuf = NULL;
1389+
char *buffer1 = NULL;
1390+
char *buffer1_start = NULL;
1391+
int err = OMPI_SUCCESS;
1392+
1393+
ptrdiff_t extent, lb;
1394+
ompi_datatype_get_extent(dtype, &lb, &extent);
1395+
1396+
int rank = ompi_comm_rank(comm);
1397+
int size = ompi_comm_size(comm);
1398+
1399+
sendtmpbuf = (char*) sbuf;
1400+
if( sbuf == MPI_IN_PLACE ) {
1401+
sendtmpbuf = (char *)rbuf;
1402+
}
1403+
ptrdiff_t buf_size, gap = 0;
1404+
buf_size = opal_datatype_span(&dtype->super, (int64_t)count * size, &gap);
1405+
partial_buf = (char *) malloc(buf_size);
1406+
partial_buf_start = partial_buf - gap;
1407+
buf_size = opal_datatype_span(&dtype->super, (int64_t)count, &gap);
1408+
buffer1 = (char *) malloc(buf_size);
1409+
buffer1_start = buffer1 - gap;
1410+
1411+
err = ompi_datatype_copy_content_same_ddt(dtype, count,
1412+
(char*)buffer1_start,
1413+
(char*)sendtmpbuf);
1414+
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
1415+
1416+
// apply allgather data so that each rank has a full copy to do reduce (trade bandwidth for better latency)
1417+
err = comm->c_coll->coll_allgather(buffer1_start, count, dtype,
1418+
partial_buf_start, count, dtype,
1419+
comm, comm->c_coll->coll_allgather_module);
1420+
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
1421+
1422+
for(int target = 1; target < size; target++)
1423+
{
1424+
ompi_op_reduce(op,
1425+
partial_buf_start + (ptrdiff_t)target * count * extent,
1426+
partial_buf_start,
1427+
count,
1428+
dtype);
1429+
}
1430+
1431+
// move data to rbuf
1432+
err = ompi_datatype_copy_content_same_ddt(dtype, count,
1433+
(char*)rbuf,
1434+
(char*)partial_buf_start);
1435+
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
1436+
1437+
if (NULL != buffer1) {
1438+
free(buffer1);
1439+
buffer1 = NULL;
1440+
buffer1_start = NULL;
1441+
}
1442+
return OMPI_SUCCESS;
1443+
1444+
err_hndl:
1445+
if (NULL != partial_buf) {
1446+
free(partial_buf);
1447+
partial_buf = NULL;
1448+
partial_buf_start = NULL;
1449+
}
1450+
if (NULL != buffer1) {
1451+
free(buffer1);
1452+
buffer1 = NULL;
1453+
buffer1_start = NULL;
1454+
}
1455+
OPAL_OUTPUT((ompi_coll_base_framework.framework_output, "%s:%4d\tError occurred %d, rank %2d",
1456+
__FILE__, line, err, rank));
1457+
(void)line; // silence compiler warning
1458+
return err;
1459+
1460+
}
13791461
/* copied function (with appropriate renaming) ends here */

ompi/mca/coll/base/coll_base_functions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ int ompi_coll_base_allreduce_intra_ring_segmented(ALLREDUCE_ARGS, uint32_t segsi
212212
int ompi_coll_base_allreduce_intra_basic_linear(ALLREDUCE_ARGS);
213213
int ompi_coll_base_allreduce_intra_redscat_allgather(ALLREDUCE_ARGS);
214214
int ompi_coll_base_allreduce_intra_allgather_reduce(ALLREDUCE_ARGS);
215+
int ompi_coll_base_allreduce_intra_k_bruck(ALLREDUCE_ARGS);
215216

216217
/* AlltoAll */
217218
int ompi_coll_base_alltoall_intra_pairwise(ALLTOALL_ARGS);

ompi/mca/coll/tuned/coll_tuned_allreduce_decision.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ static const mca_base_var_enum_value_t allreduce_algorithms[] = {
4343
{5, "segmented_ring"},
4444
{6, "rabenseifner"},
4545
{7, "allgather_reduce"},
46+
{8, "allreduce_bruck"},
4647
{0, NULL}
4748
};
4849

@@ -78,7 +79,7 @@ int ompi_coll_tuned_allreduce_intra_check_forced_init (coll_tuned_force_algorith
7879
mca_param_indices->algorithm_param_index =
7980
mca_base_component_var_register(&mca_coll_tuned_component.super.collm_version,
8081
"allreduce_algorithm",
81-
"Which allreduce algorithm is used. Can be locked down to any of: 0 ignore, 1 basic linear, 2 nonoverlapping (tuned reduce + tuned bcast), 3 recursive doubling, 4 ring, 5 segmented ring. "
82+
"Which allreduce algorithm is used. Can be locked down to any of: 0 ignore, 1 basic linear, 2 nonoverlapping (tuned reduce + tuned bcast), 3 recursive doubling, 4 ring, 5 segmented ring, 6 rabenseifner, 7 allgather_reduce, 8 allreduce_bruck. "
8283
"Only relevant if coll_tuned_use_dynamic_rules is true.",
8384
MCA_BASE_VAR_TYPE_INT, new_enum, 0, MCA_BASE_VAR_FLAG_SETTABLE,
8485
OPAL_INFO_LVL_5,
@@ -149,6 +150,8 @@ int ompi_coll_tuned_allreduce_intra_do_this(const void *sbuf, void *rbuf, int co
149150
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm, module);
150151
case (7):
151152
return ompi_coll_base_allreduce_intra_allgather_reduce(sbuf, rbuf, count, dtype, op, comm, module);
153+
case (8):
154+
return ompi_coll_base_allreduce_intra_k_bruck(sbuf, rbuf, count, dtype, op, comm, module);
152155
} /* switch */
153156
OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned:allreduce_intra_do_this attempt to select algorithm %d when only 0-%d is valid?",
154157
algorithm, ompi_coll_tuned_forced_max_algorithms[ALLREDUCE]));

0 commit comments

Comments
 (0)