diff --git a/ompi/mca/common/cuda/common_cuda.c b/ompi/mca/common/cuda/common_cuda.c index 385d299459..94f6a625ca 100644 --- a/ompi/mca/common/cuda/common_cuda.c +++ b/ompi/mca/common/cuda/common_cuda.c @@ -94,6 +94,7 @@ struct cudaFunctionTable { #if OPAL_CUDA_SUPPORT_60 int (*cuPointerSetAttribute)(const void *, CUpointer_attribute, CUdeviceptr); #endif /* OPAL_CUDA_SUPPORT_60 */ + int (*cuCtxSetCurrent)(CUcontext); } cudaFunctionTable; typedef struct cudaFunctionTable cudaFunctionTable_t; cudaFunctionTable_t cuFunc; @@ -454,6 +455,7 @@ int mca_common_cuda_stage_one_init(void) #if OPAL_CUDA_SUPPORT_60 OMPI_CUDA_DLSYM(libcuda_handle, cuPointerSetAttribute); #endif /* OPAL_CUDA_SUPPORT_60 */ + OMPI_CUDA_DLSYM(libcuda_handle, cuCtxSetCurrent); return 0; } @@ -1537,7 +1539,7 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf) "res=%d, ptr=%p aborting...", res, pUserBuf); ompi_rte_abort(1, NULL); } else { - res = cuCtxSetCurrent(ctx); + res = cuFunc.cuCtxSetCurrent(ctx); if (res != CUDA_SUCCESS) { opal_output(0, "CUDA: error calling cuCtxSetCurrent: " "res=%d, ptr=%p aborting...", res, pUserBuf);