From 2547e24c55f5082bf124071181b434a0fb6c2be5 Mon Sep 17 00:00:00 2001 From: Devendar Bureddy Date: Tue, 30 Jun 2020 00:26:35 +0300 Subject: [PATCH] UCX: initialize cuda from ucx pml component Signed-off-by: Devendar Bureddy --- config/ompi_check_ucx.m4 | 4 ++++ ompi/mca/pml/ucx/pml_ucx.c | 20 +++++++++++++++++++- ompi/mca/pml/ucx/pml_ucx.h | 1 + 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/config/ompi_check_ucx.m4 b/config/ompi_check_ucx.m4 index db07020fdc..07bfc7ad48 100644 --- a/config/ompi_check_ucx.m4 +++ b/config/ompi_check_ucx.m4 @@ -129,6 +129,10 @@ AC_DEFUN([OMPI_CHECK_UCX],[ [AC_DEFINE([HAVE_UCP_WORKER_ADDRESS_FLAGS], [1], [have worker address attribute])], [], [#include ]) + AC_CHECK_DECLS([UCP_ATTR_FIELD_MEMORY_TYPES], + [AC_DEFINE([HAVE_UCP_ATTR_MEMORY_TYPES], [1], + [have memory types attribute])], [], + [#include ]) AC_CHECK_DECLS([ucp_tag_send_nbx, ucp_tag_send_sync_nbx, ucp_tag_recv_nbx], diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index 3db38f694c..7e5b566ff9 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -22,6 +22,9 @@ #include "ompi/message/message.h" #include "ompi/mca/pml/base/pml_base_bsend.h" #include "opal/mca/common/ucx/common_ucx.h" +#if OPAL_CUDA_SUPPORT +#include "opal/mca/common/cuda/common_cuda.h" +#endif /* OPAL_CUDA_SUPPORT */ #include "pml_ucx_request.h" #include @@ -230,6 +233,9 @@ int mca_pml_ucx_open(void) /* Query UCX attributes */ attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE; +#if HAVE_UCP_ATTR_MEMORY_TYPES + attr.field_mask |= UCP_ATTR_FIELD_MEMORY_TYPES; +#endif status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr); if (UCS_OK != status) { ucp_cleanup(ompi_pml_ucx.ucp_context); @@ -237,8 +243,15 @@ int mca_pml_ucx_open(void) return OMPI_ERROR; } - ompi_pml_ucx.request_size = attr.request_size; + ompi_pml_ucx.request_size = attr.request_size; + ompi_pml_ucx.cuda_initialized = false; +#if HAVE_UCP_ATTR_MEMORY_TYPES && OPAL_CUDA_SUPPORT + if (attr.memory_types & UCS_BIT(UCS_MEMORY_TYPE_CUDA)) { + mca_common_cuda_stage_one_init(); + ompi_pml_ucx.cuda_initialized = true; + } +#endif return OMPI_SUCCESS; } @@ -246,6 +259,11 @@ int mca_pml_ucx_close(void) { PML_UCX_VERBOSE(1, "mca_pml_ucx_close"); +#if OPAL_CUDA_SUPPORT + if (ompi_pml_ucx.cuda_initialized) { + mca_common_cuda_fini(); + } +#endif if (ompi_pml_ucx.ucp_context != NULL) { ucp_cleanup(ompi_pml_ucx.ucp_context); ompi_pml_ucx.ucp_context = NULL; diff --git a/ompi/mca/pml/ucx/pml_ucx.h b/ompi/mca/pml/ucx/pml_ucx.h index f073b56a54..39ab15e9d1 100644 --- a/ompi/mca/pml/ucx/pml_ucx.h +++ b/ompi/mca/pml/ucx/pml_ucx.h @@ -57,6 +57,7 @@ struct mca_pml_ucx_module { mca_pml_ucx_freelist_t convs; int priority; + bool cuda_initialized; }; extern mca_pml_base_component_2_0_0_t mca_pml_ucx_component;