diff --git a/ompi/mca/common/cuda/common_cuda.c b/ompi/mca/common/cuda/common_cuda.c index 4c71894624..460b635716 100644 --- a/ompi/mca/common/cuda/common_cuda.c +++ b/ompi/mca/common/cuda/common_cuda.c @@ -1475,6 +1475,7 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf) int res; CUmemorytype memType; CUdeviceptr dbuf = (CUdeviceptr)pUserBuf; + CUcontext ctx = NULL; res = cuFunc.cuPointerGetAttribute(&memType, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, dbuf); @@ -1489,6 +1490,42 @@ static int mca_common_cuda_is_gpu_buffer(const void *pUserBuf) /* Must be a device pointer */ assert(memType == CU_MEMORYTYPE_DEVICE); + /* This piece of code was added in to handle in a case involving + * OMP threads. The user had initialized CUDA and then spawned + * two threads. The first thread had the CUDA context, but the + * second thread did not. We therefore had no context to act upon + * and future CUDA driver calls would fail. Therefore, if we have + * GPU memory, but no context, get the context from the GPU memory + * and set the current context to that. It is rare that we will not + * have a context. */ + res = cuFunc.cuCtxGetCurrent(&ctx); + if (OPAL_UNLIKELY(NULL == ctx)) { + if (CUDA_SUCCESS == res) { + res = cuFunc.cuPointerGetAttribute(&ctx, + CU_POINTER_ATTRIBUTE_CONTEXT, dbuf); + if (res != CUDA_SUCCESS) { + opal_output(0, "CUDA: error calling cuPointerGetAttribute: " + "res=%d, ptr=%p aborting...", res, pUserBuf); + ompi_rte_abort(1, NULL); + } else { + res = cuCtxSetCurrent(ctx); + if (res != CUDA_SUCCESS) { + opal_output(0, "CUDA: error calling cuCtxSetCurrent: " + "res=%d, ptr=%p aborting...", res, pUserBuf); + ompi_rte_abort(1, NULL); + } else { + opal_output_verbose(10, mca_common_cuda_output, + "CUDA: cuCtxSetCurrent passed: ptr=%p", pUserBuf); + } + } + } else { + /* Print error and proceed */ + opal_output(0, "CUDA: error calling cuCtxGetCurrent: " + "res=%d, ptr=%p aborting...", res, pUserBuf); + ompi_rte_abort(1, NULL); + } + } + /* First access on a device pointer finalizes CUDA support initialization. * If initialization fails, disable support. */ if (!stage_three_init_complete) {