Skip to content

Commit

Permalink
add bruck's method to allgather and allreduce; add knomial to allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
juntangc committed Feb 11, 2024
1 parent 7fc4535 commit 7f8a97b
Show file tree
Hide file tree
Showing 7 changed files with 445 additions and 3 deletions.
210 changes: 210 additions & 0 deletions ompi/mca/coll/base/coll_base_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -866,4 +866,214 @@ ompi_coll_base_allgather_intra_basic_linear(const void *sbuf, int scount,
return err;
}

/*
* ompi_coll_base_allgather_intra_k_bruck
*
* Function: allgather using O(logk(N)) steps.
* Accepts: Same arguments as MPI_Allgather
* Returns: MPI_SUCCESS or error code
*
* Description: This method extend ompi_coll_base_allgather_intra_bruck to handle any
* radix k; use non-blocking communication to take advantage of multiple ports
* The algorithm detail is described in Bruck et al. (1997),
* "Efficient Algorithms for All-to-all Communications
* in Multiport Message-Passing Systems"
* Memory requirements: non-zero ranks require temporary buffer to perform final
* step in the algorithm.
*
* Example on 10 nodes with k=3:
* Initialization: everyone has its own buffer at location 0 in rbuf
* This means if user specified MPI_IN_PLACE for sendbuf
* we must copy our block from recvbuf to beginning!
* # 0 1 2 3 4 5 6 7 8 9
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* Step 0: send message to (rank - k^0 * i), receive message from (rank + k^0 * i)
* message size is k^0 * block size and i is between [1, k-1]
* # 0 1 2 3 4 5 6 7 8 9
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [1] [2] [3] [4] [5] [6] [7] [8] [9] [0]
* [2] [3] [4] [5] [6] [7] [8] [9] [0] [1]
* Step 1: send message to (rank - k^1 * i), receive message from (rank + k^1 * i)
* message size is k^1 * block size
* # 0 1 2 3 4 5 6 7 8 9
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [1] [2] [3] [4] [5] [6] [7] [8] [9] [0]
* [2] [3] [4] [5] [6] [7] [8] [9] [0] [1]
* [3] [4] [5] [6] [7] [8] [9] [0] [1] [2]
* [4] [5] [6] [7] [8] [9] [0] [1] [2] [3]
* [5] [6] [7] [8] [9] [0] [1] [2] [3] [4]
* [6] [7] [8] [9] [0] [1] [2] [3] [4] [5]
* [7] [8] [9] [0] [1] [2] [3] [4] [5] [6]
* [8] [9] [0] [1] [2] [3] [4] [5] [6] [7]
* Step 2: send message to (rank - k^2 * i), receive message from (rank + k^2 * i)
* message size is k^2 * block size or "all remaining blocks" for each exchange
* # 0 1 2 3 4 5 6 7 8 9
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [1] [2] [3] [4] [5] [6] [7] [8] [9] [0]
* [2] [3] [4] [5] [6] [7] [8] [9] [0] [1]
* [3] [4] [5] [6] [7] [8] [9] [0] [1] [2]
* [4] [5] [6] [7] [8] [9] [0] [1] [2] [3]
* [5] [6] [7] [8] [9] [0] [1] [2] [3] [4]
* [6] [7] [8] [9] [0] [1] [2] [3] [4] [5]
* [7] [8] [9] [0] [1] [2] [3] [4] [5] [6]
* [8] [9] [0] [1] [2] [3] [4] [5] [6] [7]
* [9] [0] [1] [2] [3] [4] [5] [6] [7] [8]
* Finalization: Do a local shift (except rank 0) to get data in correct place
* # 0 1 2 3 4 5 6 7 8 9
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
* [0] [1] [2] [3] [4] [5] [6] [7] [8] [9]
*/
int ompi_coll_base_allgather_intra_k_bruck(const void *sbuf, int scount,
struct ompi_datatype_t *sdtype,
void* rbuf, int rcount,
struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
int line = -1, rank, size, dst, src, dist, err = MPI_SUCCESS;
int recvcount, distance;
int k = 8;
ptrdiff_t rlb, rextent;
ptrdiff_t rsize, rgap = 0;
ompi_request_t **reqs;
int num_reqs, max_reqs = 0;

char *tmpsend = NULL;
char *tmprecv = NULL;
char *tmp_buf = NULL;
char *tmp_buf_start = NULL;

rank = ompi_comm_rank(comm);
size = ompi_comm_size(comm);

OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
"coll:base:allgather_intra_k_bruck radix %d rank %d", k, rank));
err = ompi_datatype_get_extent (rdtype, &rlb, &rextent);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }

if (0 != rank) {
/* Compute the temporary buffer size, including datatypes empty gaps */
rsize = opal_datatype_span(&rdtype->super, (int64_t)rcount * (size - rank), &rgap);
tmp_buf = (char *) malloc(rsize);
tmp_buf_start = tmp_buf - rgap;
}

// tmprecv points to the data initially on this rank, handle mpi_in_place case
tmprecv = (char*) rbuf;
if (MPI_IN_PLACE != sbuf) {
tmpsend = (char*) sbuf;
err = ompi_datatype_sndrcv(tmpsend, scount, sdtype, tmprecv, rcount, rdtype);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
} else if (0 != rank) {
// root data placement is at the correct poistion
tmpsend = ((char*)rbuf) + (ptrdiff_t)rank * (ptrdiff_t)rcount * rextent;
err = ompi_datatype_copy_content_same_ddt(rdtype, rcount, tmprecv, tmpsend);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
}
/*
Maximum number of communication phases logk(n)
For each phase i, rank r:
- increase the distance and recvcount by k times
- sends (k - 1) messages which starts at beginning of rbuf and has size
(recvcount) to rank (r - distance * j)
- receives (k - 1) messages of size recvcount from rank (r + distance * j)
at location (rbuf + distance * j * rcount * rext)
- calculate the remaining data for each of the (k - 1) messages in the last
phase to complete all transactions
*/
max_reqs = 2 * (k - 1);
reqs = ompi_coll_base_comm_get_reqs(module->base_data, max_reqs);
recvcount = 1;
tmpsend = (char*) rbuf;
for (distance = 1; distance < size; distance *= k) {
num_reqs = 0;
for (int j = 1; j < k; j++)
{
if (distance * j >= size) {
break;
}
src = (rank + distance * j) % size;
dst = (rank - distance * j + size) % size;

tmprecv = tmpsend + (ptrdiff_t)distance * j * rcount * rextent;

if (distance <= (size / k)) {
recvcount = distance;
} else {
recvcount = (distance < (size - distance * j)?
distance:(size - distance * j));
}

err = MCA_PML_CALL(irecv(tmprecv,
recvcount * rcount,
rdtype,
src,
MCA_COLL_BASE_TAG_ALLGATHER,
comm,
&reqs[num_reqs++]));
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
err = MCA_PML_CALL(isend(tmpsend,
recvcount * rcount,
rdtype,
dst,
MCA_COLL_BASE_TAG_ALLGATHER,
MCA_PML_BASE_SEND_STANDARD,
comm,
&reqs[num_reqs++]));
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
}
err = ompi_request_wait_all(num_reqs, reqs, MPI_STATUSES_IGNORE);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
}

// Finalization step: On all ranks except 0, data needs to be shifted locally
if (0 != rank) {
err = ompi_datatype_copy_content_same_ddt(rdtype,
((ptrdiff_t)(size - rank) * rcount),
tmp_buf_start,
rbuf);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }

tmpsend = (char*) rbuf + (ptrdiff_t)(size - rank) * rcount * rextent;
err = ompi_datatype_copy_content_same_ddt(rdtype,
(ptrdiff_t)rank * rcount,
rbuf,
tmpsend);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }

tmprecv = (char*) rbuf + (ptrdiff_t)rank * rcount * rextent;
err = ompi_datatype_copy_content_same_ddt(rdtype,
(ptrdiff_t)(size - rank) * rcount,
tmprecv,
tmp_buf_start);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
}

if(tmp_buf != NULL) {
free(tmp_buf);
tmp_buf = NULL;
tmp_buf_start = NULL;
}

return OMPI_SUCCESS;

err_hndl:
OPAL_OUTPUT((ompi_coll_base_framework.framework_output, "%s:%4d\tError occurred %d, rank %2d",
__FILE__, line, err, rank));
if(tmp_buf != NULL) {
free(tmp_buf);
tmp_buf = NULL;
tmp_buf_start = NULL;
}
(void)line; // silence compiler warning
return err;
}
/* copied function (with appropriate renaming) ends here */
70 changes: 70 additions & 0 deletions ompi/mca/coll/base/coll_base_allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -1376,4 +1376,74 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
return err;
}

int ompi_coll_base_allreduce_intra_k_bruck(const void *sbuf, void *rbuf, int count,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
{
int line = -1;
char *partial_buf = NULL;
char *partial_buf_start = NULL;
char *sendtmpbuf = NULL;
char *buffer1 = NULL;
char *buffer1_start = NULL;
int err = OMPI_SUCCESS;

ptrdiff_t extent, lb;
ompi_datatype_get_extent(dtype, &lb, &extent);

int rank = ompi_comm_rank(comm);
int size = ompi_comm_size(comm);

sendtmpbuf = (char*) sbuf;
if( sbuf == MPI_IN_PLACE ) {
sendtmpbuf = (char *)rbuf;
}
ptrdiff_t buf_size, gap = 0;
buf_size = opal_datatype_span(&dtype->super, (int64_t)count * size, &gap);
partial_buf = (char *) malloc(buf_size);
partial_buf_start = partial_buf - gap;
buf_size = opal_datatype_span(&dtype->super, (int64_t)count, &gap);
buffer1 = (char *) malloc(buf_size);
buffer1_start = buffer1 - gap;

err = ompi_datatype_copy_content_same_ddt(dtype, count,
(char*)buffer1_start,
(char*)sendtmpbuf);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }

/* Local roots perform a allreduce on the upper comm */
err = comm->c_coll->coll_allgather(buffer1_start, count, dtype,
partial_buf_start, count, dtype,
comm, comm->c_coll->coll_allgather_module);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }

for(int target = 1; target < size; target++)
{
ompi_op_reduce(op,
partial_buf_start + (ptrdiff_t)target * count * extent,
partial_buf_start,
count,
dtype);
}

// move data to rbuf
err = ompi_datatype_copy_content_same_ddt(dtype, count,
(char*)rbuf,
(char*)partial_buf_start);
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }

err_hndl:
if (NULL != partial_buf) {
free(partial_buf);
partial_buf = NULL;
partial_buf_start = NULL;
}
OPAL_OUTPUT((ompi_coll_base_framework.framework_output, "%s:%4d\tError occurred %d, rank %2d",
__FILE__, line, err, rank));
(void)line; // silence compiler warning
return err;

}
/* copied function (with appropriate renaming) ends here */
3 changes: 3 additions & 0 deletions ompi/mca/coll/base/coll_base_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ int ompi_coll_base_allgather_intra_ring(ALLGATHER_ARGS);
int ompi_coll_base_allgather_intra_neighborexchange(ALLGATHER_ARGS);
int ompi_coll_base_allgather_intra_basic_linear(ALLGATHER_ARGS);
int ompi_coll_base_allgather_intra_two_procs(ALLGATHER_ARGS);
int ompi_coll_base_allgather_intra_k_bruck(ALLGATHER_ARGS);

/* All GatherV */
int ompi_coll_base_allgatherv_intra_bruck(ALLGATHERV_ARGS);
Expand All @@ -211,6 +212,7 @@ int ompi_coll_base_allreduce_intra_ring_segmented(ALLREDUCE_ARGS, uint32_t segsi
int ompi_coll_base_allreduce_intra_basic_linear(ALLREDUCE_ARGS);
int ompi_coll_base_allreduce_intra_redscat_allgather(ALLREDUCE_ARGS);
int ompi_coll_base_allreduce_intra_allgather_reduce(ALLREDUCE_ARGS);
int ompi_coll_base_allreduce_intra_k_bruck(ALLREDUCE_ARGS);

/* AlltoAll */
int ompi_coll_base_alltoall_intra_pairwise(ALLTOALL_ARGS);
Expand Down Expand Up @@ -274,6 +276,7 @@ int ompi_coll_base_reduce_intra_binary(REDUCE_ARGS, uint32_t segsize, int max_ou
int ompi_coll_base_reduce_intra_binomial(REDUCE_ARGS, uint32_t segsize, int max_outstanding_reqs );
int ompi_coll_base_reduce_intra_in_order_binary(REDUCE_ARGS, uint32_t segsize, int max_outstanding_reqs );
int ompi_coll_base_reduce_intra_redscat_gather(REDUCE_ARGS);
int ompi_coll_base_reduce_intra_knomial(REDUCE_ARGS, uint32_t segsize, int max_outstanding_reqs );

/* Reduce_scatter */
int ompi_coll_base_reduce_scatter_intra_nonoverlapping(REDUCESCATTER_ARGS);
Expand Down
Loading

0 comments on commit 7f8a97b

Please sign in to comment.