diff --git a/config/ompi_check_ucx.m4 b/config/ompi_check_ucx.m4 index 6ad07905b6..3da11a2af3 100644 --- a/config/ompi_check_ucx.m4 +++ b/config/ompi_check_ucx.m4 @@ -109,7 +109,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ [have ucp_tag_send_nbr()])], [], [#include ]) AC_CHECK_DECLS([ucp_ep_flush_nb, ucp_worker_flush_nb, - ucp_request_check_status, ucp_put_nb, ucp_get_nb], + ucp_request_check_status, ucp_put_nb, ucp_get_nb, + ucp_put_nbx, ucp_get_nbx, ucp_atomic_op_nbx], [], [], [#include ]) AC_CHECK_DECLS([ucm_test_events, diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c index b3cddcd6d2..45b0ce0069 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c @@ -31,6 +31,14 @@ int mca_atomic_ucx_cswap(shmem_ctx_t ctx, spml_ucx_mkey_t *ucx_mkey; uint64_t rva; mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + ucp_request_param_t param = { + .op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_REPLY_BUFFER, + .datatype = ucp_dt_make_contig(size), + .reply_buffer = prev + }; +#endif if ((8 != size) && (4 != size)) { ATOMIC_ERROR("[#%d] Type size must be 4 or 8 bytes.", my_pe); @@ -41,15 +49,25 @@ int mca_atomic_ucx_cswap(shmem_ctx_t ctx, *prev = value; ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self); +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + status_ptr = ucp_atomic_op_nbx(ucx_ctx->ucp_peers[pe].ucp_conn, + UCP_ATOMIC_OP_CSWAP, &cond, 1, rva, + ucx_mkey->rkey, ¶m); +#else status_ptr = ucp_atomic_fetch_nb(ucx_ctx->ucp_peers[pe].ucp_conn, UCP_ATOMIC_FETCH_OP_CSWAP, cond, prev, size, rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); +#endif if (OPAL_LIKELY(!UCS_PTR_IS_ERR(status_ptr))) { mca_spml_ucx_remote_op_posted(ucx_ctx, pe); } return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0], +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + "ucp_atomic_op_nbx"); +#else "ucp_atomic_fetch_nb"); +#endif } diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_module.c b/oshmem/mca/atomic/ucx/atomic_ucx_module.c index 34ed0b551b..8a9a4a0631 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_module.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_module.c @@ -18,6 +18,17 @@ #include "oshmem/proc/proc.h" #include "atomic_ucx.h" +#if HAVE_DECL_UCP_ATOMIC_OP_NBX +/* + * A static params array, for datatypes of size 4 and 8. "size >> 3" is used to + * access the corresponding offset. + */ +static ucp_request_param_t mca_spml_ucp_request_params[] = { + {.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE, .datatype = ucp_dt_make_contig(4)}, + {.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE, .datatype = ucp_dt_make_contig(8)} +}; +#endif + /* * Initial query function that is invoked during initialization, allowing * this module to indicate what level of thread support it provides. @@ -38,20 +49,37 @@ int mca_atomic_ucx_op(shmem_ctx_t ctx, uint64_t value, size_t size, int pe, +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + ucp_atomic_op_t op) +#else ucp_atomic_post_op_t op) +#endif { ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; uint64_t rva; mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + ucs_status_ptr_t status_ptr; +#endif assert((8 == size) || (4 == size)); ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self); + +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + status_ptr = ucp_atomic_op_nbx(ucx_ctx->ucp_peers[pe].ucp_conn, + op, &value, 1, rva, ucx_mkey->rkey, + &mca_spml_ucp_request_params[size >> 3]); + if (OPAL_LIKELY(!UCS_PTR_IS_ERR(status_ptr))) { + mca_spml_ucx_remote_op_posted(ucx_ctx, pe); + } + status = UCS_PTR_STATUS(status_ptr); +#else status = ucp_atomic_post(ucx_ctx->ucp_peers[pe].ucp_conn, op, value, size, rva, ucx_mkey->rkey); - +#endif if (OPAL_LIKELY(UCS_OK == status)) { mca_spml_ucx_remote_op_posted(ucx_ctx, pe); } @@ -66,22 +94,41 @@ int mca_atomic_ucx_fop(shmem_ctx_t ctx, uint64_t value, size_t size, int pe, +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + ucp_atomic_op_t op) +#else ucp_atomic_fetch_op_t op) +#endif { ucs_status_ptr_t status_ptr; spml_ucx_mkey_t *ucx_mkey; uint64_t rva; mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + ucp_request_param_t param = { + .op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_REPLY_BUFFER, + .datatype = ucp_dt_make_contig(size), + .reply_buffer = prev + }; +#endif assert((8 == size) || (4 == size)); ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self); +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + status_ptr = ucp_atomic_op_nbx(ucx_ctx->ucp_peers[pe].ucp_conn, op, &value, 1, + rva, ucx_mkey->rkey, ¶m); + return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0], + "ucp_atomic_op_nbx"); +#else status_ptr = ucp_atomic_fetch_nb(ucx_ctx->ucp_peers[pe].ucp_conn, op, value, prev, size, rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0], "ucp_atomic_fetch_nb"); +#endif } static int mca_atomic_ucx_add(shmem_ctx_t ctx, @@ -90,7 +137,11 @@ static int mca_atomic_ucx_add(shmem_ctx_t ctx, size_t size, int pe) { +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_OP_ADD); +#else return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_POST_OP_ADD); +#endif } static int mca_atomic_ucx_and(shmem_ctx_t ctx, @@ -99,7 +150,9 @@ static int mca_atomic_ucx_and(shmem_ctx_t ctx, size_t size, int pe) { -#if HAVE_DECL_UCP_ATOMIC_POST_OP_AND +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_OP_AND); +#elif HAVE_DECL_UCP_ATOMIC_POST_OP_AND return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_POST_OP_AND); #else return OSHMEM_ERR_NOT_IMPLEMENTED; @@ -112,7 +165,9 @@ static int mca_atomic_ucx_or(shmem_ctx_t ctx, size_t size, int pe) { -#if HAVE_DECL_UCP_ATOMIC_POST_OP_OR +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_OP_OR); +#elif HAVE_DECL_UCP_ATOMIC_POST_OP_OR return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_POST_OP_OR); #else return OSHMEM_ERR_NOT_IMPLEMENTED; @@ -125,7 +180,9 @@ static int mca_atomic_ucx_xor(shmem_ctx_t ctx, size_t size, int pe) { -#if HAVE_DECL_UCP_ATOMIC_POST_OP_XOR +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_OP_XOR); +#elif HAVE_DECL_UCP_ATOMIC_POST_OP_XOR return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_POST_OP_XOR); #else return OSHMEM_ERR_NOT_IMPLEMENTED; @@ -139,7 +196,11 @@ static int mca_atomic_ucx_fadd(shmem_ctx_t ctx, size_t size, int pe) { +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_OP_ADD); +#else return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_FETCH_OP_FADD); +#endif } static int mca_atomic_ucx_fand(shmem_ctx_t ctx, @@ -149,7 +210,9 @@ static int mca_atomic_ucx_fand(shmem_ctx_t ctx, size_t size, int pe) { -#if HAVE_DECL_UCP_ATOMIC_FETCH_OP_FAND +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_OP_AND); +#elif HAVE_DECL_UCP_ATOMIC_FETCH_OP_FAND return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_FETCH_OP_FAND); #else return OSHMEM_ERR_NOT_IMPLEMENTED; @@ -163,7 +226,9 @@ static int mca_atomic_ucx_for(shmem_ctx_t ctx, size_t size, int pe) { -#if HAVE_DECL_UCP_ATOMIC_FETCH_OP_FOR +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_OP_OR); +#elif HAVE_DECL_UCP_ATOMIC_FETCH_OP_FOR return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_FETCH_OP_FOR); #else return OSHMEM_ERR_NOT_IMPLEMENTED; @@ -177,7 +242,9 @@ static int mca_atomic_ucx_fxor(shmem_ctx_t ctx, size_t size, int pe) { -#if HAVE_DECL_UCP_ATOMIC_FETCH_OP_FXOR +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_OP_XOR); +#elif HAVE_DECL_UCP_ATOMIC_FETCH_OP_FXOR return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_FETCH_OP_FXOR); #else return OSHMEM_ERR_NOT_IMPLEMENTED; @@ -191,7 +258,11 @@ static int mca_atomic_ucx_swap(shmem_ctx_t ctx, size_t size, int pe) { +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_OP_SWAP); +#else return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_FETCH_OP_SWAP); +#endif } diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 60453e9243..ed4c1c6324 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -90,6 +90,10 @@ mca_spml_ucx_ctx_t mca_spml_ucx_ctx_default = { .options = 0 }; +#if HAVE_DECL_UCP_ATOMIC_OP_NBX +static ucp_request_param_t mca_spml_ucx_request_param = {0}; +#endif + int mca_spml_ucx_enable(bool enable) { SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****"); @@ -813,16 +817,19 @@ void mca_spml_ucx_ctx_destroy(shmem_ctx_t ctx) int mca_spml_ucx_get(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_addr, int src) { void *rva; - spml_ucx_mkey_t *ucx_mkey; + spml_ucx_mkey_t *ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx); mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; -#if HAVE_DECL_UCP_GET_NB +#if (HAVE_DECL_UCP_GET_NBX || HAVE_DECL_UCP_GET_NB) ucs_status_ptr_t request; #else ucs_status_t status; #endif - ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx); -#if HAVE_DECL_UCP_GET_NB +#if HAVE_DECL_UCP_GET_NBX + request = ucp_get_nbx(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size, + (uint64_t)rva, ucx_mkey->rkey, &mca_spml_ucx_request_param); + return opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker[0], "ucp_get_nbx"); +#elif HAVE_DECL_UCP_GET_NB request = ucp_get_nb(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); return opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker[0], "ucp_get_nb"); @@ -837,13 +844,25 @@ int mca_spml_ucx_get_nb(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_ { void *rva; ucs_status_t status; - spml_ucx_mkey_t *ucx_mkey; + spml_ucx_mkey_t *ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx); mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; +#if HAVE_DECL_UCP_GET_NBX + ucs_status_ptr_t status_ptr; +#endif - ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx); +#if HAVE_DECL_UCP_GET_NBX + status_ptr = ucp_get_nbx(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size, + (uint64_t)rva, ucx_mkey->rkey, &mca_spml_ucx_request_param); + if (UCS_PTR_IS_PTR(status_ptr)) { + ucp_request_free(status_ptr); + status = UCS_INPROGRESS; + } else { + status = UCS_PTR_STATUS(status_ptr); + } +#else status = ucp_get_nbi(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey); - +#endif return ucx_status_to_oshmem_nb(status); } @@ -852,12 +871,26 @@ int mca_spml_ucx_get_nb_wprogress(shmem_ctx_t ctx, void *src_addr, size_t size, unsigned int i; void *rva; ucs_status_t status; - spml_ucx_mkey_t *ucx_mkey; + spml_ucx_mkey_t *ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx); mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; +#if HAVE_DECL_UCP_GET_NBX + ucs_status_ptr_t status_ptr; +#endif - ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx); +#if HAVE_DECL_UCP_GET_NBX + status_ptr = ucp_get_nbx(ucx_ctx->ucp_peers[src].ucp_conn, + dst_addr, size, (uint64_t)rva, + ucx_mkey->rkey, &mca_spml_ucx_request_param); + if (UCS_PTR_IS_PTR(status_ptr)) { + ucp_request_free(status_ptr); + status = UCS_INPROGRESS; + } else { + status = UCS_PTR_STATUS(status_ptr); + } +#else status = ucp_get_nbi(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey); +#endif if (++ucx_ctx->nb_progress_cnt > mca_spml_ucx.nb_get_progress_thresh) { for (i = 0; i < mca_spml_ucx.nb_ucp_worker_progress; i++) { @@ -874,17 +907,20 @@ int mca_spml_ucx_get_nb_wprogress(shmem_ctx_t ctx, void *src_addr, size_t size, int mca_spml_ucx_put(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_addr, int dst) { void *rva; - spml_ucx_mkey_t *ucx_mkey; + spml_ucx_mkey_t *ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx); mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; int res; -#if HAVE_DECL_UCP_PUT_NB +#if (HAVE_DECL_UCP_PUT_NBX || HAVE_DECL_UCP_PUT_NB) ucs_status_ptr_t request; #else ucs_status_t status; #endif - ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx); -#if HAVE_DECL_UCP_PUT_NB +#if HAVE_DECL_UCP_PUT_NBX + request = ucp_put_nbx(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size, + (uint64_t)rva, ucx_mkey->rkey, &mca_spml_ucx_request_param); + res = opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker[0], "ucp_put_nbx"); +#elif HAVE_DECL_UCP_PUT_NB request = ucp_put_nb(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); res = opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker[0], "ucp_put_nb"); @@ -904,14 +940,27 @@ int mca_spml_ucx_put(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_add int mca_spml_ucx_put_nb(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_addr, int dst, void **handle) { void *rva; - ucs_status_t status; - spml_ucx_mkey_t *ucx_mkey; + spml_ucx_mkey_t *ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx); mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; + ucs_status_t status; +#if HAVE_DECL_UCP_PUT_NBX + ucs_status_ptr_t status_ptr; +#endif - ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx); +#if HAVE_DECL_UCP_PUT_NBX + status_ptr = ucp_put_nbx(ucx_ctx->ucp_peers[dst].ucp_conn, + src_addr, size, (uint64_t)rva, + ucx_mkey->rkey, &mca_spml_ucx_request_param); + if (UCS_PTR_IS_PTR(status_ptr)) { + ucp_request_free(status_ptr); + status = UCS_INPROGRESS; + } else { + status = UCS_PTR_STATUS(status_ptr); + } +#else status = ucp_put_nbi(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey); - +#endif if (OPAL_LIKELY(status >= 0)) { mca_spml_ucx_remote_op_posted(ucx_ctx, dst); } @@ -924,13 +973,26 @@ int mca_spml_ucx_put_nb_wprogress(shmem_ctx_t ctx, void* dst_addr, size_t size, unsigned int i; void *rva; ucs_status_t status; - spml_ucx_mkey_t *ucx_mkey; + spml_ucx_mkey_t *ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx); mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; +#if HAVE_DECL_UCP_PUT_NBX + ucs_status_ptr_t status_ptr; +#endif - ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx); +#if HAVE_DECL_UCP_PUT_NBX + status_ptr = ucp_put_nbx(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size, + (uint64_t)rva, ucx_mkey->rkey, + &mca_spml_ucx_request_param); + if (UCS_PTR_IS_PTR(status_ptr)) { + ucp_request_free(status_ptr); + status = UCS_INPROGRESS; + } else { + status = UCS_PTR_STATUS(status_ptr); + } +#else status = ucp_put_nbi(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey); - +#endif if (OPAL_LIKELY(status >= 0)) { mca_spml_ucx_remote_op_posted(ucx_ctx, dst); }