Skip to content

Commit

Permalink
[stubmpi] implement Gather, Gatherv, and Type_get_extent
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Dec 6, 2024
1 parent 2f1ccaa commit 018a999
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions src/madness/world/stubmpi.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <madness/world/madness_exception.h>
#include <madness/world/timers.h>

typedef int MPI_Group;
Expand All @@ -28,6 +29,8 @@ typedef int MPI_Errhandler;

typedef int MPI_Info;

typedef std::ptrdiff_t MPI_Aint;

/* MPI's error classes */
/* these constants are consistent with MPICH2 mpi.h */
#define MPI_SUCCESS 0 /* Successful return code */
Expand Down Expand Up @@ -86,6 +89,48 @@ typedef int MPI_Datatype;
#define MPI_UNSIGNED_LONG_LONG ((MPI_Datatype)0x4c000819)
#define MPI_LONG_LONG ((MPI_Datatype)0x4c000809)

inline int MPI_Type_get_extent(MPI_Datatype datatype, MPI_Aint *lb,
MPI_Aint *extent) {
switch(datatype) {
case MPI_CHAR:
*extent = sizeof(char); break;
case MPI_SIGNED_CHAR:
*extent = sizeof(signed char); break;
case MPI_UNSIGNED_CHAR:
*extent = sizeof(unsigned char); break;
case MPI_BYTE:
*extent = 1; break;
case MPI_WCHAR:
*extent = sizeof(wchar_t); break;
case MPI_SHORT:
*extent = sizeof(short); break;
case MPI_UNSIGNED_SHORT:
*extent = sizeof(unsigned short); break;
case MPI_INT:
*extent = sizeof(int); break;
case MPI_UNSIGNED:
*extent = sizeof(unsigned); break;
case MPI_LONG:
*extent = sizeof(long); break;
case MPI_UNSIGNED_LONG:
*extent = sizeof(unsigned long); break;
case MPI_FLOAT:
*extent = sizeof(float); break;
case MPI_DOUBLE:
*extent = sizeof(double); break;
case MPI_LONG_DOUBLE:
*extent = sizeof(long double); break;
case MPI_LONG_LONG_INT: // same as MPI_LONG_LONG
*extent = sizeof(long long int); break;
case MPI_UNSIGNED_LONG_LONG:
*extent = sizeof(unsigned long long); break;
default:
*extent = MPI_UNDEFINED;
}
*lb = 0;
return MPI_SUCCESS;
}

/* MPI Reduction operation */
/* these constants are consistent with MPICH2 mpi.h */
typedef int MPI_Op;
Expand Down Expand Up @@ -169,6 +214,29 @@ inline int MPI_Bsend(void*, int, MPI_Datatype, int, int, MPI_Comm) { return MPI_
inline int MPI_Irecv(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Request*) { return MPI_ERR_COMM; }
inline int MPI_Recv(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*) { return MPI_ERR_COMM; }

// Gather = copy
inline int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
void *recvbuf, const int recvcounts[], const int displs[], MPI_Datatype recvtype,
int root, MPI_Comm) {
MPI_Aint recvtype_extent;
MPI_Aint recvtype_lb;
MPI_Type_get_extent(recvtype, &recvtype_lb, &recvtype_extent);
MADNESS_ASSERT(recvtype_lb == 0);
MPI_Aint sendtype_extent;
MPI_Aint sendtype_lb;
MPI_Type_get_extent(sendtype, &sendtype_lb, &sendtype_extent);
MADNESS_ASSERT(sendtype_lb == 0);
MADNESS_ASSERT(sendcount * sendtype_extent <= recvcounts[0] * recvtype_extent);
std::memcpy(recvbuf, sendbuf, sendcount * sendtype_extent);
return MPI_SUCCESS;
}
inline int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm) {
const int recvcounts[1] = {recvcount};
const int displs[1] = {0};
return MPI_Gatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm);
}

// Bcast does nothing but return MPI_SUCCESS
inline int MPI_Bcast(void*, int, MPI_Datatype, int, MPI_Comm) { return MPI_SUCCESS; }

Expand Down

0 comments on commit 018a999

Please sign in to comment.