diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index 1b5ca4b18e..a455c674e5 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -482,7 +482,6 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src ucp_tag_recv_info_t info; ucs_status_t status; void *req; - int i; PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv"); @@ -493,16 +492,12 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src mca_pml_ucx_get_datatype(datatype), ucp_tag, ucp_tag_mask, req); - while (1) { - for (i = 0; i < opal_common_ucx.progress_iterations; i++) { - status = ucp_request_test(req, &info); - if (status != UCS_INPROGRESS) { - mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info); - return OMPI_SUCCESS; - } - ucp_worker_progress(ompi_pml_ucx.ucp_worker); + MCA_COMMON_UCX_PROGRESS_LOOP(ompi_pml_ucx.ucp_worker) { + status = ucp_request_test(req, &info); + if (status != UCS_INPROGRESS) { + mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info); + return OMPI_SUCCESS; } - opal_progress(); } } @@ -690,7 +685,6 @@ mca_pml_ucx_send_nb(ucp_ep_h ep, const void *buf, size_t count, req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype, mca_pml_ucx_get_datatype(datatype), tag, mode, cb); - if (OPAL_LIKELY(req == NULL)) { return OMPI_SUCCESS; } else if (!UCS_PTR_IS_ERR(req)) { @@ -754,6 +748,8 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm, int *matched, ompi_status_public_t* mpi_status) { + static int ucx_progress_cntr = 0; + ucp_tag_t ucp_tag, ucp_tag_mask; ucp_tag_recv_info_t info; ucp_tag_message_h ucp_msg; @@ -766,8 +762,10 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm, if (ucp_msg != NULL) { *matched = 1; mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info); - } else { - opal_progress(); + ucx_progress_cntr = 0; + } else { + (++ucx_progress_cntr % opal_common_ucx.progress_iterations) ? + (void)ucp_worker_progress(ompi_pml_ucx.ucp_worker) : opal_progress(); *matched = 0; } return OMPI_SUCCESS; @@ -779,29 +777,27 @@ int mca_pml_ucx_probe(int src, int tag, struct ompi_communicator_t* comm, ucp_tag_t ucp_tag, ucp_tag_mask; ucp_tag_recv_info_t info; ucp_tag_message_h ucp_msg; - int i; PML_UCX_TRACE_PROBE("probe", src, tag, comm); PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm); - while (1) { - for (i = 0; i < opal_common_ucx.progress_iterations; i++) { - ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask, - 0, &info); - if (ucp_msg != NULL) { - mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info); - return OMPI_SUCCESS; - } - ucp_worker_progress(ompi_pml_ucx.ucp_worker); + + MCA_COMMON_UCX_PROGRESS_LOOP(ompi_pml_ucx.ucp_worker) { + ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, + ucp_tag_mask, 0, &info); + if (ucp_msg != NULL) { + mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info); + return OMPI_SUCCESS; } - opal_progress(); } } int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm, - int *matched, struct ompi_message_t **message, - ompi_status_public_t* mpi_status) + int *matched, struct ompi_message_t **message, + ompi_status_public_t* mpi_status) { + static int ucx_progress_cntr = 0; + ucp_tag_t ucp_tag, ucp_tag_mask; ucp_tag_recv_info_t info; ucp_tag_message_h ucp_msg; @@ -816,8 +812,10 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm, PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg); *matched = 1; mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info); + ucx_progress_cntr = 0; } else { - opal_progress(); + (++ucx_progress_cntr % opal_common_ucx.progress_iterations) ? + (void)ucp_worker_progress(ompi_pml_ucx.ucp_worker) : opal_progress(); *matched = 0; } return OMPI_SUCCESS; @@ -834,7 +832,7 @@ int mca_pml_ucx_mprobe(int src, int tag, struct ompi_communicator_t* comm, PML_UCX_TRACE_PROBE("mprobe", src, tag, comm); PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm); - for (;;) { + MCA_COMMON_UCX_PROGRESS_LOOP(ompi_pml_ucx.ucp_worker) { ucp_msg = ucp_tag_probe_nb(ompi_pml_ucx.ucp_worker, ucp_tag, ucp_tag_mask, 1, &info); if (ucp_msg != NULL) { @@ -843,8 +841,6 @@ int mca_pml_ucx_mprobe(int src, int tag, struct ompi_communicator_t* comm, mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info); return OMPI_SUCCESS; } - - opal_progress(); } } diff --git a/opal/mca/common/ucx/common_ucx.h b/opal/mca/common/ucx/common_ucx.h index 24fd72b4df..98ef19a5fb 100644 --- a/opal/mca/common/ucx/common_ucx.h +++ b/opal/mca/common/ucx/common_ucx.h @@ -52,12 +52,17 @@ BEGIN_C_DECLS __VA_ARGS__); \ } +/* progress loop to allow call UCX/opal progress */ +/* used C99 for-statement variable initialization */ +#define MCA_COMMON_UCX_PROGRESS_LOOP(_worker) \ + for (int iter = 0;; (++iter % opal_common_ucx.progress_iterations) ? \ + (void)ucp_worker_progress(_worker) : opal_progress()) + #define MCA_COMMON_UCX_WAIT_LOOP(_request, _worker, _msg, _completed) \ - while (1) { \ + do { \ ucs_status_t status; \ - int i; \ /* call UCX progress */ \ - for (i = 0; i < opal_common_ucx.progress_iterations; i++) { \ + MCA_COMMON_UCX_PROGRESS_LOOP(_worker) { \ if (UCS_INPROGRESS != (status = opal_common_ucx_request_status(_request))) { \ _completed; \ if (OPAL_LIKELY(UCS_OK == status)) { \ @@ -70,12 +75,8 @@ BEGIN_C_DECLS return OPAL_ERROR; \ } \ } \ - ucp_worker_progress(_worker); \ } \ - /* call OPAL progress on every opal_common_ucx_progress_iterations */ \ - /* calls to UCX progress */ \ - opal_progress(); \ - } + } while (0) typedef struct opal_common_ucx_module { int output;