Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

comm: create local_group/remote_group beform comm commit #7237

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/include/mpir_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ struct MPIR_Comm {
int rank; /* Value of MPI_Comm_rank */
MPIR_Attribute *attributes; /* List of attributes */
int local_size; /* Value of MPI_Comm_size for local group */
MPIR_Group *local_group, /* Groups in communicator. */
*remote_group; /* The local and remote groups are the
* same for intra communicators */
MPIR_Group *local_group; /* Groups in communicator. */
MPIR_Group *remote_group; /* The remote group in a inter communicator.
* Must be NULL in a intra communicator. */
MPIR_Comm_kind_t comm_kind; /* MPIR_COMM_KIND__INTRACOMM or MPIR_COMM_KIND__INTERCOMM */
char name[MPI_MAX_OBJECT_NAME]; /* Required for MPI-2 */
MPIR_Errhandler *errhandler; /* Pointer to the error handler structure */
Expand Down
9 changes: 9 additions & 0 deletions src/mpi/comm/builtin_comms.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ int MPIR_init_comm_world(void)
MPIR_Process.comm_world->remote_size = MPIR_Process.size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the remote_group is NULL, the remote_size should probably be 0.
Or, if we want to keep local_size == remote_size for now to lessen the impact on existing codes, we should update the comment in the struct definition and maybe add a TODO for future cleanup.

MPIR_Process.comm_world->local_size = MPIR_Process.size;

MPIR_Process.comm_world->local_group = MPIR_GROUP_WORLD_PTR;
MPIR_Group_add_ref(MPIR_GROUP_WORLD_PTR);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explicitly set remote_group to NULL to avoid uninitialized value.


mpi_errno = MPIR_Comm_commit(MPIR_Process.comm_world);
MPIR_ERR_CHECK(mpi_errno);

Expand Down Expand Up @@ -59,6 +62,9 @@ int MPIR_init_comm_self(void)
MPIR_Process.comm_self->remote_size = 1;
MPIR_Process.comm_self->local_size = 1;

MPIR_Process.comm_self->local_group = MPIR_GROUP_SELF_PTR;
MPIR_Group_add_ref(MPIR_GROUP_SELF_PTR);

mpi_errno = MPIR_Comm_commit(MPIR_Process.comm_self);
MPIR_ERR_CHECK(mpi_errno);

Expand Down Expand Up @@ -91,6 +97,9 @@ int MPIR_init_icomm_world(void)
MPIR_Process.icomm_world->remote_size = MPIR_Process.size;
MPIR_Process.icomm_world->local_size = MPIR_Process.size;

MPIR_Process.icomm_world->local_group = MPIR_GROUP_WORLD_PTR;
MPIR_Group_add_ref(MPIR_GROUP_WORLD_PTR);

mpi_errno = MPIR_Comm_commit(MPIR_Process.icomm_world);
MPIR_ERR_CHECK(mpi_errno);

Expand Down
93 changes: 81 additions & 12 deletions src/mpi/comm/comm_impl.c
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,7 @@ int MPIR_Comm_create_intra(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co
(*newcomm_ptr)->local_group = group_ptr;
MPIR_Group_add_ref(group_ptr);

(*newcomm_ptr)->remote_group = group_ptr;
MPIR_Group_add_ref(group_ptr);
(*newcomm_ptr)->remote_group = NULL;
(*newcomm_ptr)->context_id = (*newcomm_ptr)->recvcontext_id;
(*newcomm_ptr)->remote_size = (*newcomm_ptr)->local_size = n;

Expand Down Expand Up @@ -382,15 +381,12 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co
int mpi_errno = MPI_SUCCESS;
int new_context_id;
int *mapping = NULL;
int *remote_mapping = NULL;
MPIR_Comm *mapping_comm = NULL;
int remote_size = -1;
int rinfo[2];
MPIR_CHKLMEM_DECL(1);

MPIR_FUNC_ENTER;

MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM);
MPIR_Session *session_ptr = comm_ptr->session_ptr;

/* Create a new communicator from the specified group members */

Expand All @@ -409,6 +405,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co
MPIR_Assert(new_context_id != 0);
MPIR_Assert(new_context_id != comm_ptr->recvcontext_id);

MPIR_Comm *mapping_comm;
mpi_errno = MPII_Comm_create_calculate_mapping(group_ptr, comm_ptr, &mapping, &mapping_comm);
MPIR_ERR_CHECK(mpi_errno);

Expand All @@ -434,7 +431,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co

(*newcomm_ptr)->is_low_group = comm_ptr->is_low_group;

MPIR_Comm_set_session_ptr(*newcomm_ptr, comm_ptr->session_ptr);
MPIR_Comm_set_session_ptr(*newcomm_ptr, session_ptr);
}

/* There is an additional step. We must communicate the
Expand All @@ -445,6 +442,11 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co
* in the remote group, from which the remote network address
* mapping can be constructed. We need to use the "collective"
* context in the original intercommunicator */

int remote_size = -1;
int *remote_mapping; /* a list of remote ranks */
int rinfo[2];

if (comm_ptr->rank == 0) {
int info[2];
info[0] = new_context_id;
Expand Down Expand Up @@ -494,6 +496,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co

MPIR_Assert(remote_size >= 0);


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove blank line.

if (group_ptr->rank != MPI_UNDEFINED) {
(*newcomm_ptr)->remote_size = remote_size;
/* Now, everyone has the remote_mapping, and can apply that to
Expand All @@ -505,6 +508,23 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co
mapping, remote_mapping, mapping_comm, *newcomm_ptr);
MPIR_ERR_CHECK(mpi_errno);

/* create remote_group.
* FIXME: we can directly exchange group maps once we get rid of comm mappers */
MPIR_Group *remote_group;

MPIR_Lpid *remote_map;
remote_map = MPL_malloc(remote_size * sizeof(MPIR_Lpid), MPL_MEM_GROUP);
MPIR_ERR_CHKANDJUMP(!remote_map, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPIR_Group *mapping_group = mapping_comm->remote_group;
MPIR_Assert(mapping_group);
for (int i = 0; i < remote_size; i++) {
remote_map[i] = MPIR_Group_rank_to_lpid(mapping_group, remote_mapping[i]);
}
mpi_errno = MPIR_Group_create_map(remote_size, MPI_UNDEFINED, session_ptr, remote_map,
&remote_group);
(*newcomm_ptr)->remote_group = remote_group;

(*newcomm_ptr)->tainted = comm_ptr->tainted;
mpi_errno = MPIR_Comm_commit(*newcomm_ptr);
MPIR_ERR_CHECK(mpi_errno);
Expand Down Expand Up @@ -605,8 +625,7 @@ int MPIR_Comm_create_group_impl(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, in
(*newcomm_ptr)->local_group = group_ptr;
MPIR_Group_add_ref(group_ptr);

(*newcomm_ptr)->remote_group = group_ptr;
MPIR_Group_add_ref(group_ptr);
(*newcomm_ptr)->remote_group = NULL;
(*newcomm_ptr)->context_id = (*newcomm_ptr)->recvcontext_id;
(*newcomm_ptr)->remote_size = (*newcomm_ptr)->local_size = n;

Expand Down Expand Up @@ -913,6 +932,9 @@ int MPIR_Comm_remote_group_impl(MPIR_Comm * comm_ptr, MPIR_Group ** group_ptr)
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

/* FIXME: remove the following remote_group creation once this assertion passes */
MPIR_Assert(comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM && comm_ptr->remote_group);

/* Create a group and populate it with the local process ids */
if (!comm_ptr->remote_group) {
int n = comm_ptr->remote_size;
Expand Down Expand Up @@ -965,6 +987,7 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader,
uint64_t *remote_lpids = NULL;
int comm_info[3];
int is_low_group = 0;
MPIR_Session *session_ptr = local_comm_ptr->session_ptr;

MPIR_FUNC_ENTER;

Expand Down Expand Up @@ -1042,7 +1065,14 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader,
(*new_intercomm_ptr)->local_comm = 0;
(*new_intercomm_ptr)->is_low_group = is_low_group;

MPIR_Comm_set_session_ptr(*new_intercomm_ptr, local_comm_ptr->session_ptr);
(*new_intercomm_ptr)->local_group = local_comm_ptr->local_group;
MPIR_Group_add_ref(local_comm_ptr->local_group);

/* construct remote_group */
mpi_errno = MPIR_Group_create_map(remote_size, MPI_UNDEFINED, session_ptr, remote_lpids,
&(*new_intercomm_ptr)->remote_group);

MPIR_Comm_set_session_ptr(*new_intercomm_ptr, session_ptr);

mpi_errno = MPID_Create_intercomm_from_lpids(*new_intercomm_ptr, remote_size, remote_lpids);
if (mpi_errno)
Expand All @@ -1064,8 +1094,6 @@ int MPIR_Intercomm_create_impl(MPIR_Comm * local_comm_ptr, int local_leader,


fn_exit:
MPL_free(remote_lpids);
remote_lpids = NULL;
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
Expand Down Expand Up @@ -1106,6 +1134,15 @@ int MPIR_peer_intercomm_create(int context_id, int recvcontext_id,
}
MPID_THREAD_CS_EXIT(VCI, comm_self->mutex);

MPIR_Session *session_ptr = NULL; /* Can we just use NULL session since peer_intercomm is always temporary? */
MPIR_Lpid my_lpid = MPIR_Group_rank_to_lpid(comm_self->local_group, 0);
mpi_errno = MPIR_Group_create_stride(1, 0, session_ptr, my_lpid, 1, 1,
&(*newcomm)->local_group);
MPIR_ERR_CHECK(mpi_errno);
mpi_errno = MPIR_Group_create_stride(1, 0, session_ptr, remote_lpid, 1, 1,
&(*newcomm)->remote_group);
MPIR_ERR_CHECK(mpi_errno);

(*newcomm)->tainted = 1;
mpi_errno = MPIR_Comm_commit(*newcomm);
MPIR_ERR_CHECK(mpi_errno);
Expand Down Expand Up @@ -1222,6 +1259,37 @@ int MPIR_Intercomm_merge_impl(MPIR_Comm * comm_ptr, int high, MPIR_Comm ** new_i

MPIR_Comm_set_session_ptr(*new_intracomm_ptr, comm_ptr->session_ptr);

/* construct local_group */
MPIR_Group *new_local_group;

MPIR_Lpid *map;
map = MPL_malloc(new_size * sizeof(MPIR_Lpid), MPL_MEM_GROUP);
MPIR_ERR_CHKANDJUMP(!map, mpi_errno, MPI_ERR_OTHER, "**nomem");

int myrank;
MPIR_Group *group1, *group2;
if (local_high) {
group1 = comm_ptr->remote_group;
group2 = comm_ptr->local_group;
myrank = group1->size + group2->rank;
} else {
group1 = comm_ptr->local_group;
group2 = comm_ptr->remote_group;
myrank = group1->rank;
}
for (int i = 0; i < group1->size; i++) {
map[i] = MPIR_Group_rank_to_lpid(group1, i);
}
for (int i = 0; i < group2->size; i++) {
map[group1->size + i] = MPIR_Group_rank_to_lpid(group2, i);
}

mpi_errno = MPIR_Group_create_map(new_size, myrank, comm_ptr->session_ptr, map,
&new_local_group);

(*new_intracomm_ptr)->local_group = new_local_group;
MPIR_Group_add_ref(new_local_group);

/* Now we know which group comes first. Build the new mapping
* from the existing comm */
mpi_errno = create_and_map(comm_ptr, local_high, (*new_intracomm_ptr));
Expand Down Expand Up @@ -1260,6 +1328,7 @@ int MPIR_Intercomm_merge_impl(MPIR_Comm * comm_ptr, int high, MPIR_Comm ** new_i
(*new_intracomm_ptr)->recvcontext_id = new_context_id;

MPIR_Comm_set_session_ptr(*new_intracomm_ptr, comm_ptr->session_ptr);
(*new_intracomm_ptr)->local_group = new_local_group;

mpi_errno = create_and_map(comm_ptr, local_high, (*new_intracomm_ptr));
MPIR_ERR_CHECK(mpi_errno);
Expand Down
13 changes: 13 additions & 0 deletions src/mpi/comm/comm_split.c
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm **
(*newcomm_ptr)->rank = i;
}

mpi_errno = MPIR_Group_incl_impl(comm_ptr->local_group, new_size, mapper->src_mapping,
&(*newcomm_ptr)->local_group);
MPIR_ERR_CHECK(mpi_errno);

/* For the remote group, the situation is more complicated.
* We need to find the size of our "partner" group in the
* remote comm. The easiest way (in terms of code) is for
Expand All @@ -313,6 +317,11 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm **
for (i = 0; i < new_remote_size; i++)
mapper->src_mapping[i] = remotekeytable[i].color;

mpi_errno = MPIR_Group_incl_impl(comm_ptr->remote_group,
new_remote_size, mapper->src_mapping,
&(*newcomm_ptr)->remote_group);
MPIR_ERR_CHECK(mpi_errno);

(*newcomm_ptr)->context_id = remote_context_id;
(*newcomm_ptr)->remote_size = new_remote_size;
(*newcomm_ptr)->local_comm = 0;
Expand All @@ -331,6 +340,10 @@ int MPIR_Comm_split_impl(MPIR_Comm * comm_ptr, int color, int key, MPIR_Comm **
if (keytable[i].color == comm_ptr->rank)
(*newcomm_ptr)->rank = i;
}

mpi_errno = MPIR_Group_incl_impl(comm_ptr->local_group, new_size, mapper->src_mapping,
&(*newcomm_ptr)->local_group);
MPIR_ERR_CHECK(mpi_errno);
}

/* Inherit the error handler (if any) */
Expand Down
34 changes: 34 additions & 0 deletions src/mpi/comm/commutil.c
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ int MPII_Setup_intercomm_localcomm(MPIR_Comm * intercomm_ptr)
mpi_errno = MPII_Comm_init(localcomm_ptr);
MPIR_ERR_CHECK(mpi_errno);

MPIR_Assert(intercomm_ptr->local_group);
localcomm_ptr->local_group = intercomm_ptr->local_group;
MPIR_Group_add_ref(intercomm_ptr->local_group);

MPIR_Comm_set_session_ptr(localcomm_ptr, intercomm_ptr->session_ptr);

/* use the parent intercomm's recv ctx as the basis for our ctx */
Expand Down Expand Up @@ -687,6 +691,14 @@ int MPIR_Comm_create_subcomms(MPIR_Comm * comm)
/* Copy relevant hints to node_comm */
propagate_hints_to_subcomm(comm, comm->node_comm);

/* construct local_group */
MPIR_Group *parent_group = comm->local_group;
MPIR_Assert(parent_group);
mpi_errno = MPIR_Group_incl_impl(parent_group, num_local, local_procs,
&comm->node_comm->local_group);
MPIR_ERR_CHECK(mpi_errno);

/* mapper */
MPIR_Comm_map_irregular(comm->node_comm, comm, local_procs, num_local,
MPIR_COMM_MAP_DIR__L2L, NULL);
mpi_errno = MPIR_Comm_commit_internal(comm->node_comm);
Expand Down Expand Up @@ -714,6 +726,14 @@ int MPIR_Comm_create_subcomms(MPIR_Comm * comm)
/* Copy relevant hints to node_roots_comm */
propagate_hints_to_subcomm(comm, comm->node_roots_comm);

/* construct local_group */
MPIR_Group *parent_group = comm->local_group;
MPIR_Assert(parent_group);
mpi_errno = MPIR_Group_incl_impl(parent_group, num_external, external_procs,
&comm->node_roots_comm->local_group);
MPIR_ERR_CHECK(mpi_errno);

/* mapper */
MPIR_Comm_map_irregular(comm->node_roots_comm, comm, external_procs, num_external,
MPIR_COMM_MAP_DIR__L2L, NULL);
mpi_errno = MPIR_Comm_commit_internal(comm->node_roots_comm);
Expand Down Expand Up @@ -961,6 +981,13 @@ int MPII_Comm_copy(MPIR_Comm * comm_ptr, int size, MPIR_Info * info, MPIR_Comm *
newcomm_ptr->comm_kind = comm_ptr->comm_kind;
newcomm_ptr->local_comm = 0;

newcomm_ptr->local_group = comm_ptr->local_group;
MPIR_Group_add_ref(comm_ptr->local_group);
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) {
newcomm_ptr->remote_group = comm_ptr->remote_group;
MPIR_Group_add_ref(comm_ptr->remote_group);
}

MPIR_Comm_set_session_ptr(newcomm_ptr, comm_ptr->session_ptr);

/* There are two cases here - size is the same as the old communicator,
Expand Down Expand Up @@ -1059,6 +1086,13 @@ int MPII_Comm_copy_data(MPIR_Comm * comm_ptr, MPIR_Info * info, MPIR_Comm ** out
newcomm_ptr->comm_kind = comm_ptr->comm_kind;
newcomm_ptr->local_comm = 0;

newcomm_ptr->local_group = comm_ptr->local_group;
MPIR_Group_add_ref(comm_ptr->local_group);
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM) {
newcomm_ptr->remote_group = comm_ptr->remote_group;
MPIR_Group_add_ref(comm_ptr->remote_group);
}

if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM)
MPIR_Comm_map_dup(newcomm_ptr, comm_ptr, MPIR_COMM_MAP_DIR__L2L);
else
Expand Down
24 changes: 22 additions & 2 deletions src/mpid/ch3/src/ch3u_port.c
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,13 @@ static int MPIDI_CH3I_Initialize_tmp_comm(MPIR_Comm **comm_pptr,

MPIR_Coll_comm_init(tmp_comm);

MPIR_Lpid local_lpid = tmp_comm->dev.local_vcrt->vcr_table[0]->lpid;
MPIR_Lpid remote_lpid = tmp_comm->dev.vcrt->vcr_table[0]->lpid;
mpi_errno = MPIR_Group_create_stride(1, 0, commself_ptr->session_ptr, local_lpid, 1, 1,
&tmp_comm->local_group);
mpi_errno = MPIR_Group_create_stride(1, 0, commself_ptr->session_ptr, remote_lpid, 1, 1,
&tmp_comm->remote_group);

/* Even though this is a tmp comm and we don't call
MPI_Comm_commit, we still need to call the creation hook
because the destruction hook will be called in comm_release */
Expand Down Expand Up @@ -1337,8 +1344,6 @@ static int SetupNewIntercomm( MPIR_Comm *comm_ptr, int remote_comm_size,
intercomm->remote_size = remote_comm_size;
intercomm->local_size = comm_ptr->local_size;
intercomm->rank = comm_ptr->rank;
intercomm->local_group = NULL;
intercomm->remote_group = NULL;
intercomm->comm_kind = MPIR_COMM_KIND__INTERCOMM;
intercomm->local_comm = NULL;

Expand All @@ -1356,6 +1361,21 @@ static int SetupNewIntercomm( MPIR_Comm *comm_ptr, int remote_comm_size,
remote_translation[i].pg_rank, &intercomm->dev.vcrt->vcr_table[i]);
}

intercomm->local_group = comm_ptr->local_group;
MPIR_Group_add_ref(comm_ptr->local_group);

MPIR_Lpid *remote_map;
remote_map = MPL_malloc(remote_comm_size * sizeof(MPIR_Lpid), MPL_MEM_GROUP);
MPIR_ERR_CHKANDJUMP(!remote_map, mpi_errno, MPI_ERR_OTHER, "**nomem");
for (i=0; i < intercomm->remote_size; i++) {
MPIDI_PG_t *pg = remote_pg[remote_translation[i].pg_index];
int rank = remote_translation[i].pg_rank;
remote_map[i] = pg->vct[rank].lpid;
}
mpi_errno = MPIR_Group_create_map(remote_comm_size, MPI_UNDEFINED, comm_ptr->session_ptr,
remote_map, &intercomm->remote_group);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIR_Comm_commit(intercomm);
MPIR_ERR_CHECK(mpi_errno);

Expand Down
Loading