diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index b5d79a8ec1..f39b68474e 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -15,6 +15,7 @@ #include "opal/runtime/opal.h" #include "opal/mca/pmix/pmix.h" #include "ompi/message/message.h" +#include "ompi/mca/pml/base/pml_base_bsend.h" #include "pml_ucx_request.h" #include @@ -333,7 +334,7 @@ static void mca_pml_ucx_waitall(void **reqs, size_t *count_p) ucs_status_t status; size_t i; - PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", *count_p); + PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", (int)*count_p); for (i = 0; i < *count_p; ++i) { do { opal_progress(); @@ -343,7 +344,7 @@ static void mca_pml_ucx_waitall(void **reqs, size_t *count_p) PML_UCX_ERROR("disconnect request failed: %s", ucs_status_string(status)); } - ucp_request_release(reqs[i]); + ucp_request_free(reqs[i]); reqs[i] = NULL; } @@ -391,7 +392,7 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs) proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL; - if (num_reqs >= ompi_pml_ucx.num_disconnect) { + if ((int)num_reqs >= ompi_pml_ucx.num_disconnect) { mca_pml_ucx_waitall(dreqs, &num_reqs); } } @@ -494,7 +495,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv"); PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm); - req = alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size; + req = (char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size; status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count, mca_pml_ucx_get_datatype(datatype), ucp_tag, ucp_tag_mask, req); @@ -556,15 +557,80 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat req->flags = MCA_PML_UCX_REQUEST_FLAG_SEND; req->buffer = (void *)buf; req->count = count; - req->datatype = mca_pml_ucx_get_datatype(datatype); req->tag = PML_UCX_MAKE_SEND_TAG(tag, comm); req->send.mode = mode; req->send.ep = ep; + if (MCA_PML_BASE_SEND_BUFFERED == mode) { + req->ompi_datatype = datatype; + OBJ_RETAIN(datatype); + } else { + req->datatype = mca_pml_ucx_get_datatype(datatype); + } *request = &req->ompi; return OMPI_SUCCESS; } +static int +mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count, + ompi_datatype_t *datatype, uint64_t pml_tag) +{ + ompi_request_t *req; + void *packed_data; + size_t packed_length; + size_t offset; + uint32_t iov_count; + struct iovec iov; + opal_convertor_t opal_conv; + + OBJ_CONSTRUCT(&opal_conv, opal_convertor_t); + opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor, + &datatype->super, count, buf, 0, + &opal_conv); + opal_convertor_get_packed_size(&opal_conv, &packed_length); + + packed_data = mca_pml_base_bsend_request_alloc_buf(packed_length); + if (OPAL_UNLIKELY(NULL == packed_data)) { + OBJ_DESTRUCT(&opal_conv); + PML_UCX_ERROR("bsend: failed to allocate buffer"); + return OMPI_ERR_OUT_OF_RESOURCE; + } + + iov_count = 1; + iov.iov_base = packed_data; + iov.iov_len = packed_length; + + PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %d\n", packed_data, packed_length); + offset = 0; + opal_convertor_set_position(&opal_conv, &offset); + if (0 > opal_convertor_pack(&opal_conv, &iov, &iov_count, &packed_length)) { + mca_pml_base_bsend_request_free(packed_data); + OBJ_DESTRUCT(&opal_conv); + PML_UCX_ERROR("bsend: failed to pack user datatype"); + return OMPI_ERROR; + } + + OBJ_DESTRUCT(&opal_conv); + + req = (ompi_request_t*)ucp_tag_send_nb(ep, packed_data, packed_length, + ucp_dt_make_contig(1), pml_tag, + mca_pml_ucx_bsend_completion); + if (NULL == req) { + /* request was completed in place */ + mca_pml_base_bsend_request_free(packed_data); + return OMPI_SUCCESS; + } + + if (OPAL_UNLIKELY(UCS_PTR_IS_ERR(req))) { + mca_pml_base_bsend_request_free(packed_data); + PML_UCX_ERROR("ucx bsend failed: %s", ucs_status_string(UCS_PTR_STATUS(req))); + return OMPI_ERROR; + } + + req->req_complete_cb_data = packed_data; + return OMPI_SUCCESS; +} + int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype, int dst, int tag, mca_pml_base_send_mode_t mode, struct ompi_communicator_t* comm, @@ -573,8 +639,10 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype, ompi_request_t *req; ucp_ep_h ep; - PML_UCX_TRACE_SEND("isend request *%p", buf, count, datatype, dst, tag, mode, - comm, (void*)request) + PML_UCX_TRACE_SEND("i%ssend request *%p", + buf, count, datatype, dst, tag, mode, comm, + mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "", + (void*)request) /* TODO special care to sync/buffered send */ @@ -584,6 +652,13 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype, return OMPI_ERROR; } + /* Special care to sync/buffered send */ + if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) { + *request = &ompi_pml_ucx.completed_send_req; + return mca_pml_ucx_bsend(ep, buf, count, datatype, + PML_UCX_MAKE_SEND_TAG(tag, comm)); + } + req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count, mca_pml_ucx_get_datatype(datatype), PML_UCX_MAKE_SEND_TAG(tag, comm), @@ -609,9 +684,8 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i ompi_request_t *req; ucp_ep_h ep; - PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, "send"); - - /* TODO special care to sync/buffered send */ + PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, + mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send"); ep = mca_pml_ucx_get_ep(comm, dst); if (OPAL_UNLIKELY(NULL == ep)) { @@ -619,6 +693,12 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i return OMPI_ERROR; } + /* Special care to sync/buffered send */ + if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) { + return mca_pml_ucx_bsend(ep, buf, count, datatype, + PML_UCX_MAKE_SEND_TAG(tag, comm)); + } + req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count, mca_pml_ucx_get_datatype(datatype), PML_UCX_MAKE_SEND_TAG(tag, comm), @@ -781,6 +861,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests) mca_pml_ucx_persistent_request_t *preq; ompi_request_t *tmp_req; size_t i; + int rc; for (i = 0; i < count; ++i) { preq = (mca_pml_ucx_persistent_request_t *)requests[i]; @@ -795,12 +876,22 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests) mca_pml_ucx_request_reset(&preq->ompi); if (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) { - /* TODO special care to sync/buffered send */ - PML_UCX_VERBOSE(8, "start send request %p", (void*)preq); - tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer, - preq->count, preq->datatype, - preq->tag, - mca_pml_ucx_psend_completion); + if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == preq->send.mode)) { + PML_UCX_VERBOSE(8, "start bsend request %p", (void*)preq); + rc = mca_pml_ucx_bsend(preq->send.ep, preq->buffer, preq->count, + preq->ompi_datatype, preq->tag); + if (OMPI_SUCCESS != rc) { + return rc; + } + /* pretend that we got immediate completion */ + tmp_req = NULL; + } else { + PML_UCX_VERBOSE(8, "start send request %p", (void*)preq); + tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer, + preq->count, preq->datatype, + preq->tag, + mca_pml_ucx_psend_completion); + } } else { PML_UCX_VERBOSE(8, "start recv request %p", (void*)preq); tmp_req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, diff --git a/ompi/mca/pml/ucx/pml_ucx_request.c b/ompi/mca/pml/ucx/pml_ucx_request.c index 49ec04b2fb..3b86c7ab70 100644 --- a/ompi/mca/pml/ucx/pml_ucx_request.c +++ b/ompi/mca/pml/ucx/pml_ucx_request.c @@ -24,7 +24,7 @@ static int mca_pml_ucx_request_free(ompi_request_t **rptr) *rptr = MPI_REQUEST_NULL; mca_pml_ucx_request_reset(req); - ucp_request_release(req); + ucp_request_free(req); return OMPI_SUCCESS; } @@ -46,6 +46,18 @@ void mca_pml_ucx_send_completion(void *request, ucs_status_t status) ompi_request_complete(req, true); } +void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status) +{ + ompi_request_t *req = request; + + PML_UCX_VERBOSE(8, "bsend request %p buffer %p completed with status %s", (void*)req, + req->req_complete_cb_data, ucs_status_string(status)); + mca_pml_base_bsend_request_free(req->req_complete_cb_data); + mca_pml_ucx_set_send_status(&req->req_status, status); + PML_UCX_ASSERT( !(REQUEST_COMPLETE(req))); + mca_pml_ucx_request_free(&req); +} + void mca_pml_ucx_recv_completion(void *request, ucs_status_t status, ucp_tag_recv_info_t *info) { @@ -75,7 +87,7 @@ mca_pml_ucx_persistent_request_complete(mca_pml_ucx_persistent_request_t *preq, ompi_request_complete(&preq->ompi, true); mca_pml_ucx_persistent_request_detach(preq, tmp_req); mca_pml_ucx_request_reset(tmp_req); - ucp_request_release(tmp_req); + ucp_request_free(tmp_req); } static inline void mca_pml_ucx_preq_completion(ompi_request_t *tmp_req) @@ -152,7 +164,10 @@ static int mca_pml_ucx_persistent_request_free(ompi_request_t **rptr) preq->ompi.req_state = OMPI_REQUEST_INVALID; if (tmp_req != NULL) { mca_pml_ucx_persistent_request_detach(preq, tmp_req); - ucp_request_release(tmp_req); + ucp_request_free(tmp_req); + } + if (MCA_PML_BASE_SEND_BUFFERED == preq->send.mode) { + OBJ_RELEASE(preq->ompi_datatype); } PML_UCX_FREELIST_RETURN(&ompi_pml_ucx.persistent_reqs, &preq->ompi.super); *rptr = MPI_REQUEST_NULL; diff --git a/ompi/mca/pml/ucx/pml_ucx_request.h b/ompi/mca/pml/ucx/pml_ucx_request.h index 2aed32ac0e..5aa657eccb 100644 --- a/ompi/mca/pml/ucx/pml_ucx_request.h +++ b/ompi/mca/pml/ucx/pml_ucx_request.h @@ -99,7 +99,10 @@ struct pml_ucx_persistent_request { unsigned flags; void *buffer; size_t count; - ucp_datatype_t datatype; + union { + ucp_datatype_t datatype; + ompi_datatype_t *ompi_datatype; + }; ucp_tag_t tag; struct { mca_pml_base_send_mode_t mode; @@ -118,6 +121,8 @@ void mca_pml_ucx_recv_completion(void *request, ucs_status_t status, void mca_pml_ucx_psend_completion(void *request, ucs_status_t status); +void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status); + void mca_pml_ucx_precv_completion(void *request, ucs_status_t status, ucp_tag_recv_info_t *info);