Skip to content

Commit

Permalink
prov/sharp: Implement sharp_domain2() using sharp_coll API
Browse files Browse the repository at this point in the history
Signed-off-by: Lukasz Dorau <[email protected]>
  • Loading branch information
ldorau committed Dec 21, 2022
1 parent f6743af commit 3803925
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 13 deletions.
2 changes: 2 additions & 0 deletions prov/sharp/src/sharp.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ int sharp_oob_barrier(void* context);

int sharp_oob_gather(void * context, int root, void *sbuf, void *rbuf, int len);

int sharp_oob_progress(void);

/*
int sharp_coll_init(struct sharp_coll_init_spec *sharp_coll_spec,
struct sharp_coll_context **sharp_coll_context);
Expand Down
23 changes: 22 additions & 1 deletion prov/sharp/src/sharp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,25 @@ ssize_t sharp_peer_xfer_error(struct fid_ep *ep_fid, struct fi_cq_err_entry *cqe
FI_WARN(ep->util_ep.domain->fabric->prov, FI_LOG_DOMAIN,
"collective - cq write failed\n");
return 0;
}
}

int sharp_oob_bcast(void* context, void* buffer, int len, int root)
{
return 0;
}

int sharp_oob_barrier(void* context)
{
return 0;
}

int sharp_oob_gather(void * context, int root, void *sbuf, void *rbuf, int len)
{
return 0;
}

int
sharp_oob_progress(void)
{
return 0;
}
86 changes: 74 additions & 12 deletions prov/sharp/src/sharp_domain.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <string.h>

#include "sharp.h"
#include "mocks/api/sharp.h"

#include "../../coll/src/coll.h" /* for coll_av_open */

Expand Down Expand Up @@ -156,7 +157,8 @@ fid_domain_init(struct fid_domain **domain_fid,
(*domain_fid)->mr = mr;
}

int sharp_domain2(struct fid_fabric *fabric, struct fi_info *info,
int
sharp_domain2(struct fid_fabric *fabric, struct fi_info *info,
struct fid_domain **domain_fid, uint64_t flags, void *context)
{
int ret;
Expand Down Expand Up @@ -185,12 +187,8 @@ int sharp_domain2(struct fid_fabric *fabric, struct fi_info *info,

ret = ofi_domain_init(fabric, info, &domain->util_domain, context,
OFI_LOCK_MUTEX);


if (ret) {
free(domain);
return ret;
}
if (ret)
goto err_free_domain;

ofi_atomic_initialize32(&domain->ref, 0);
domain->util_domain.threading = FI_THREAD_UNSPEC;
Expand All @@ -208,11 +206,6 @@ int sharp_domain2(struct fid_fabric *fabric, struct fi_info *info,
&sharp_domain_ops, &sharp_domain_mr_ops);


/*
XXX maped to
int sharp_coll_init(struct sharp_coll_init_spec *sharp_coll_spec,
struct sharp_coll_context **sharp_coll_context);
*/
#if 0
struct sharp_coll_init_spec {
uint64_t job_id; /**< Job unique ID */
Expand All @@ -228,5 +221,74 @@ struct sharp_coll_init_spec {
int reserved[4]; /**< Reserved */
};
#endif

struct sharp_coll_out_of_band_colls oob_colls = {
.barrier = sharp_oob_barrier,
.bcast = sharp_oob_bcast,
.gather = sharp_oob_gather
};

struct sharp_coll_config config = { /* XXX */
/* ??? */ .ib_dev_list = NULL, /**< IB device name, port list. (const char *) */
.user_progress_num_polls = -1, /**< Number of polls to do before calling user progress. (int) */
/* ??? */ .coll_timeout = 0, /**< Timeout (msec) for collective operation, -1 - infinite (int) */
};

struct sharp_coll_init_spec sharp_coll_spec = {
/* ??? */ .job_id = 0, /**< Job unique ID */
.world_rank = 0, /**< Global unique process id. */
.world_size = 0, /**< Num of processes in the job. */
/* ??? */ .progress_func = sharp_oob_progress, /**< External progress function. */
/* ??? */ .group_channel_idx = 0, /**< local group channel index(0 .. (max - 1))*/
/* ??? */ .config = config, /**< @ref sharp_coll_config "SHARP COLL Configuration". */
/* ??? */ .oob_colls = oob_colls, /**< @ref sharp_coll_out_of_band_colls "List of OOB collectives". */
.world_local_rank = 0, /**< relative rank of this process on this node within its job. */
/* ??? */ .enable_thread_support = 0, /**< enable multi threaded support. */
.oob_ctx = NULL, /**< context for OOB functions in sharp_coll_init */
};

char *e;

/* set sharp_coll_spec.world_size */
if ((((e = getenv("PMI_SIZE")) && *e)) // MPICH & IMPI
|| (((e = getenv("OMPI_COMM_WORLD_SIZE")) && *e)) // OMPI
|| (((e = getenv("MPI_NRANKS")) && *e)) // Platform MPI
|| (((e = getenv("MPIRUN_NPROCS")) && *e)) // older MPICH
|| (((e = getenv("SLURM_NTASKS")) && *e)) // SLURM
|| (((e = getenv("SLURM_NPROCS")) && *e))) // older SLURM
{
sharp_coll_spec.world_size = atoi(e);
}

/* set sharp_coll_spec.world_rank */
if ((((e = getenv("PMI_RANK")) && *e)) // MPICH and *_SIZE
|| (((e = getenv("OMPI_COMM_WORLD_RANK")) && *e)) // OMPI and *_SIZE
|| (((e = getenv("MPI_RANKID")) && *e)) // Platform MPI and *_NRANKS
|| (((e = getenv("MPIRUN_RANK")) && *e)) // older MPICH and *_NPROCS
|| (((e = getenv("PSC_MPI_RANK")) && *e)) // pathscale MPI
|| (((e = getenv("SLURM_TASKID")) && *e)) // SLURM
|| (((e = getenv("SLURM_PROCID")) && *e))) // older SLURM
{
sharp_coll_spec.world_rank = atoi(e);
}

/* set sharp_coll_spec.world_local_rank */
if ((((e = getenv("MPI_LOCALRANKID")) && *e)) // MPICH and IMPI
|| (((e = getenv("OMPI_COMM_WORLD_LOCAL_RANK")) && *e)) // OMPI
|| (((e = getenv("MPI_LOCALRANKID")) && *e)) // Platform MPI
|| (((e = getenv("PSC_MPI_NODE_RANK")) && *e)) // pathscale MPI
|| (((e = getenv("SLURM_LOCALID")) && *e))) // SLURM
{
sharp_coll_spec.world_local_rank = atoi(e);
}

ret = sharp_coll_init(&sharp_coll_spec, (struct sharp_coll_context **)&domain->sharp_context);
if (ret)
goto err_free_domain;

return 0;

err_free_domain:
free(domain);
return ret;
}

0 comments on commit 3803925

Please sign in to comment.