diff --git a/config/ompi_check_ucx.m4 b/config/ompi_check_ucx.m4 index 6994af8e51..62cb693bb9 100644 --- a/config/ompi_check_ucx.m4 +++ b/config/ompi_check_ucx.m4 @@ -100,6 +100,15 @@ AC_DEFUN([OMPI_CHECK_UCX],[ AC_MSG_RESULT([$ompi_check_ucx_happy]) CPPFLAGS=$old_CPPFLAGS])]) + old_CPPFLAGS="$CPPFLAGS" + AS_IF([test -n "$ompi_check_ucx_dir"], + [CPPFLAGS="$CPPFLAGS -I$ompi_check_ucx_dir/include"]) + AC_CHECK_DECLS([ucp_tag_send_nbr], + [AC_DEFINE([HAVE_UCP_TAG_SEND_NBR],[1], + [have ucp_tag_send_nbr()])], [], + [#include ]) + CPPFLAGS=$old_CPPFLAGS + OPAL_SUMMARY_ADD([[Transports]],[[Open UCX]],[$1],[$ompi_check_ucx_happy])])]) AS_IF([test "$ompi_check_ucx_happy" = "yes"], diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index d7cac1b382..cf453dd1c5 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -75,6 +75,9 @@ mca_pml_ucx_module_t ompi_pml_ucx = { NULL /* ucp_worker */ }; +#define PML_UCX_REQ_ALLOCA() \ + ((char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size); + static int mca_pml_ucx_send_worker_address(void) { ucp_address_t *address; @@ -525,7 +528,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv"); PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm); - req = (char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size; + req = PML_UCX_REQ_ALLOCA(); status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count, mca_pml_ucx_get_datatype(datatype), ucp_tag, ucp_tag_mask, req); @@ -715,26 +718,18 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype, } } -int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, int dst, - int tag, mca_pml_base_send_mode_t mode, - struct ompi_communicator_t* comm) +static inline __opal_attribute_always_inline__ int +mca_pml_ucx_send_nb(ucp_ep_h ep, const void *buf, size_t count, + ompi_datatype_t *datatype, ucp_datatype_t ucx_datatype, + ucp_tag_t tag, mca_pml_base_send_mode_t mode, + ucp_send_callback_t cb) { ompi_request_t *req; - ucp_ep_h ep; - - PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, - mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send"); - - ep = mca_pml_ucx_get_ep(comm, dst); - if (OPAL_UNLIKELY(NULL == ep)) { - PML_UCX_ERROR("Failed to get ep for rank %d", dst); - return OMPI_ERROR; - } req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype, mca_pml_ucx_get_datatype(datatype), - PML_UCX_MAKE_SEND_TAG(tag, comm), - mode, mca_pml_ucx_send_completion); + tag, mode, + mca_pml_ucx_send_completion); if (OPAL_LIKELY(req == NULL)) { return OMPI_SUCCESS; @@ -749,6 +744,60 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i } } +#if HAVE_DECL_UCP_TAG_SEND_NBR +static inline __opal_attribute_always_inline__ int +mca_pml_ucx_send_nbr(ucp_ep_h ep, const void *buf, size_t count, + ucp_datatype_t ucx_datatype, ucp_tag_t tag) + +{ + void *req; + ucs_status_t status; + + req = PML_UCX_REQ_ALLOCA(); + status = ucp_tag_send_nbr(ep, buf, count, ucx_datatype, tag, req); + if (OPAL_LIKELY(status == UCS_OK)) { + return OMPI_SUCCESS; + } + + ucp_worker_progress(ompi_pml_ucx.ucp_worker); + while ((status = ucp_request_check_status(req)) == UCS_INPROGRESS) { + opal_progress(); + } + + return OPAL_LIKELY(UCS_OK == status) ? OMPI_SUCCESS : OMPI_ERROR; +} +#endif + +int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, int dst, + int tag, mca_pml_base_send_mode_t mode, + struct ompi_communicator_t* comm) +{ + ucp_ep_h ep; + + PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, + mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send"); + + ep = mca_pml_ucx_get_ep(comm, dst); + if (OPAL_UNLIKELY(NULL == ep)) { + PML_UCX_ERROR("Failed to get ep for rank %d", dst); + return OMPI_ERROR; + } + +#if HAVE_DECL_UCP_TAG_SEND_NBR + if (OPAL_LIKELY((MCA_PML_BASE_SEND_BUFFERED != mode) && + (MCA_PML_BASE_SEND_SYNCHRONOUS != mode))) { + return mca_pml_ucx_send_nbr(ep, buf, count, + mca_pml_ucx_get_datatype(datatype), + PML_UCX_MAKE_SEND_TAG(tag, comm)); + } +#endif + + return mca_pml_ucx_send_nb(ep, buf, count, datatype, + mca_pml_ucx_get_datatype(datatype), + PML_UCX_MAKE_SEND_TAG(tag, comm), mode, + mca_pml_ucx_send_completion); +} + int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm, int *matched, ompi_status_public_t* mpi_status) {