Merge pull request #7893 from bureddy/cuda-ucx
UCX: initialize cuda from ucx pml component
Этот коммит содержится в:
Коммит
aa8f7f4ede
@ -129,6 +129,10 @@ AC_DEFUN([OMPI_CHECK_UCX],[
|
||||
[AC_DEFINE([HAVE_UCP_WORKER_ADDRESS_FLAGS], [1],
|
||||
[have worker address attribute])], [],
|
||||
[#include <ucp/api/ucp.h>])
|
||||
AC_CHECK_DECLS([UCP_ATTR_FIELD_MEMORY_TYPES],
|
||||
[AC_DEFINE([HAVE_UCP_ATTR_MEMORY_TYPES], [1],
|
||||
[have memory types attribute])], [],
|
||||
[#include <ucp/api/ucp.h>])
|
||||
AC_CHECK_DECLS([ucp_tag_send_nbx,
|
||||
ucp_tag_send_sync_nbx,
|
||||
ucp_tag_recv_nbx],
|
||||
|
@ -22,6 +22,9 @@
|
||||
#include "ompi/message/message.h"
|
||||
#include "ompi/mca/pml/base/pml_base_bsend.h"
|
||||
#include "opal/mca/common/ucx/common_ucx.h"
|
||||
#if OPAL_CUDA_SUPPORT
|
||||
#include "opal/mca/common/cuda/common_cuda.h"
|
||||
#endif /* OPAL_CUDA_SUPPORT */
|
||||
#include "pml_ucx_request.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
@ -230,6 +233,9 @@ int mca_pml_ucx_open(void)
|
||||
|
||||
/* Query UCX attributes */
|
||||
attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
|
||||
#if HAVE_UCP_ATTR_MEMORY_TYPES
|
||||
attr.field_mask |= UCP_ATTR_FIELD_MEMORY_TYPES;
|
||||
#endif
|
||||
status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
|
||||
if (UCS_OK != status) {
|
||||
ucp_cleanup(ompi_pml_ucx.ucp_context);
|
||||
@ -237,8 +243,15 @@ int mca_pml_ucx_open(void)
|
||||
return OMPI_ERROR;
|
||||
}
|
||||
|
||||
ompi_pml_ucx.request_size = attr.request_size;
|
||||
ompi_pml_ucx.request_size = attr.request_size;
|
||||
ompi_pml_ucx.cuda_initialized = false;
|
||||
|
||||
#if HAVE_UCP_ATTR_MEMORY_TYPES && OPAL_CUDA_SUPPORT
|
||||
if (attr.memory_types & UCS_BIT(UCS_MEMORY_TYPE_CUDA)) {
|
||||
mca_common_cuda_stage_one_init();
|
||||
ompi_pml_ucx.cuda_initialized = true;
|
||||
}
|
||||
#endif
|
||||
return OMPI_SUCCESS;
|
||||
}
|
||||
|
||||
@ -246,6 +259,11 @@ int mca_pml_ucx_close(void)
|
||||
{
|
||||
PML_UCX_VERBOSE(1, "mca_pml_ucx_close");
|
||||
|
||||
#if OPAL_CUDA_SUPPORT
|
||||
if (ompi_pml_ucx.cuda_initialized) {
|
||||
mca_common_cuda_fini();
|
||||
}
|
||||
#endif
|
||||
if (ompi_pml_ucx.ucp_context != NULL) {
|
||||
ucp_cleanup(ompi_pml_ucx.ucp_context);
|
||||
ompi_pml_ucx.ucp_context = NULL;
|
||||
|
@ -57,6 +57,7 @@ struct mca_pml_ucx_module {
|
||||
mca_pml_ucx_freelist_t convs;
|
||||
|
||||
int priority;
|
||||
bool cuda_initialized;
|
||||
};
|
||||
|
||||
extern mca_pml_base_component_2_0_0_t mca_pml_ucx_component;
|
||||
|
Загрузка…
x
Ссылка в новой задаче
Block a user