diff --git a/config/ompi_check_ucx.m4 b/config/ompi_check_ucx.m4 index 350514d272..e5a5ccf047 100644 --- a/config/ompi_check_ucx.m4 +++ b/config/ompi_check_ucx.m4 @@ -108,7 +108,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ [AC_DEFINE([HAVE_UCP_TAG_SEND_NBR],[1], [have ucp_tag_send_nbr()])], [], [#include ]) - AC_CHECK_DECLS([ucp_ep_flush_nb, ucp_worker_flush_nb, ucp_request_check_status], + AC_CHECK_DECLS([ucp_ep_flush_nb, ucp_worker_flush_nb, + ucp_request_check_status, ucp_put_nb, ucp_get_nb], [], [], [#include ]) CPPFLAGS=$old_CPPFLAGS diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index d03a1ad4d2..b6c4fbd8a6 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -18,6 +18,7 @@ #include "opal/mca/pmix/pmix.h" #include "ompi/message/message.h" #include "ompi/mca/pml/base/pml_base_bsend.h" +#include "opal/mca/common/ucx/common_ucx.h" #include "pml_ucx_request.h" #include @@ -374,29 +375,19 @@ static void mca_pml_ucx_waitall(void **reqs, int *count_p) PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", *count_p); for (i = 0; i < *count_p; ++i) { - do { - opal_progress(); - status = ucp_request_test(reqs[i], NULL); - } while (status == UCS_INPROGRESS); + status = opal_common_ucx_wait_request(reqs[i], ompi_pml_ucx.ucp_worker); if (status != UCS_OK) { PML_UCX_ERROR("disconnect request failed: %s", ucs_status_string(status)); } - ucp_request_free(reqs[i]); reqs[i] = NULL; } *count_p = 0; } -static void mca_pml_fence_complete_cb(int status, void *fenced) -{ - *(int*)fenced = 1; -} - int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs) { - volatile int fenced = 0; ompi_proc_t *proc; int num_reqs; size_t max_reqs; @@ -447,10 +438,7 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs) mca_pml_ucx_waitall(dreqs, &num_reqs); free(dreqs); - opal_pmix.fence_nb(NULL, 0, mca_pml_fence_complete_cb, (void*)&fenced); - while (!fenced) { - ucp_worker_progress(ompi_pml_ucx.ucp_worker); - } + opal_common_ucx_mca_pmix_fence(ompi_pml_ucx.ucp_worker); return OMPI_SUCCESS; } diff --git a/opal/mca/common/ucx/common_ucx.c b/opal/mca/common/ucx/common_ucx.c index 85b96a92cd..b8b1e51b9e 100644 --- a/opal/mca/common/ucx/common_ucx.c +++ b/opal/mca/common/ucx/common_ucx.c @@ -11,6 +11,7 @@ #include "common_ucx.h" #include "opal/mca/base/mca_base_var.h" +#include "opal/mca/pmix/pmix.h" /***********************************************************************/ @@ -36,3 +37,19 @@ OPAL_DECLSPEC void opal_common_ucx_mca_register(void) void opal_common_ucx_empty_complete_cb(void *request, ucs_status_t status) { } + +static void opal_common_ucx_mca_fence_complete_cb(int status, void *fenced) +{ + *(int*)fenced = 1; +} + +OPAL_DECLSPEC void opal_common_ucx_mca_pmix_fence(ucp_worker_h worker) +{ + volatile int fenced = 0; + + opal_pmix.fence_nb(NULL, 0, opal_common_ucx_mca_fence_complete_cb, (void*)&fenced); + while (!fenced) { + ucp_worker_progress(worker); + } +} + diff --git a/opal/mca/common/ucx/common_ucx.h b/opal/mca/common/ucx/common_ucx.h index 7ca296304d..eae31172de 100644 --- a/opal/mca/common/ucx/common_ucx.h +++ b/opal/mca/common/ucx/common_ucx.h @@ -27,6 +27,7 @@ extern int opal_common_ucx_progress_iterations; OPAL_DECLSPEC void opal_common_ucx_mca_register(void); OPAL_DECLSPEC void opal_common_ucx_empty_complete_cb(void *request, ucs_status_t status); +OPAL_DECLSPEC void opal_common_ucx_mca_pmix_fence(ucp_worker_h worker); static inline ucs_status_t opal_common_ucx_wait_request(ucs_status_ptr_t request, ucp_worker_h worker) diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index c9068fafad..555d36eade 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -103,15 +103,11 @@ static void mca_spml_ucx_waitall(void **reqs, int *count_p) SPML_VERBOSE(10, "waiting for %d disconnect requests", *count_p); for (i = 0; i < *count_p; ++i) { - do { - opal_progress(); - status = ucp_request_test(reqs[i], NULL); - } while (status == UCS_INPROGRESS); + status = opal_common_ucx_wait_request(reqs[i], mca_spml_ucx.ucp_worker); if (status != UCS_OK) { SPML_ERROR("disconnect request failed: %s", ucs_status_string(status)); } - ucp_request_release(reqs[i]); reqs[i] = NULL; } @@ -175,8 +171,9 @@ int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs) mca_spml_ucx_waitall(dreqs, &num_reqs); free(dreqs); - opal_pmix.fence(NULL, 0); + opal_common_ucx_mca_pmix_fence(mca_spml_ucx.ucp_worker); free(mca_spml_ucx.ucp_peers); + mca_spml_ucx.ucp_peers = NULL; return OSHMEM_SUCCESS; } @@ -560,10 +557,20 @@ int mca_spml_ucx_get(void *src_addr, size_t size, void *dst_addr, int src) void *rva; ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; +#if HAVE_DECL_UCP_GET_NB + ucs_status_ptr_t request; +#endif ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva, &mca_spml_ucx); +#if HAVE_DECL_UCP_GET_NB + request = ucp_get_nb(mca_spml_ucx.ucp_peers[src].ucp_conn, dst_addr, size, + (uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); + /* TODO: replace wait_request by opal_common_ucx_wait_request_opal_status */ + status = opal_common_ucx_wait_request(request, mca_spml_ucx.ucp_worker); +#else status = ucp_get(mca_spml_ucx.ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey); +#endif return ucx_status_to_oshmem(status); } @@ -586,11 +593,20 @@ int mca_spml_ucx_put(void* dst_addr, size_t size, void* src_addr, int dst) void *rva; ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; +#if HAVE_DECL_UCP_PUT_NB + ucs_status_ptr_t request; +#endif ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva, &mca_spml_ucx); +#if HAVE_DECL_UCP_PUT_NB + request = ucp_put_nb(mca_spml_ucx.ucp_peers[dst].ucp_conn, src_addr, size, + (uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); + /* TODO: replace wait_request by opal_common_ucx_wait_request_opal_status */ + status = opal_common_ucx_wait_request(request, mca_spml_ucx.ucp_worker); +#else status = ucp_put(mca_spml_ucx.ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey); - +#endif return ucx_status_to_oshmem(status); }