diff --git a/oshmem/mca/atomic/ucx/atomic_ucx.h b/oshmem/mca/atomic/ucx/atomic_ucx.h index 4db6008c1f..3b5ba03b73 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx.h +++ b/oshmem/mca/atomic/ucx/atomic_ucx.h @@ -60,6 +60,35 @@ struct mca_atomic_ucx_module_t { typedef struct mca_atomic_ucx_module_t mca_atomic_ucx_module_t; OBJ_CLASS_DECLARATION(mca_atomic_ucx_module_t); + +void mca_atomic_ucx_complete_cb(void *request, ucs_status_t status); + +static inline +ucs_status_t mca_atomic_ucx_wait_request(ucs_status_ptr_t request) +{ + ucs_status_t status; + int i; + + /* check for request completed or failed */ + if (UCS_OK == request) { + return UCS_OK; + } else if (UCS_PTR_IS_ERR(request)) { + return UCS_PTR_STATUS(request); + } + + while (1) { + /* call UCX progress */ + for (i = 0; i < 100; i++) { + if (UCS_INPROGRESS != (status = ucp_request_check_status(request))) { + ucp_request_free(request); + return status; + } + ucp_worker_progress(mca_spml_self->ucp_worker); + } + /* call OPAL progress on every 100 call to UCX progress */ + opal_progress(); + } +} END_C_DECLS -#endif /* MCA_ATOMIC_MXM_H */ +#endif /* MCA_ATOMIC_UCX_H */ diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c index 7d84f9e3dc..57723cf0ae 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c @@ -19,6 +19,50 @@ #include "atomic_ucx.h" +/* nlong argument should be constant to hint compiler + * to calculate nlong relative branches in compile time */ +static inline +int mca_atomic_ucx_cswap_inner(void *target, + void *prev, + const void *cond, + const void *value, + size_t nlong, + int pe) +{ + ucs_status_t status; + ucs_status_ptr_t status_ptr; + spml_ucx_mkey_t *ucx_mkey; + uint64_t rva; + uint64_t val; + uint64_t cmp; + + val = (4 == nlong) ? *(uint32_t*)value : *(uint64_t*)value; + ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva); + if (NULL == cond) { + status_ptr = ucp_atomic_fetch_nb(mca_spml_self->ucp_peers[pe].ucp_conn, + UCP_ATOMIC_FETCH_OP_SWAP, val, prev, nlong, + rva, ucx_mkey->rkey, mca_atomic_ucx_complete_cb); + status = mca_atomic_ucx_wait_request(status_ptr); + } + else { + cmp = (4 == nlong) ? *(uint32_t*)cond : *(uint64_t*)cond; + status_ptr = ucp_atomic_fetch_nb(mca_spml_self->ucp_peers[pe].ucp_conn, + UCP_ATOMIC_FETCH_OP_CSWAP, cmp, &val, nlong, + rva, ucx_mkey->rkey, mca_atomic_ucx_complete_cb); + status = mca_atomic_ucx_wait_request(status_ptr); + if (UCS_OK == status) { + assert(NULL != prev); + memcpy(prev, &val, nlong); + if (4 == nlong) { + *(uint32_t*)prev = val; + } else { + *(uint64_t*)prev = val; + } + } + } + return ucx_status_to_oshmem(status); +} + int mca_atomic_ucx_cswap(void *target, void *prev, const void *cond, @@ -26,45 +70,12 @@ int mca_atomic_ucx_cswap(void *target, size_t nlong, int pe) { - ucs_status_t status; - spml_ucx_mkey_t *ucx_mkey; - uint64_t rva; - - ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva); - if (NULL == cond) { - switch (nlong) { - case 4: - status = ucp_atomic_swap32(mca_spml_self->ucp_peers[pe].ucp_conn, - *(uint32_t *)value, rva, ucx_mkey->rkey, prev); - break; - case 8: - status = ucp_atomic_swap64(mca_spml_self->ucp_peers[pe].ucp_conn, - *(uint64_t *)value, rva, ucx_mkey->rkey, prev); - break; - default: - goto err_size; - } + if (8 == nlong) { + return mca_atomic_ucx_cswap_inner(target, prev, cond, value, 8, pe); + } else if (4 == nlong) { + return mca_atomic_ucx_cswap_inner(target, prev, cond, value, 4, pe); + } else { + ATOMIC_ERROR("[#%d] Type size must be 4 or 8 bytes.", my_pe); + return OSHMEM_ERROR; } - else { - switch (nlong) { - case 4: - status = ucp_atomic_cswap32(mca_spml_self->ucp_peers[pe].ucp_conn, - *(uint32_t *)cond, *(uint32_t *)value, rva, ucx_mkey->rkey, prev); - break; - case 8: - status = ucp_atomic_cswap64(mca_spml_self->ucp_peers[pe].ucp_conn, - *(uint64_t *)cond, *(uint64_t *)value, rva, ucx_mkey->rkey, prev); - break; - default: - goto err_size; - } - } - - return ucx_status_to_oshmem(status); - -err_size: - ATOMIC_ERROR("[#%d] Type size must be 4 or 8 bytes.", my_pe); - return OSHMEM_ERROR; } - - diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c b/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c index b9ce9dee0d..b8639d2d00 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c @@ -26,43 +26,32 @@ int mca_atomic_ucx_fadd(void *target, struct oshmem_op_t *op) { ucs_status_t status; + ucs_status_ptr_t status_ptr; spml_ucx_mkey_t *ucx_mkey; uint64_t rva; + uint64_t val; + + if (8 == nlong) { + val = *(uint64_t*)value; + } else if (4 == nlong) { + val = *(uint32_t*)value; + } else { + ATOMIC_ERROR("[#%d] Type size must be 4 or 8 bytes.", my_pe); + return OSHMEM_ERROR; + } ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva); - if (NULL == prev) { - switch (nlong) { - case 4: - status = ucp_atomic_add32(mca_spml_self->ucp_peers[pe].ucp_conn, - *(uint32_t *)value, rva, ucx_mkey->rkey); - break; - case 8: - status = ucp_atomic_add64(mca_spml_self->ucp_peers[pe].ucp_conn, - *(uint64_t *)value, rva, ucx_mkey->rkey); - break; - default: - goto err_size; - } + status = ucp_atomic_post(mca_spml_self->ucp_peers[pe].ucp_conn, + UCP_ATOMIC_POST_OP_ADD, val, nlong, rva, + ucx_mkey->rkey); } else { - switch (nlong) { - case 4: - status = ucp_atomic_fadd32(mca_spml_self->ucp_peers[pe].ucp_conn, - *(uint32_t *)value, rva, ucx_mkey->rkey, prev); - break; - case 8: - status = ucp_atomic_fadd64(mca_spml_self->ucp_peers[pe].ucp_conn, - *(uint64_t *)value, rva, ucx_mkey->rkey, prev); - break; - default: - goto err_size; - } + status_ptr = ucp_atomic_fetch_nb(mca_spml_self->ucp_peers[pe].ucp_conn, + UCP_ATOMIC_FETCH_OP_FADD, val, prev, nlong, + rva, ucx_mkey->rkey, mca_atomic_ucx_complete_cb); + status = mca_atomic_ucx_wait_request(status_ptr); } return ucx_status_to_oshmem(status); - -err_size: - ATOMIC_ERROR("[#%d] Type size must be 4 or 8 bytes.", my_pe); - return OSHMEM_ERROR; } diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_module.c b/oshmem/mca/atomic/ucx/atomic_ucx_module.c index 0b570043a6..a59783d186 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_module.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_module.c @@ -49,3 +49,7 @@ mca_atomic_ucx_query(int *priority) return NULL ; } +void mca_atomic_ucx_complete_cb(void *request, ucs_status_t status) +{ +} +