Skip to content

Commit

Permalink
Merge branch 'prov-rxm-mc' into devel
Browse files Browse the repository at this point in the history
  • Loading branch information
grom72 committed Dec 15, 2022
2 parents 19b0e50 + 401e416 commit 4f1e5ab
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 23 deletions.
8 changes: 8 additions & 0 deletions include/ofi_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ struct util_coll_mc {
uint16_t group_id;
uint16_t seq;
ofi_atomic32_t ref;
struct fid_mc *peer_mc;
};

struct util_av_set {
Expand Down Expand Up @@ -1134,6 +1135,13 @@ enum {
OFI_OPT_TCP_FI_ADDR = -FI_PROV_SPECIFIC_TCP
};

/*
* Peer mc support.
*/
struct fi_peer_mc_context {
size_t size;
struct fid_mc *mc_fid;
};

#ifdef __cplusplus
}
Expand Down
17 changes: 15 additions & 2 deletions prov/coll/src/coll_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ void coll_join_comp(struct util_coll_operation *coll_op)
struct fi_eq_entry entry;
struct coll_ep *ep;
struct coll_eq *eq;
uint64_t flags;

ep = container_of(coll_op->ep, struct coll_ep, util_ep.ep_fid);
eq = container_of(ep->util_ep.eq, struct coll_eq, util_eq.eq_fid);
Expand All @@ -709,8 +710,11 @@ void coll_join_comp(struct util_coll_operation *coll_op)
entry.fid = &coll_op->mc->mc_fid.fid;
entry.context = coll_op->context;

flags = FI_COLLECTIVE;
if (coll_op->mc->peer_mc)
flags |= FI_PEER;
if (fi_eq_write(eq->peer_eq, FI_JOIN_COMPLETE, &entry,
sizeof(struct fi_eq_entry), FI_COLLECTIVE) < 0)
sizeof(struct fi_eq_entry), flags) < 0)
FI_WARN(ep->util_ep.domain->fabric->prov, FI_LOG_DOMAIN,
"join collective - eq write failed\n");

Expand Down Expand Up @@ -911,6 +915,7 @@ static struct util_coll_mc *coll_create_mc(struct util_av_set *av_set,
int coll_join_collective(struct fid_ep *ep, const void *addr,
uint64_t flags, struct fid_mc **mc, void *context)
{
struct fi_peer_mc_context *peer_context;
struct util_coll_mc *new_coll_mc;
struct util_av_set *av_set;
struct util_coll_mc *coll_mc;
Expand All @@ -924,6 +929,11 @@ int coll_join_collective(struct fid_ep *ep, const void *addr,
if (!(flags & FI_COLLECTIVE))
return -FI_ENOSYS;

if (flags & FI_PEER) {
peer_context = context;
context = peer_context->mc_fid;
}

c_addr = (struct fi_collective_addr *)addr;
coll_addr = c_addr->coll_addr;
set = c_addr->set;
Expand All @@ -941,11 +951,14 @@ int coll_join_collective(struct fid_ep *ep, const void *addr,
if (!new_coll_mc)
return -FI_ENOMEM;

if (flags & FI_PEER)
new_coll_mc->peer_mc = context;

/* get the rank */
coll_find_local_rank(ep, new_coll_mc);
coll_find_local_rank(ep, coll_mc);

join_op = coll_create_op(ep, coll_mc, UTIL_COLL_JOIN_OP, flags,
join_op = coll_create_op(ep, new_coll_mc, UTIL_COLL_JOIN_OP, flags,
context, coll_join_comp);
if (!join_op) {
ret = -FI_ENOMEM;
Expand Down
10 changes: 10 additions & 0 deletions prov/rxm/src/rxm.h
Original file line number Diff line number Diff line change
Expand Up @@ -1001,4 +1001,14 @@ rxm_multi_recv_entry_get(struct rxm_ep *rxm_ep, const struct iovec *iov,
void **desc, size_t count, fi_addr_t src_addr,
uint64_t tag, uint64_t ignore, void *context,
uint64_t flags);

struct rxm_mc {
struct fid_mc mc_fid;
void *context;
struct util_av_set *av_set;
struct fid_mc *util_coll_mc_fid;
int util_coll_join_completed;
struct fid_mc *offload_coll_mc_fid;
int offload_coll_join_completed;
};
#endif
140 changes: 119 additions & 21 deletions prov/rxm/src/rxm_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -380,17 +380,100 @@ static int rxm_getname(fid_t fid, void *addr, size_t *addrlen)
return fi_getname(&rxm_ep->msg_pep->fid, addr, addrlen);
}

static int rxm_mc_close(struct fid *fid)
{
struct rxm_mc *rxm_mc;

rxm_mc = container_of(fid, struct rxm_mc, mc_fid.fid);

ofi_atomic_dec32(&rxm_mc->av_set->ref);

if (rxm_mc->util_coll_mc_fid) {
assert (rxm_mc->util_coll_join_completed == 1);
fi_close(&rxm_mc->util_coll_mc_fid->fid);
}

if (rxm_mc->offload_coll_mc_fid) {
assert (rxm_mc->offload_coll_join_completed == 1);
fi_close(&rxm_mc->offload_coll_mc_fid->fid);
}

free(rxm_mc);

return FI_SUCCESS;
}

static struct fi_ops rxm_mc_fi_ops = {
.size = sizeof(struct fi_ops),
.close = rxm_mc_close,
.bind = fi_no_bind,
.control = fi_no_control,
.ops_open = fi_no_ops_open,
};

static struct rxm_mc *rxm_create_mc(struct util_av_set *av_set,
void *context)
{
struct rxm_mc *rxm_mc;

rxm_mc = calloc(1, sizeof(*rxm_mc));
if (!rxm_mc)
return NULL;

rxm_mc->mc_fid.fid.fclass = FI_CLASS_MC;
rxm_mc->mc_fid.fid.context = context;
rxm_mc->mc_fid.fid.ops = &rxm_mc_fi_ops;
rxm_mc->mc_fid.fi_addr = (uintptr_t) rxm_mc;

ofi_atomic_inc32(&av_set->ref);
rxm_mc->av_set = av_set;

return rxm_mc;
}

static int rxm_join_coll(struct fid_ep *ep, const void *addr, uint64_t flags,
struct fid_mc **mc, void *context)
struct fid_mc **mc, void *context)
{
struct fi_collective_addr *c_addr;
const struct fid_av_set *set;
struct util_av_set *av_set;
struct rxm_mc *rxm_mc;
struct rxm_ep *rxm_ep;
int ret;
struct fi_peer_mc_context peer_context = {
.size = sizeof(struct fi_peer_mc_context),
};

if (!(flags & FI_COLLECTIVE))
return -FI_ENOSYS;

rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid);
c_addr = (struct fi_collective_addr *)addr;
set = c_addr->set;
av_set = container_of(set, struct util_av_set, av_set_fid);

return fi_join(rxm_ep->util_coll_ep, addr, flags, mc, context);
rxm_mc = rxm_create_mc(av_set, context);
if (!rxm_mc)
return -FI_ENOMEM;

peer_context.mc_fid = &rxm_mc->mc_fid;
rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid);
ret = fi_join(rxm_ep->util_coll_ep, addr, flags | FI_PEER,
&rxm_mc->util_coll_mc_fid, &peer_context);
if (ret) {
fi_close(&rxm_mc->mc_fid.fid);
} else if (rxm_ep->offload_coll_ep) {
ret = fi_join(rxm_ep->offload_coll_ep, addr, flags | FI_PEER,
&rxm_mc->offload_coll_mc_fid, &peer_context);
if (ret) {
/*mark util_coll_mc to be removed as soon as
util_coll_ep:fi_join() complets */
rxm_mc->util_coll_join_completed = -1;
}
}
if (!ret) {
*mc = &rxm_mc->mc_fid;
}
return ret;
}

static struct fi_ops_cm rxm_ops_cm = {
Expand Down Expand Up @@ -972,11 +1055,14 @@ struct fid_ep *get_coll_ep(struct rxm_ep *rxm_ep, uint64_t flags, int coll_op)

static int rxm_ep_init_coll_req(struct rxm_ep *rxm_ep, int coll_op, uint64_t flags,
void *context, struct rxm_coll_buf **req,
struct fid_ep **coll_ep)
struct fid_ep **coll_ep, fi_addr_t *coll_addr)
{
ofi_ep_lock_acquire(&rxm_ep->util_ep);
struct rxm_mc *rxm_mc;
struct util_coll_mc *coll_mc;

ofi_ep_lock_acquire(&rxm_ep->util_ep);
(*req) = rxm_get_coll_buf(rxm_ep);
ofi_ep_lock_release(&rxm_ep->util_ep);
ofi_ep_lock_release(&rxm_ep->util_ep);

if (!(*req))
return -FI_EAGAIN;
Expand All @@ -985,13 +1071,25 @@ static int rxm_ep_init_coll_req(struct rxm_ep *rxm_ep, int coll_op, uint64_t fla
(*req)->flags = flags;
(*req)->app_context = context;

if (flags & FI_PEER_TRANSFER)
rxm_mc = (struct rxm_mc*) ((uintptr_t) *coll_addr);
if ( (flags & FI_PEER_TRANSFER) ||
!(rxm_ep->offload_coll_mask & BIT(coll_op))) {
if (rxm_mc->util_coll_join_completed != 1) {
return -FI_EAGAIN;
}
coll_mc = container_of(rxm_mc->util_coll_mc_fid,
struct util_coll_mc, mc_fid);
*coll_addr = fi_mc_addr(&coll_mc->mc_fid);
(*coll_ep) = rxm_ep->util_coll_ep;
else if (rxm_ep->offload_coll_mask & BIT(coll_op))
} else {
if (rxm_mc->offload_coll_join_completed != 1) {
return -FI_EAGAIN;
}
coll_mc = container_of(rxm_mc->offload_coll_mc_fid,
struct util_coll_mc, mc_fid);
*coll_addr = fi_mc_addr(&coll_mc->mc_fid);
(*coll_ep) = rxm_ep->offload_coll_ep;
else
(*coll_ep) = rxm_ep->util_coll_ep;

}
return 0;
}

Expand All @@ -1011,10 +1109,10 @@ ssize_t rxm_ep_barrier2(struct fid_ep *ep, fi_addr_t coll_addr,
struct rxm_coll_buf *req;
ssize_t ret;

rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid.fid);
rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid.fid);

ret = rxm_ep_init_coll_req(rxm_ep, FI_BARRIER, flags, context,
&req, &coll_ep);
&req, &coll_ep, &coll_addr);
if (ret)
return ret;

Expand Down Expand Up @@ -1042,10 +1140,10 @@ ssize_t rxm_ep_allreduce(struct fid_ep *ep, const void *buf, size_t count,
struct rxm_coll_buf *req;
ssize_t ret;

rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid.fid);
rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid.fid);

ret = rxm_ep_init_coll_req(rxm_ep, FI_ALLREDUCE, flags, context,
&req, &coll_ep);
&req, &coll_ep, &coll_addr);
if (ret)
return ret;

Expand All @@ -1069,10 +1167,10 @@ ssize_t rxm_ep_allgather(struct fid_ep *ep, const void *buf, size_t count,
struct rxm_coll_buf *req;
ssize_t ret;

rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid.fid);
rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid.fid);

ret = rxm_ep_init_coll_req(rxm_ep, FI_ALLGATHER, flags, context,
&req, &coll_ep);
&req, &coll_ep, &coll_addr);
if (ret)
return ret;

Expand All @@ -1097,10 +1195,10 @@ ssize_t rxm_ep_scatter(struct fid_ep *ep, const void *buf, size_t count,
struct rxm_coll_buf *req;
ssize_t ret;

rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid.fid);
rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid.fid);

ret = rxm_ep_init_coll_req(rxm_ep, FI_SCATTER, flags, context,
&req, &coll_ep);
&req, &coll_ep, &coll_addr);
if (ret)
return ret;

Expand All @@ -1124,10 +1222,10 @@ ssize_t rxm_ep_broadcast(struct fid_ep *ep, void *buf, size_t count,
struct rxm_coll_buf *req;
ssize_t ret;

rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid.fid);
rxm_ep = container_of(ep, struct rxm_ep, util_ep.ep_fid.fid);

ret = rxm_ep_init_coll_req(rxm_ep, FI_BROADCAST, flags, context,
&req, &coll_ep);
&req, &coll_ep, &coll_addr);
if (ret)
return ret;

Expand Down
55 changes: 55 additions & 0 deletions prov/rxm/src/rxm_eq.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,60 @@ static struct fi_ops rxm_eq_fi_ops = {
.ops_open = fi_no_ops_open,
};

ssize_t rxm_eq_write(struct fid_eq *eq_fid, uint32_t event,
const void *buf, size_t len, uint64_t flags)
{
struct fi_eq_entry entry;
const struct fi_eq_entry *in_entry = buf;
struct rxm_mc *mc;

if (event != FI_JOIN_COMPLETE) {
return ofi_eq_write(eq_fid, event, buf, len, flags);
}

if (flags & FI_PEER) {
mc = container_of(in_entry->context, struct rxm_mc, mc_fid);
} else {
mc = in_entry->context;
}

if (in_entry->fid == &mc->util_coll_mc_fid->fid) {
/* cleanup after partially executed fi_join() */
if (mc->util_coll_join_completed == -1) {
mc->util_coll_join_completed = 1;
fi_close(&mc->mc_fid.fid);
} else {
assert (mc->util_coll_join_completed == 0);
mc->util_coll_join_completed = 1;
}
} else if (in_entry->fid == &mc->offload_coll_mc_fid->fid) {
assert (mc->offload_coll_join_completed == 0);
mc->offload_coll_join_completed = 1;
} else {
assert(0); /* we do not expect any other fid */
}

if (mc->util_coll_join_completed == 1 &&
(mc->offload_coll_join_completed == 1 ||
!mc->offload_coll_mc_fid)) {
memset(&entry, 0, sizeof(entry));
entry.context = mc->context;
entry.fid = &mc->mc_fid.fid;
return ofi_eq_write(eq_fid, event, &entry, len, flags);
}

return len;
};

static struct fi_ops_eq rxm_eq_ops = {
.size = sizeof(struct fi_ops_eq),
.read = ofi_eq_read,
.readerr = ofi_eq_readerr,
.sread = ofi_eq_sread,
.write = rxm_eq_write,
.strerror = ofi_eq_strerror,
};

int rxm_eq_open(struct fid_fabric *fabric_fid, struct fi_eq_attr *attr,
struct fid_eq **eq_fid, void *context)
{
Expand Down Expand Up @@ -105,6 +159,7 @@ int rxm_eq_open(struct fid_fabric *fabric_fid, struct fi_eq_attr *attr,
}

rxm_eq->util_eq.eq_fid.fid.ops = &rxm_eq_fi_ops;
rxm_eq->util_eq.eq_fid.ops = &rxm_eq_ops;
*eq_fid = &rxm_eq->util_eq.eq_fid;
return 0;

Expand Down

0 comments on commit 4f1e5ab

Please sign in to comment.