diff --git a/ompi/mca/mtl/ofi/mtl_ofi.h b/ompi/mca/mtl/ofi/mtl_ofi.h index e62c5b8225..f47aa5fb86 100644 --- a/ompi/mca/mtl/ofi/mtl_ofi.h +++ b/ompi/mca/mtl/ofi/mtl_ofi.h @@ -238,34 +238,82 @@ ompi_mtl_ofi_isend_callback(struct fi_cq_tagged_entry *wc, } __opal_attribute_always_inline__ static inline int -ompi_mtl_ofi_send_start(struct mca_mtl_base_module_t *mtl, - struct ompi_communicator_t *comm, - int dest, - int tag, - struct opal_convertor_t *convertor, - mca_pml_base_send_mode_t mode, - ompi_mtl_ofi_request_t *ofi_req) +ompi_mtl_ofi_ssend_recv(ompi_mtl_ofi_request_t *ack_req, + struct ompi_communicator_t *comm, + fi_addr_t *src_addr, + ompi_mtl_ofi_request_t *ofi_req, + mca_mtl_ofi_endpoint_t *endpoint, + uint64_t *match_bits, + int tag) { + ssize_t ret = OMPI_SUCCESS; + ack_req = malloc(sizeof(ompi_mtl_ofi_request_t)); + + assert(ack_req); + + ack_req->parent = ofi_req; + ack_req->event_callback = ompi_mtl_ofi_send_ack_callback; + ack_req->error_callback = ompi_mtl_ofi_send_ack_error_callback; + + ofi_req->completion_count += 1; + + MTL_OFI_RETRY_UNTIL_DONE(fi_trecv(ompi_mtl_ofi.ep, + NULL, + 0, + NULL, + *src_addr, + *match_bits | ompi_mtl_ofi.sync_send_ack, + 0, /* Exact match, no ignore bits */ + (void *) &ack_req->ctx), ret); + if (OPAL_UNLIKELY(0 > ret)) { + opal_output_verbose(1, ompi_mtl_base_framework.framework_output, + "%s:%d: fi_trecv failed: %s(%zd)", + __FILE__, __LINE__, fi_strerror(-ret), ret); + free(ack_req); + return ompi_mtl_ofi_get_error(ret); + } + + /* The SYNC_SEND tag bit is set for the send operation only.*/ + MTL_OFI_SET_SYNC_SEND(*match_bits); + return OMPI_SUCCESS; +} + +__opal_attribute_always_inline__ static inline int +ompi_mtl_ofi_send(struct mca_mtl_base_module_t *mtl, + struct ompi_communicator_t *comm, + int dest, + int tag, + struct opal_convertor_t *convertor, + mca_pml_base_send_mode_t mode) +{ + ssize_t ret = OMPI_SUCCESS; + ompi_mtl_ofi_request_t ofi_req; int ompi_ret; void *start; - size_t length; - ssize_t ret; bool free_after; + size_t length; uint64_t match_bits; ompi_proc_t *ompi_proc = NULL; mca_mtl_ofi_endpoint_t *endpoint = NULL; ompi_mtl_ofi_request_t *ack_req = NULL; /* For synchronous send */ fi_addr_t src_addr = 0; + /** + * Create a send request, start it and wait until it completes. + */ + ofi_req.event_callback = ompi_mtl_ofi_send_callback; + ofi_req.error_callback = ompi_mtl_ofi_send_error_callback; + ompi_proc = ompi_comm_peer_lookup(comm, dest); endpoint = ompi_mtl_ofi_get_endpoint(mtl, ompi_proc); ompi_ret = ompi_mtl_datatype_pack(convertor, &start, &length, &free_after); if (OMPI_SUCCESS != ompi_ret) return ompi_ret; - ofi_req->buffer = (free_after) ? start : NULL; - ofi_req->length = length; - ofi_req->status.MPI_ERROR = OMPI_SUCCESS; + ofi_req.buffer = (free_after) ? start : NULL; + ofi_req.length = length; + ofi_req.status.MPI_ERROR = OMPI_SUCCESS; + ofi_req.completion_count = 0; if (ompi_mtl_ofi.fi_cq_data) { match_bits = mtl_ofi_create_send_tag_CQD(comm->c_contextid, tag); @@ -277,33 +325,11 @@ ompi_mtl_ofi_send_start(struct mca_mtl_base_module_t *mtl, } if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_SYNCHRONOUS == mode)) { - ack_req = malloc(sizeof(ompi_mtl_ofi_request_t)); - assert(ack_req); - ack_req->parent = ofi_req; - ack_req->event_callback = ompi_mtl_ofi_send_ack_callback; - ack_req->error_callback = ompi_mtl_ofi_send_ack_error_callback; - - ofi_req->completion_count = 2; - - MTL_OFI_RETRY_UNTIL_DONE(fi_trecv(ompi_mtl_ofi.ep, - NULL, - 0, - NULL, - src_addr, - match_bits | ompi_mtl_ofi.sync_send_ack, - 0, /* Exact match, no ignore bits */ - (void *) &ack_req->ctx), ret); - if (OPAL_UNLIKELY(0 > ret)) { - opal_output_verbose(1, ompi_mtl_base_framework.framework_output, - "%s:%d: fi_trecv failed: %s(%zd)", - __FILE__, __LINE__, fi_strerror(-ret), ret); - free(ack_req); - return ompi_mtl_ofi_get_error(ret); - } - /* The SYNC_SEND tag bit is set for the send operation only.*/ - MTL_OFI_SET_SYNC_SEND(match_bits); - } else { - ofi_req->completion_count = 1; + ofi_req.status.MPI_ERROR = ompi_mtl_ofi_ssend_recv(ack_req, comm, &src_addr, + &ofi_req, endpoint, + &match_bits, tag); + if (OPAL_UNLIKELY(ofi_req.status.MPI_ERROR != OMPI_SUCCESS)) + goto free_request_buffer; } if (ompi_mtl_ofi.max_inject_size >= length) { @@ -331,11 +357,12 @@ ompi_mtl_ofi_send_start(struct mca_mtl_base_module_t *mtl, fi_cancel((fid_t)ompi_mtl_ofi.ep, &ack_req->ctx); free(ack_req); } - return ompi_mtl_ofi_get_error(ret); - } - ofi_req->event_callback(NULL,ofi_req); + ofi_req.status.MPI_ERROR = ompi_mtl_ofi_get_error(ret); + goto free_request_buffer; + } } else { + ofi_req.completion_count += 1; if (ompi_mtl_ofi.fi_cq_data) { MTL_OFI_RETRY_UNTIL_DONE(fi_tsenddata(ompi_mtl_ofi.ep, start, @@ -344,7 +371,7 @@ ompi_mtl_ofi_send_start(struct mca_mtl_base_module_t *mtl, comm->c_my_rank, endpoint->peer_fiaddr, match_bits, - (void *) &ofi_req->ctx), ret); + (void *) &ofi_req.ctx), ret); } else { MTL_OFI_RETRY_UNTIL_DONE(fi_tsend(ompi_mtl_ofi.ep, start, @@ -352,46 +379,20 @@ ompi_mtl_ofi_send_start(struct mca_mtl_base_module_t *mtl, NULL, endpoint->peer_fiaddr, match_bits, - (void *) &ofi_req->ctx), ret); + (void *) &ofi_req.ctx), ret); } if (OPAL_UNLIKELY(0 > ret)) { char *fi_api = ompi_mtl_ofi.fi_cq_data ? "fi_tsendddata" : "fi_send"; opal_output_verbose(1, ompi_mtl_base_framework.framework_output, "%s:%d: %s failed: %s(%zd)", __FILE__, __LINE__,fi_api, fi_strerror(-ret), ret); - return ompi_mtl_ofi_get_error(ret); + free(fi_api); + + ofi_req.status.MPI_ERROR = ompi_mtl_ofi_get_error(ret); + goto free_request_buffer; } } - return OMPI_SUCCESS; -} - -__opal_attribute_always_inline__ static inline int -ompi_mtl_ofi_send(struct mca_mtl_base_module_t *mtl, - struct ompi_communicator_t *comm, - int dest, - int tag, - struct opal_convertor_t *convertor, - mca_pml_base_send_mode_t mode) -{ - int ret = OMPI_SUCCESS; - ompi_mtl_ofi_request_t ofi_req; - - /** - * Create a send request, start it and wait until it completes. - */ - ofi_req.event_callback = ompi_mtl_ofi_send_callback; - ofi_req.error_callback = ompi_mtl_ofi_send_error_callback; - - ret = ompi_mtl_ofi_send_start(mtl, comm, dest, tag, - convertor, mode, &ofi_req); - if (OPAL_UNLIKELY(OMPI_SUCCESS != ret)) { - if (NULL != ofi_req.buffer) { - free(ofi_req.buffer); - } - return ret; - } - /** * Wait until the request is completed. * ompi_mtl_ofi_send_callback() updates this variable. @@ -400,6 +401,7 @@ ompi_mtl_ofi_send(struct mca_mtl_base_module_t *mtl, ompi_mtl_ofi_progress(); } +free_request_buffer: if (OPAL_UNLIKELY(NULL != ofi_req.buffer)) { free(ofi_req.buffer); } @@ -417,20 +419,89 @@ ompi_mtl_ofi_isend(struct mca_mtl_base_module_t *mtl, bool blocking, mca_mtl_request_t *mtl_request) { - int ret = OMPI_SUCCESS; - ompi_mtl_ofi_request_t *ofi_req = (ompi_mtl_ofi_request_t*) mtl_request; + ssize_t ret = OMPI_SUCCESS; + ompi_mtl_ofi_request_t *ofi_req = (ompi_mtl_ofi_request_t *) mtl_request; + int ompi_ret; + void *start; + size_t length; + bool free_after; + uint64_t match_bits; + ompi_proc_t *ompi_proc = NULL; + mca_mtl_ofi_endpoint_t *endpoint = NULL; + ompi_mtl_ofi_request_t *ack_req = NULL; /* For synchronous send */ + fi_addr_t src_addr = 0; ofi_req->event_callback = ompi_mtl_ofi_isend_callback; ofi_req->error_callback = ompi_mtl_ofi_send_error_callback; - ret = ompi_mtl_ofi_send_start(mtl, comm, dest, tag, - convertor, mode, ofi_req); + ompi_proc = ompi_comm_peer_lookup(comm, dest); + endpoint = ompi_mtl_ofi_get_endpoint(mtl, ompi_proc); - if (OPAL_UNLIKELY(OMPI_SUCCESS != ret && NULL != ofi_req->buffer)) { + ompi_ret = ompi_mtl_datatype_pack(convertor, &start, &length, &free_after); + if (OMPI_SUCCESS != ompi_ret) return ompi_ret; + + ofi_req->buffer = (free_after) ? start : NULL; + ofi_req->length = length; + ofi_req->status.MPI_ERROR = OMPI_SUCCESS; + ofi_req->completion_count = 1; + + if (ompi_mtl_ofi.fi_cq_data) { + match_bits = mtl_ofi_create_send_tag_CQD(comm->c_contextid, tag); + src_addr = endpoint->peer_fiaddr; + } else { + match_bits = mtl_ofi_create_send_tag(comm->c_contextid, + comm->c_my_rank, tag); + /* src_addr is ignored when FI_DIRECTED_RECV is not supported */ + } + + if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_SYNCHRONOUS == mode)) { + ofi_req->status.MPI_ERROR = ompi_mtl_ofi_ssend_recv(ack_req, comm, &src_addr, + ofi_req, endpoint, + &match_bits, tag); + if (OPAL_UNLIKELY(ofi_req->status.MPI_ERROR != OMPI_SUCCESS)) + goto free_request_buffer; + } + + if (ompi_mtl_ofi.fi_cq_data) { + MTL_OFI_RETRY_UNTIL_DONE(fi_tsenddata(ompi_mtl_ofi.ep, + start, + length, + NULL, + comm->c_my_rank, + endpoint->peer_fiaddr, + match_bits, + (void *) &ofi_req->ctx), ret); + } else { + MTL_OFI_RETRY_UNTIL_DONE(fi_tsend(ompi_mtl_ofi.ep, + start, + length, + NULL, + endpoint->peer_fiaddr, + match_bits, + (void *) &ofi_req->ctx), ret); + } + if (OPAL_UNLIKELY(0 > ret)) { + char *fi_api; + if (ompi_mtl_ofi.fi_cq_data) { + asprintf( &fi_api, "fi_tsendddata") ; + } + else { + asprintf( &fi_api, "fi_send") ; + } + opal_output_verbose(1, ompi_mtl_base_framework.framework_output, + "%s:%d: %s failed: %s(%zd)", + __FILE__, __LINE__,fi_api, fi_strerror(-ret), ret); + free(fi_api); + ofi_req->status.MPI_ERROR = ompi_mtl_ofi_get_error(ret); + } + +free_request_buffer: + if (OPAL_UNLIKELY(OMPI_SUCCESS != ofi_req->status.MPI_ERROR + && NULL != ofi_req->buffer)) { free(ofi_req->buffer); } - return ret; + return ofi_req->status.MPI_ERROR; } /**