diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index 1f37212c58..d92ad5a58c 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -16,6 +16,7 @@ #include "opal/runtime/opal.h" #include "opal/mca/pmix/pmix.h" +#include "ompi/attribute/attribute.h" #include "ompi/message/message.h" #include "ompi/mca/pml/base/pml_base_bsend.h" #include "opal/mca/common/ucx/common_ucx.h" @@ -190,9 +191,9 @@ int mca_pml_ucx_close(void) int mca_pml_ucx_init(void) { ucp_worker_params_t params; - ucs_status_t status; ucp_worker_attr_t attr; - int rc; + ucs_status_t status; + int i, rc; PML_UCX_VERBOSE(1, "mca_pml_ucx_init"); @@ -209,30 +210,34 @@ int mca_pml_ucx_init(void) &ompi_pml_ucx.ucp_worker); if (UCS_OK != status) { PML_UCX_ERROR("Failed to create UCP worker"); - return OMPI_ERROR; + rc = OMPI_ERROR; + goto err; } attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; status = ucp_worker_query(ompi_pml_ucx.ucp_worker, &attr); if (UCS_OK != status) { - ucp_worker_destroy(ompi_pml_ucx.ucp_worker); - ompi_pml_ucx.ucp_worker = NULL; PML_UCX_ERROR("Failed to query UCP worker thread level"); - return OMPI_ERROR; + rc = OMPI_ERROR; + goto err_destroy_worker; } - if (ompi_mpi_thread_multiple && attr.thread_mode != UCS_THREAD_MODE_MULTI) { + if (ompi_mpi_thread_multiple && (attr.thread_mode != UCS_THREAD_MODE_MULTI)) { /* UCX does not support multithreading, disqualify current PML for now */ /* TODO: we should let OMPI to fallback to THREAD_SINGLE mode */ - ucp_worker_destroy(ompi_pml_ucx.ucp_worker); - ompi_pml_ucx.ucp_worker = NULL; PML_UCX_ERROR("UCP worker does not support MPI_THREAD_MULTIPLE"); - return OMPI_ERROR; + rc = OMPI_ERR_NOT_SUPPORTED; + goto err_destroy_worker; } rc = mca_pml_ucx_send_worker_address(); if (rc < 0) { - return rc; + goto err_destroy_worker; + } + + ompi_pml_ucx.datatype_attr_keyval = MPI_KEYVAL_INVALID; + for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) { + ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID; } /* Initialize the free lists */ @@ -249,14 +254,33 @@ int mca_pml_ucx_init(void) (void *)ompi_pml_ucx.ucp_context, (void *)ompi_pml_ucx.ucp_worker); return OMPI_SUCCESS; + +err_destroy_worker: + ucp_worker_destroy(ompi_pml_ucx.ucp_worker); + ompi_pml_ucx.ucp_worker = NULL; +err: + return OMPI_ERROR; } int mca_pml_ucx_cleanup(void) { + int i; + PML_UCX_VERBOSE(1, "mca_pml_ucx_cleanup"); opal_progress_unregister(mca_pml_ucx_progress); + if (ompi_pml_ucx.datatype_attr_keyval != MPI_KEYVAL_INVALID) { + ompi_attr_free_keyval(TYPE_ATTR, &ompi_pml_ucx.datatype_attr_keyval, false); + } + + for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) { + if (ompi_pml_ucx.predefined_types[i] != PML_UCX_DATATYPE_INVALID) { + ucp_dt_destroy(ompi_pml_ucx.predefined_types[i]); + ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID; + } + } + ompi_pml_ucx.completed_send_req.req_state = OMPI_REQUEST_INVALID; OMPI_REQUEST_FINI(&ompi_pml_ucx.completed_send_req); OBJ_DESTRUCT(&ompi_pml_ucx.completed_send_req); @@ -398,6 +422,22 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs) int mca_pml_ucx_enable(bool enable) { + ompi_attribute_fn_ptr_union_t copy_fn; + ompi_attribute_fn_ptr_union_t del_fn; + int ret; + + /* Create a key for adding custom attributes to datatypes */ + copy_fn.attr_datatype_copy_fn = + (MPI_Type_internal_copy_attr_function*)MPI_TYPE_NULL_COPY_FN; + del_fn.attr_datatype_delete_fn = mca_pml_ucx_datatype_attr_del_fn; + ret = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn, + &ompi_pml_ucx.datatype_attr_keyval, NULL, 0, + NULL); + if (ret != OMPI_SUCCESS) { + PML_UCX_ERROR("Failed to create keyval for UCX datatypes: %d", ret); + return ret; + } + PML_UCX_FREELIST_INIT(&ompi_pml_ucx.persistent_reqs, mca_pml_ucx_persistent_request_t, 128, -1, 128); diff --git a/ompi/mca/pml/ucx/pml_ucx.h b/ompi/mca/pml/ucx/pml_ucx.h index da1b3ef0c5..484ad5ebe1 100644 --- a/ompi/mca/pml/ucx/pml_ucx.h +++ b/ompi/mca/pml/ucx/pml_ucx.h @@ -15,6 +15,7 @@ #include "ompi/mca/pml/pml.h" #include "ompi/mca/pml/base/base.h" #include "ompi/datatype/ompi_datatype.h" +#include "ompi/datatype/ompi_datatype_internal.h" #include "ompi/communicator/communicator.h" #include "ompi/request/request.h" #include "opal/mca/common/ucx/common_ucx.h" @@ -42,6 +43,10 @@ struct mca_pml_ucx_module { ucp_context_h ucp_context; ucp_worker_h ucp_worker; + /* Datatypes */ + int datatype_attr_keyval; + ucp_datatype_t predefined_types[OMPI_DATATYPE_MPI_MAX_PREDEFINED]; + /* Requests */ mca_pml_ucx_freelist_t persistent_reqs; ompi_request_t completed_send_req; diff --git a/ompi/mca/pml/ucx/pml_ucx_datatype.c b/ompi/mca/pml/ucx/pml_ucx_datatype.c index 98b7b190df..74b5fbe19c 100644 --- a/ompi/mca/pml/ucx/pml_ucx_datatype.c +++ b/ompi/mca/pml/ucx/pml_ucx_datatype.c @@ -10,6 +10,7 @@ #include "pml_ucx_datatype.h" #include "ompi/runtime/mpiruntime.h" +#include "ompi/attribute/attribute.h" #include @@ -127,12 +128,25 @@ static ucp_generic_dt_ops_t pml_ucx_generic_datatype_ops = { .finish = pml_ucx_generic_datatype_finish }; +int mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval, + void *attr_val, void *extra) +{ + ucp_datatype_t ucp_datatype = (ucp_datatype_t)attr_val; + + PML_UCX_ASSERT((void*)ucp_datatype == datatype->pml_data); + + ucp_dt_destroy(ucp_datatype); + datatype->pml_data = PML_UCX_DATATYPE_INVALID; + return OMPI_SUCCESS; +} + ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype) { ucp_datatype_t ucp_datatype; ucs_status_t status; ptrdiff_t lb; size_t size; + int ret; ompi_datatype_type_lb(datatype, &lb); @@ -147,16 +161,33 @@ ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype) } status = ucp_dt_create_generic(&pml_ucx_generic_datatype_ops, - datatype, &ucp_datatype); + datatype, &ucp_datatype); if (status != UCS_OK) { PML_UCX_ERROR("Failed to create UCX datatype for %s", datatype->name); ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1); } - PML_UCX_VERBOSE(7, "created generic UCX datatype 0x%"PRIx64, ucp_datatype) - // TODO put this on a list to be destroyed later - datatype->pml_data = ucp_datatype; + + /* Add custom attribute, to clean up UCX resources when OMPI datatype is + * released. + */ + if (ompi_datatype_is_predefined(datatype)) { + PML_UCX_ASSERT(datatype->id < OMPI_DATATYPE_MAX_PREDEFINED); + ompi_pml_ucx.predefined_types[datatype->id] = ucp_datatype; + } else { + ret = ompi_attr_set_c(TYPE_ATTR, datatype, &datatype->d_keyhash, + ompi_pml_ucx.datatype_attr_keyval, + (void*)ucp_datatype, false); + if (ret != OMPI_SUCCESS) { + PML_UCX_ERROR("Failed to add UCX datatype attribute for %s: %d", + datatype->name, ret); + ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1); + } + } + + PML_UCX_VERBOSE(7, "created generic UCX datatype 0x%"PRIx64, ucp_datatype) + return ucp_datatype; } diff --git a/ompi/mca/pml/ucx/pml_ucx_datatype.h b/ompi/mca/pml/ucx/pml_ucx_datatype.h index 26b1835a15..f5207cecc7 100644 --- a/ompi/mca/pml/ucx/pml_ucx_datatype.h +++ b/ompi/mca/pml/ucx/pml_ucx_datatype.h @@ -13,6 +13,8 @@ #include "pml_ucx.h" +#define PML_UCX_DATATYPE_INVALID 0 + struct pml_ucx_convertor { opal_free_list_item_t super; ompi_datatype_t *datatype; @@ -23,6 +25,9 @@ struct pml_ucx_convertor { ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype); +int mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval, + void *attr_val, void *extra); + OBJ_CLASS_DECLARATION(mca_pml_ucx_convertor_t); @@ -30,7 +35,7 @@ static inline ucp_datatype_t mca_pml_ucx_get_datatype(ompi_datatype_t *datatype) { ucp_datatype_t ucp_type = datatype->pml_data; - if (OPAL_LIKELY(ucp_type != 0)) { + if (OPAL_LIKELY(ucp_type != PML_UCX_DATATYPE_INVALID)) { return ucp_type; }