diff --git a/prov/psm2/src/psmx2.h b/prov/psm2/src/psmx2.h index 88beeaef45c..775dc5fb5cc 100644 --- a/prov/psm2/src/psmx2.h +++ b/prov/psm2/src/psmx2.h @@ -238,12 +238,16 @@ struct psmx2_am_request { size_t len_read; } read; struct { - uint8_t *buf; + union { + uint8_t *buf; /* for result_count == 1 */ + size_t iov_count; /* for result_count > 1 */ + }; size_t len; uint64_t addr; uint64_t key; void *context; uint8_t *result; + int datatype; } atomic; }; uint64_t cq_flags; @@ -252,7 +256,10 @@ struct psmx2_am_request { int no_event; int error; struct slist_entry list_entry; - struct iovec iov[0]; /* for readv, must be the last field */ + union { + struct iovec iov[0]; /* for readv, must be the last field */ + struct fi_ioc ioc[0]; /* for atomic read, must be the last field */ + }; }; #define PSMX2_IOV_PROTO_PACK 0 @@ -401,8 +408,11 @@ enum psmx2_triggered_op { PSMX2_TRIGGERED_READ, PSMX2_TRIGGERED_READV, PSMX2_TRIGGERED_ATOMIC_WRITE, + PSMX2_TRIGGERED_ATOMIC_WRITEV, PSMX2_TRIGGERED_ATOMIC_READWRITE, + PSMX2_TRIGGERED_ATOMIC_READWRITEV, PSMX2_TRIGGERED_ATOMIC_COMPWRITE, + PSMX2_TRIGGERED_ATOMIC_COMPWRITEV, }; struct psmx2_trigger { @@ -531,6 +541,19 @@ struct psmx2_trigger { void *context; uint64_t flags; } atomic_write; + struct { + struct fid_ep *ep; + const struct fi_ioc *iov; + size_t count; + void *desc; + fi_addr_t dest_addr; + uint64_t addr; + uint64_t key; + enum fi_datatype datatype; + enum fi_op atomic_op; + void *context; + uint64_t flags; + } atomic_writev; struct { struct fid_ep *ep; const void *buf; @@ -546,6 +569,22 @@ struct psmx2_trigger { void *context; uint64_t flags; } atomic_readwrite; + struct { + struct fid_ep *ep; + const struct fi_ioc *iov; + size_t count; + void **desc; + struct fi_ioc *resultv; + void **result_desc; + size_t result_count; + fi_addr_t dest_addr; + uint64_t addr; + uint64_t key; + enum fi_datatype datatype; + enum fi_op atomic_op; + void *context; + uint64_t flags; + } atomic_readwritev; struct { struct fid_ep *ep; const void *buf; @@ -563,6 +602,25 @@ struct psmx2_trigger { void *context; uint64_t flags; } atomic_compwrite; + struct { + struct fid_ep *ep; + const struct fi_ioc *iov; + size_t count; + void **desc; + const struct fi_ioc *comparev; + void **compare_desc; + size_t compare_count; + struct fi_ioc *resultv; + void **result_desc; + size_t result_count; + fi_addr_t dest_addr; + uint64_t addr; + uint64_t key; + enum fi_datatype datatype; + enum fi_op atomic_op; + void *context; + uint64_t flags; + } atomic_compwritev; }; struct psmx2_trigger *next; /* used for randomly accessed trigger list */ struct slist_entry list_entry; /* used for ready-to-fire trigger queue */ diff --git a/prov/psm2/src/psmx2_atomic.c b/prov/psm2/src/psmx2_atomic.c index 2921db7dce5..642472d2c08 100644 --- a/prov/psm2/src/psmx2_atomic.c +++ b/prov/psm2/src/psmx2_atomic.c @@ -61,6 +61,50 @@ void psmx2_atomic_fini(void) fastlock_destroy(&psmx2_atomic_lock); } +static inline void psmx2_ioc_read(const struct fi_ioc *ioc, size_t count, + int datatype, uint8_t *buf, size_t len) +{ + int i; + size_t copy_len; + + for (i=0; i len) + copy_len = len; + memcpy(buf, ioc[i].addr, copy_len); + buf += copy_len; + len -= copy_len; + } +} + +static inline void psmx2_ioc_write(struct fi_ioc *ioc, size_t count, + int datatype, const uint8_t *buf, size_t len) +{ + int i; + size_t copy_len; + + for (i=0; i len) + copy_len = len; + memcpy(ioc[i].addr, buf, copy_len); + buf += copy_len; + len -= copy_len; + } +} + +static inline size_t psmx2_ioc_size(const struct fi_ioc *ioc, size_t count, + int datatype) +{ + int i; + size_t len = 0; + + for (i=0; iatomic.len == len); - if (!op_error) - memcpy(req->atomic.result, src, len); + if (!op_error) { + if (req->atomic.result) + memcpy(req->atomic.result, src, len); + else + psmx2_ioc_write(req->ioc, req->atomic.iov_count, + req->atomic.datatype, src, len); + } if (req->ep->send_cq && !req->no_event) { event = psmx2_cq_create_event( @@ -851,6 +900,7 @@ ssize_t psmx2_atomic_write_generic(struct fid_ep *ep, req->atomic.addr = addr; req->atomic.key = key; req->atomic.context = context; + req->atomic.datatype = datatype; req->ep = ep_priv; req->cq_flags = FI_WRITE | FI_ATOMIC; @@ -869,6 +919,154 @@ ssize_t psmx2_atomic_write_generic(struct fid_ep *ep, return 0; } +ssize_t psmx2_atomic_writev_generic(struct fid_ep *ep, + const struct fi_ioc *iov, + void **desc, size_t count, + fi_addr_t dest_addr, + uint64_t addr, uint64_t key, + enum fi_datatype datatype, + enum fi_op op, void *context, + uint64_t flags) +{ + struct psmx2_fid_ep *ep_priv; + struct psmx2_fid_av *av; + struct psmx2_epaddr_context *epaddr_context; + struct psmx2_am_request *req; + psm2_amarg_t args[8]; + psm2_epaddr_t psm2_epaddr; + uint8_t vlane; + int am_flags = PSM2_AM_FLAG_ASYNC; + int chunk_size; + size_t idx; + size_t len; + uint8_t *buf; + int err; + + ep_priv = container_of(ep, struct psmx2_fid_ep, ep); + + if (flags & FI_TRIGGER) { + struct psmx2_trigger *trigger; + struct fi_triggered_context *ctxt = context; + + trigger = calloc(1, sizeof(*trigger)); + if (!trigger) + return -FI_ENOMEM; + + trigger->op = PSMX2_TRIGGERED_ATOMIC_WRITEV; + trigger->cntr = container_of(ctxt->trigger.threshold.cntr, + struct psmx2_fid_cntr, cntr); + trigger->threshold = ctxt->trigger.threshold.threshold; + trigger->atomic_writev.ep = ep; + trigger->atomic_writev.iov = iov; + trigger->atomic_writev.count = count; + trigger->atomic_writev.desc = desc; + trigger->atomic_writev.dest_addr = dest_addr; + trigger->atomic_writev.addr = addr; + trigger->atomic_writev.key = key; + trigger->atomic_writev.datatype = datatype; + trigger->atomic_writev.atomic_op = op; + trigger->atomic_writev.context = context; + trigger->atomic_writev.flags = flags & ~FI_TRIGGER; + + psmx2_cntr_add_trigger(trigger->cntr, trigger); + return 0; + } + + if (!iov || !count) + return -FI_EINVAL; + + while (count && !iov[count-1].count) + count--; + + if (datatype < 0 || datatype >= FI_DATATYPE_LAST) + return -FI_EINVAL; + + if (op < 0 || op >= FI_ATOMIC_OP_LAST) + return -FI_EINVAL; + + av = ep_priv->av; + if (av && av->type == FI_AV_TABLE) { + idx = dest_addr; + if (idx >= av->last) + return -FI_EINVAL; + + psm2_epaddr = av->epaddrs[idx]; + vlane = av->vlanes[idx]; + } else { + if (!dest_addr) + return -FI_EINVAL; + + psm2_epaddr = PSMX2_ADDR_TO_EP(dest_addr); + vlane = PSMX2_ADDR_TO_VL(dest_addr); + } + + len = psmx2_ioc_size(iov, count, datatype); + + epaddr_context = psm2_epaddr_getctxt((void *)psm2_epaddr); + if (epaddr_context->epid == ep_priv->domain->psm2_epid) { + buf = malloc(len); + if (!buf) + return -FI_ENOMEM; + + psmx2_ioc_read(iov, count, datatype, buf, len); + + err = psmx2_atomic_self(PSMX2_AM_REQ_ATOMIC_WRITE, ep_priv, + ep_priv->domain->eps[vlane], + buf, count, NULL, NULL, NULL, NULL, + NULL, addr, key, datatype, op, + context, flags); + + free(buf); + return err; + } + + chunk_size = psmx2_am_param.max_request_short; + if (len > chunk_size) + return -FI_EMSGSIZE; + + if (count > 1) { + req = malloc(sizeof(*req) + len); + if (!req) + return -FI_ENOMEM; + + buf = (uint8_t *)req + sizeof(*req); + memset(req, 0, sizeof(*req)); + psmx2_ioc_read(iov, count, datatype, buf, len); + } else { + req = calloc(1, sizeof(*req)); + if (!req) + return -FI_ENOMEM; + buf = iov[0].addr; + } + + req->no_event = (flags & PSMX2_NO_COMPLETION) || + (ep_priv->send_selective_completion && !(flags & FI_COMPLETION)); + + req->op = PSMX2_AM_REQ_ATOMIC_WRITE; + req->atomic.buf = (void *)buf; + req->atomic.len = len; + req->atomic.addr = addr; + req->atomic.key = key; + req->atomic.context = context; + req->atomic.datatype = datatype; + req->ep = ep_priv; + req->cq_flags = FI_WRITE | FI_ATOMIC; + + args[0].u32w0 = PSMX2_AM_REQ_ATOMIC_WRITE; + PSMX2_AM_SET_DST(args[0].u32w0, vlane); + args[0].u32w1 = len / fi_datatype_size(datatype); + args[1].u64 = (uint64_t)(uintptr_t)req; + args[2].u64 = addr; + args[3].u64 = key; + args[4].u32w0 = datatype; + args[4].u32w1 = op; + psm2_am_request_short(psm2_epaddr, + PSMX2_AM_ATOMIC_HANDLER, args, 5, + (void *)buf, len, am_flags, NULL, NULL); + + return 0; +} + static ssize_t psmx2_atomic_write(struct fid_ep *ep, const void *buf, size_t count, void *desc, @@ -889,9 +1087,18 @@ static ssize_t psmx2_atomic_writemsg(struct fid_ep *ep, const struct fi_msg_atomic *msg, uint64_t flags) { - if (!msg || msg->iov_count != 1 || !msg->msg_iov || !msg->rma_iov) + if (!msg || !msg->iov_count || !msg->msg_iov || !msg->rma_iov || + msg->rma_iov_count != 1) return -FI_EINVAL; + if (msg->iov_count > 1) + return psmx2_atomic_writev_generic(ep, msg->msg_iov, msg->desc, + msg->iov_count, msg->addr, + msg->rma_iov[0].addr, + msg->rma_iov[0].key, + msg->datatype, msg->op, + msg->context, flags); + return psmx2_atomic_write_generic(ep, msg->msg_iov[0].addr, msg->msg_iov[0].count, msg->desc ? msg->desc[0] : NULL, @@ -908,12 +1115,23 @@ static ssize_t psmx2_atomic_writev(struct fid_ep *ep, enum fi_datatype datatype, enum fi_op op, void *context) { - if (!iov || count != 1) + struct psmx2_fid_ep *ep_priv; + + ep_priv = container_of(ep, struct psmx2_fid_ep, ep); + + if (!iov || !count) return -FI_EINVAL; - return psmx2_atomic_write(ep, iov->addr, iov->count, - desc ? desc[0] : NULL, dest_addr, addr, - key, datatype, op, context); + if (count > 1) + return psmx2_atomic_writev_generic(ep, iov, desc, count, + dest_addr, addr, key, + datatype, op, context, + ep_priv->tx_flags); + + return psmx2_atomic_write_generic(ep, iov->addr, iov->count, + desc ? desc[0] : NULL, dest_addr, + addr, key, datatype, op, context, + ep_priv->tx_flags); } static ssize_t psmx2_atomic_inject(struct fid_ep *ep, @@ -1047,6 +1265,7 @@ ssize_t psmx2_atomic_readwrite_generic(struct fid_ep *ep, req->atomic.key = key; req->atomic.context = context; req->atomic.result = result; + req->atomic.datatype = datatype; req->ep = ep_priv; if (op == FI_ATOMIC_READ) req->cq_flags = FI_READ | FI_ATOMIC; @@ -1068,6 +1287,210 @@ ssize_t psmx2_atomic_readwrite_generic(struct fid_ep *ep, return 0; } +ssize_t psmx2_atomic_readwritev_generic(struct fid_ep *ep, + const struct fi_ioc *iov, + void **desc, size_t count, + struct fi_ioc *resultv, + void **result_desc, + size_t result_count, + fi_addr_t dest_addr, + uint64_t addr, uint64_t key, + enum fi_datatype datatype, + enum fi_op op, void *context, + uint64_t flags) +{ + struct psmx2_fid_ep *ep_priv; + struct psmx2_fid_av *av; + struct psmx2_epaddr_context *epaddr_context; + struct psmx2_am_request *req; + psm2_amarg_t args[8]; + psm2_epaddr_t psm2_epaddr; + uint8_t vlane; + int am_flags = PSM2_AM_FLAG_ASYNC; + int chunk_size; + size_t idx; + size_t len, result_len, iov_size; + uint8_t *buf, *result; + void *desc0, *result_desc0; + int err; + + ep_priv = container_of(ep, struct psmx2_fid_ep, ep); + + if (flags & FI_TRIGGER) { + struct psmx2_trigger *trigger; + struct fi_triggered_context *ctxt = context; + + trigger = calloc(1, sizeof(*trigger)); + if (!trigger) + return -FI_ENOMEM; + + trigger->op = PSMX2_TRIGGERED_ATOMIC_READWRITEV; + trigger->cntr = container_of(ctxt->trigger.threshold.cntr, + struct psmx2_fid_cntr, cntr); + trigger->threshold = ctxt->trigger.threshold.threshold; + trigger->atomic_readwritev.ep = ep; + trigger->atomic_readwritev.iov = iov; + trigger->atomic_readwritev.count = count; + trigger->atomic_readwritev.desc = desc; + trigger->atomic_readwritev.resultv = resultv; + trigger->atomic_readwritev.result_desc = result_desc; + trigger->atomic_readwritev.result_count = result_count; + trigger->atomic_readwritev.dest_addr = dest_addr; + trigger->atomic_readwritev.addr = addr; + trigger->atomic_readwritev.key = key; + trigger->atomic_readwritev.datatype = datatype; + trigger->atomic_readwritev.atomic_op = op; + trigger->atomic_readwritev.context = context; + trigger->atomic_readwritev.flags = flags & ~FI_TRIGGER; + + psmx2_cntr_add_trigger(trigger->cntr, trigger); + return 0; + } + + if (((!iov || !count) && op != FI_ATOMIC_READ) || !resultv || + !result_count) + return -FI_EINVAL; + + while (count && !iov[count-1].count) + count--; + + while (result_count && !resultv[result_count-1].count) + result_count--; + + if (datatype < 0 || datatype >= FI_DATATYPE_LAST) + return -FI_EINVAL; + + if (op < 0 || op >= FI_ATOMIC_OP_LAST) + return -FI_EINVAL; + + result_len = psmx2_ioc_size(resultv, result_count, datatype); + + if (op != FI_ATOMIC_READ) { + buf = iov[0].addr; /* as default for count == 1 */ + len = psmx2_ioc_size(iov, count, datatype); + desc0 = desc ? desc[0] : NULL; + } else { + buf = NULL; + len = result_len; + desc0 = NULL; + } + + if (result_len < len) + return -FI_EINVAL; + + av = ep_priv->av; + if (av && av->type == FI_AV_TABLE) { + idx = dest_addr; + if (idx >= av->last) + return -FI_EINVAL; + + psm2_epaddr = av->epaddrs[idx]; + vlane = av->vlanes[idx]; + } else { + if (!dest_addr) + return -FI_EINVAL; + + psm2_epaddr = PSMX2_ADDR_TO_EP(dest_addr); + vlane = PSMX2_ADDR_TO_VL(dest_addr); + } + + epaddr_context = psm2_epaddr_getctxt((void *)psm2_epaddr); + if (epaddr_context->epid == ep_priv->domain->psm2_epid) { + if (buf && count > 1) { + buf = malloc(len); + psmx2_ioc_read(iov, count, datatype, buf, len); + desc0 = NULL; + } + + if (result_count > 1) { + result = malloc(len); + if (!result) { + if (buf && count > 1) + free(buf); + return -FI_ENOMEM; + } + result_desc0 = result_desc ? result_desc[0] : NULL; + } else { + result = resultv[0].addr; + result_desc0 = NULL; + } + + err = psmx2_atomic_self(PSMX2_AM_REQ_ATOMIC_READWRITE, + ep_priv, ep_priv->domain->eps[vlane], + buf, count, desc0, NULL, NULL, result, + result_desc0, addr, key, datatype, op, + context, flags); + + if (result_count > 1) { + psmx2_ioc_write(resultv, result_count, datatype, result, len); + free(result); + } + + if (buf && count > 1) + free(buf); + + return err; + } + + chunk_size = psmx2_am_param.max_request_short; + if (len > chunk_size) + return -FI_EMSGSIZE; + + iov_size = result_count > 1 ? result_count * sizeof(struct fi_ioc) : 0; + + if (((flags & FI_INJECT) || count > 1) && op != FI_ATOMIC_READ) { + req = malloc(sizeof(*req) + iov_size + len); + if (!req) + return -FI_ENOMEM; + buf = (uint8_t *)req + sizeof(*req) + iov_size; + memset(req, 0, sizeof(*req)); + psmx2_ioc_read(iov, count, datatype, buf, len); + } else { + req = calloc(1, sizeof(*req) + iov_size); + if (!req) + return -FI_ENOMEM; + } + + if (iov_size) { + memcpy(req->ioc, resultv, iov_size); + req->atomic.iov_count = result_count; + req->atomic.result = NULL; + } else { + req->atomic.buf = buf; + req->atomic.result = resultv[0].addr; + } + + req->no_event = (flags & PSMX2_NO_COMPLETION) || + (ep_priv->send_selective_completion && !(flags & FI_COMPLETION)); + + req->op = PSMX2_AM_REQ_ATOMIC_READWRITE; + req->atomic.buf = (void *)buf; + req->atomic.len = len; + req->atomic.addr = addr; + req->atomic.key = key; + req->atomic.context = context; + req->atomic.datatype = datatype; + req->ep = ep_priv; + if (op == FI_ATOMIC_READ) + req->cq_flags = FI_READ | FI_ATOMIC; + else + req->cq_flags = FI_WRITE | FI_ATOMIC; + + args[0].u32w0 = PSMX2_AM_REQ_ATOMIC_READWRITE; + PSMX2_AM_SET_DST(args[0].u32w0, vlane); + args[0].u32w1 = len / fi_datatype_size(datatype); + args[1].u64 = (uint64_t)(uintptr_t)req; + args[2].u64 = addr; + args[3].u64 = key; + args[4].u32w0 = datatype; + args[4].u32w1 = op; + psm2_am_request_short(psm2_epaddr, + PSMX2_AM_ATOMIC_HANDLER, args, 5, + (void *)buf, (buf?len:0), am_flags, NULL, NULL); + + return 0; +} + static ssize_t psmx2_atomic_readwrite(struct fid_ep *ep, const void *buf, size_t count, void *desc, @@ -1095,27 +1518,37 @@ static ssize_t psmx2_atomic_readwritemsg(struct fid_ep *ep, { void *buf; size_t count; + void *desc; - if (!msg || !msg->rma_iov) + if (!msg || !msg->rma_iov || msg->rma_iov_count !=1 || !resultv || + !result_count) + return -FI_EINVAL; + + if ((msg->op != FI_ATOMIC_READ) && (!msg->msg_iov || !msg->iov_count)) return -FI_EINVAL; - if (msg->op == FI_ATOMIC_READ) { - if (result_count != 1 || !resultv) - return -FI_EINVAL; + if ((msg->op != FI_ATOMIC_READ && msg->iov_count > 1) || + result_count > 1) + return psmx2_atomic_readwritev_generic(ep, msg->msg_iov, msg->desc, + msg->iov_count, resultv, + result_desc, result_count, + msg->addr, + msg->rma_iov[0].addr, + msg->rma_iov[0].key, + msg->datatype, msg->op, + msg->context, flags); + if (msg->op == FI_ATOMIC_READ) { buf = NULL; count = resultv[0].count; + desc = result_desc ? result_desc[0] : NULL; } else { - if (msg->iov_count != 1 || !msg->msg_iov) - return -FI_EINVAL; - buf = msg->msg_iov[0].addr; count = msg->msg_iov[0].count; + desc = msg->desc ? msg->desc[0] : NULL; } - return psmx2_atomic_readwrite_generic(ep, buf, count, - msg->desc ? msg->desc[0] : NULL, - resultv[0].addr, + return psmx2_atomic_readwrite_generic(ep, buf, count, desc, resultv[0].addr, result_desc ? result_desc[0] : NULL, msg->addr, msg->rma_iov[0].addr, msg->rma_iov[0].key, msg->datatype, @@ -1132,14 +1565,38 @@ static ssize_t psmx2_atomic_readwritev(struct fid_ep *ep, enum fi_datatype datatype, enum fi_op op, void *context) { - if (!iov || count != 1 || !resultv) + struct psmx2_fid_ep *ep_priv; + void *buf; + void *src_desc; + + ep_priv = container_of(ep, struct psmx2_fid_ep, ep); + + if (!resultv || !result_count) return -FI_EINVAL; - return psmx2_atomic_readwrite(ep, iov->addr, iov->count, - desc ? desc[0] : NULL, - resultv[0].addr, - result_desc ? result_desc[0] : NULL, - dest_addr, addr, key, datatype, op, context); + if ((op != FI_ATOMIC_READ) && (!iov || !count)) + return -FI_EINVAL; + + if ((op != FI_ATOMIC_READ && count > 1) || result_count > 1) + return psmx2_atomic_readwritev_generic(ep, iov, desc, count, + resultv, result_desc, result_count, + dest_addr, addr, key, datatype, op, + context, ep_priv->tx_flags); + + if (FI_ATOMIC_READ) { + buf = NULL; + count = resultv[0].count; + src_desc = result_desc ? result_desc[0] : NULL; + } else { + buf = iov[0].addr; + count = iov[0].count; + src_desc = desc ? desc[0] : NULL; + } + + return psmx2_atomic_readwrite_generic(ep, buf, count, src_desc, resultv[0].addr, + result_desc ? result_desc[0] : NULL, + dest_addr, addr, key, datatype, op, + context, ep_priv->tx_flags); } ssize_t psmx2_atomic_compwrite_generic(struct fid_ep *ep, @@ -1274,6 +1731,7 @@ ssize_t psmx2_atomic_compwrite_generic(struct fid_ep *ep, req->atomic.key = key; req->atomic.context = context; req->atomic.result = result; + req->atomic.datatype = datatype; req->ep = ep_priv; req->cq_flags = FI_WRITE | FI_ATOMIC; @@ -1294,6 +1752,234 @@ ssize_t psmx2_atomic_compwrite_generic(struct fid_ep *ep, return 0; } +ssize_t psmx2_atomic_compwritev_generic(struct fid_ep *ep, + const struct fi_ioc *iov, + void **desc, size_t count, + const struct fi_ioc *comparev, + void **compare_desc, + size_t compare_count, + struct fi_ioc *resultv, + void **result_desc, + size_t result_count, + fi_addr_t dest_addr, + uint64_t addr, uint64_t key, + enum fi_datatype datatype, + enum fi_op op, void *context, + uint64_t flags) +{ + struct psmx2_fid_ep *ep_priv; + struct psmx2_fid_av *av; + struct psmx2_epaddr_context *epaddr_context; + struct psmx2_am_request *req; + psm2_amarg_t args[8]; + psm2_epaddr_t psm2_epaddr; + uint8_t vlane; + int am_flags = PSM2_AM_FLAG_ASYNC; + int chunk_size; + size_t idx; + size_t len, compare_len, result_len, iov_size; + uint8_t *buf, *compare, *result; + void *desc0, *compare_desc0, *result_desc0; + int err; + + ep_priv = container_of(ep, struct psmx2_fid_ep, ep); + + if (flags & FI_TRIGGER) { + struct psmx2_trigger *trigger; + struct fi_triggered_context *ctxt = context; + + trigger = calloc(1, sizeof(*trigger)); + if (!trigger) + return -FI_ENOMEM; + + trigger->op = PSMX2_TRIGGERED_ATOMIC_COMPWRITEV; + trigger->cntr = container_of(ctxt->trigger.threshold.cntr, + struct psmx2_fid_cntr, cntr); + trigger->threshold = ctxt->trigger.threshold.threshold; + trigger->atomic_compwritev.ep = ep; + trigger->atomic_compwritev.iov = iov; + trigger->atomic_compwritev.desc = desc; + trigger->atomic_compwritev.count = count; + trigger->atomic_compwritev.comparev = comparev; + trigger->atomic_compwritev.compare_desc = compare_desc; + trigger->atomic_compwritev.compare_count = compare_count; + trigger->atomic_compwritev.resultv = resultv; + trigger->atomic_compwritev.result_desc = result_desc; + trigger->atomic_compwritev.result_count = result_count; + trigger->atomic_compwritev.dest_addr = dest_addr; + trigger->atomic_compwritev.addr = addr; + trigger->atomic_compwritev.key = key; + trigger->atomic_compwritev.datatype = datatype; + trigger->atomic_compwritev.atomic_op = op; + trigger->atomic_compwritev.context = context; + trigger->atomic_compwritev.flags = flags & ~FI_TRIGGER; + + psmx2_cntr_add_trigger(trigger->cntr, trigger); + return 0; + } + + if (!iov || !count || !comparev || !compare_count || !resultv || + !result_count) + return -FI_EINVAL; + + while (count && !iov[count-1].count) + count--; + + while (compare_count && !comparev[compare_count-1].count) + compare_count--; + + while (result_count && !resultv[result_count-1].count) + result_count--; + + if (datatype < 0 || datatype >= FI_DATATYPE_LAST) + return -FI_EINVAL; + + if (op < 0 || op >= FI_ATOMIC_OP_LAST) + return -FI_EINVAL; + + len = psmx2_ioc_size(iov, count, datatype); + compare_len = psmx2_ioc_size(comparev, compare_count, datatype); + result_len = psmx2_ioc_size(resultv, result_count, datatype); + + if (compare_len < len || result_len < len) + return -FI_EINVAL; + + av = ep_priv->av; + if (av && av->type == FI_AV_TABLE) { + idx = dest_addr; + if (idx >= av->last) + return -FI_EINVAL; + + psm2_epaddr = av->epaddrs[idx]; + vlane = av->vlanes[idx]; + } else { + if (!dest_addr) + return -FI_EINVAL; + + psm2_epaddr = PSMX2_ADDR_TO_EP(dest_addr); + vlane = PSMX2_ADDR_TO_VL(dest_addr); + } + + epaddr_context = psm2_epaddr_getctxt((void *)psm2_epaddr); + if (epaddr_context->epid == ep_priv->domain->psm2_epid) { + if (count > 1) { + buf = malloc(len); + if (!buf) + return -FI_ENOMEM; + psmx2_ioc_read(iov, count, datatype, buf, len); + desc0 = NULL; + } else { + buf = iov[0].addr; + desc0 = desc ? desc[0] : NULL; + } + + if (compare_count > 1) { + compare = malloc(len); + if (!compare) { + if (count > 1) + free(buf); + return -FI_ENOMEM; + } + psmx2_ioc_read(comparev, compare_count, datatype, compare, len); + compare_desc0 = NULL; + } else { + compare = comparev[0].addr; + compare_desc0 = compare_desc ? compare_desc[0] : NULL; + } + + if (result_count > 1) { + result = malloc(len); + if (!result) { + if (compare_count > 1) + free(compare); + if (count > 1) + free(buf); + return -FI_ENOMEM; + } + result_desc0 = NULL; + } else { + result = resultv[0].addr; + result_desc0 = result_desc ? result_desc[0] : NULL; + } + + err = psmx2_atomic_self(PSMX2_AM_REQ_ATOMIC_COMPWRITE, + ep_priv, ep_priv->domain->eps[vlane], + buf, count, desc0, compare, compare_desc0, + result, result_desc0, addr, key, datatype, op, + context, flags); + + if (result_count > 1) { + psmx2_ioc_write(resultv, result_count, datatype, result, len); + free(result); + } + + if (compare_count > 1) + free(compare); + + if (count > 1) + free(buf); + + return err; + } + + chunk_size = psmx2_am_param.max_request_short; + if (len * 2 > chunk_size) + return -FI_EMSGSIZE; + + iov_size = result_count > 1 ? result_count * sizeof(struct fi_ioc) : 0; + + if ((flags & FI_INJECT) || count > 1 || compare_count > 1 || + (uintptr_t)comparev[0].addr != (uintptr_t)iov[0].addr + len) { + req = malloc(sizeof(*req) + iov_size + len + len); + if (!req) + return -FI_ENOMEM; + buf = (uint8_t *)req + sizeof(*req) + iov_size; + memset(req, 0, sizeof(*req)); + psmx2_ioc_read(iov, count, datatype, buf, len); + psmx2_ioc_read(comparev, compare_count, datatype, buf + len, len); + } else { + req = calloc(1, sizeof(*req) + iov_size); + if (!req) + return -FI_ENOMEM; + buf = iov[0].addr; + } + + if (iov_size) { + memcpy(req->ioc, resultv, iov_size); + req->atomic.iov_count = result_count; + req->atomic.result = NULL; + } else { + req->atomic.buf = buf; + req->atomic.result = resultv[0].addr; + } + + req->no_event = (flags & PSMX2_NO_COMPLETION) || + (ep_priv->send_selective_completion && !(flags & FI_COMPLETION)); + + req->op = PSMX2_AM_REQ_ATOMIC_COMPWRITE; + req->atomic.len = len; + req->atomic.addr = addr; + req->atomic.key = key; + req->atomic.context = context; + req->atomic.datatype = datatype; + req->ep = ep_priv; + req->cq_flags = FI_WRITE | FI_ATOMIC; + + args[0].u32w0 = PSMX2_AM_REQ_ATOMIC_COMPWRITE; + PSMX2_AM_SET_DST(args[0].u32w0, vlane); + args[0].u32w1 = len / fi_datatype_size(datatype); + args[1].u64 = (uint64_t)(uintptr_t)req; + args[2].u64 = addr; + args[3].u64 = key; + args[4].u32w0 = datatype; + args[4].u32w1 = op; + psm2_am_request_short(psm2_epaddr, + PSMX2_AM_ATOMIC_HANDLER, args, 5, + buf, len * 2, am_flags, NULL, NULL); + + return 0; +} + static ssize_t psmx2_atomic_compwrite(struct fid_ep *ep, const void *buf, size_t count, void *desc, @@ -1324,9 +2010,20 @@ static ssize_t psmx2_atomic_compwritemsg(struct fid_ep *ep, size_t result_count, uint64_t flags) { - if (!msg || msg->iov_count != 1 || !msg->msg_iov || !msg->rma_iov || !resultv) + if (!msg || !msg->msg_iov || msg->iov_count != 1 || + !msg->rma_iov || msg->rma_iov_count != 1 || + !comparev || !compare_count || !resultv || !result_count) return -FI_EINVAL; + if (msg->iov_count > 1 || compare_count > 1 || result_count > 1) + return psmx2_atomic_compwritev_generic(ep, msg->msg_iov, msg->desc, + msg->iov_count, comparev, + compare_desc, compare_count, + resultv, result_desc, result_count, + msg->addr, msg->rma_iov[0].addr, + msg->rma_iov[0].key, msg->datatype, + msg->op, msg->context, flags); + return psmx2_atomic_compwrite_generic(ep, msg->msg_iov[0].addr, msg->msg_iov[0].count, msg->desc ? msg->desc[0] : NULL, @@ -1353,16 +2050,30 @@ static ssize_t psmx2_atomic_compwritev(struct fid_ep *ep, enum fi_datatype datatype, enum fi_op op, void *context) { - if (!iov || count != 1 || !comparev || !resultv) + struct psmx2_fid_ep *ep_priv; + + ep_priv = container_of(ep, struct psmx2_fid_ep, ep); + + if (!iov || !count || !comparev || !compare_count || !resultv || !result_count) return -FI_EINVAL; - return psmx2_atomic_compwrite(ep, iov->addr, iov->count, - desc ? desc[0] : NULL, - comparev[0].addr, - compare_desc ? compare_desc[0] : NULL, - resultv[0].addr, - result_desc ? result_desc[0] : NULL, - dest_addr, addr, key, datatype, op, context); + if (count > 1 || compare_count > 1 || result_count > 1) + return psmx2_atomic_compwritev_generic(ep, iov, desc, count, + comparev, compare_desc, + compare_count, resultv, + result_desc, result_count, + dest_addr, addr, key, + datatype, op, context, + ep_priv->tx_flags); + + return psmx2_atomic_compwrite_generic(ep, iov->addr, iov->count, + desc ? desc[0] : NULL, + comparev[0].addr, + compare_desc ? compare_desc[0] : NULL, + resultv[0].addr, + result_desc ? result_desc[0] : NULL, + dest_addr, addr, key, datatype, op, + context, ep_priv->tx_flags); } static int psmx2_atomic_writevalid(struct fid_ep *ep,