diff --git a/ompi/mca/osc/ucx/osc_ucx_component.c b/ompi/mca/osc/ucx/osc_ucx_component.c index 4ce208f2b7..4a36b3eefb 100644 --- a/ompi/mca/osc/ucx/osc_ucx_component.c +++ b/ompi/mca/osc/ucx/osc_ucx_component.c @@ -24,6 +24,21 @@ memcpy(((char*)(_dst)) + (_off), _src, _len); \ (_off) += (_len); +opal_mutex_t mca_osc_service_mutex = OPAL_MUTEX_STATIC_INIT; +static void _osc_ucx_init_lock(void) +{ + if(mca_osc_ucx_component.enable_mpi_threads) { + opal_mutex_lock(&mca_osc_service_mutex); + } +} +static void _osc_ucx_init_unlock(void) +{ + if(mca_osc_ucx_component.enable_mpi_threads) { + opal_mutex_unlock(&mca_osc_service_mutex); + } +} + + static int component_open(void); static int component_register(void); static int component_init(bool enable_progress_threads, bool enable_mpi_threads); @@ -192,6 +207,9 @@ static void ompi_osc_ucx_unregister_progress() { int ret; + /* May be called concurrently - protect */ + _osc_ucx_init_lock(); + mca_osc_ucx_component.num_modules--; OSC_UCX_ASSERT(mca_osc_ucx_component.num_modules >= 0); if (0 == mca_osc_ucx_component.num_modules) { @@ -200,6 +218,8 @@ static void ompi_osc_ucx_unregister_progress() OSC_UCX_VERBOSE(1, "opal_progress_unregister failed: %d", ret); } } + + _osc_ucx_init_unlock(); } static int component_select(struct ompi_win_t *win, void **base, size_t size, int disp_unit, @@ -226,7 +246,14 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in return OMPI_ERR_NOT_SUPPORTED; } + /* May be called concurrently - protect */ + _osc_ucx_init_lock(); + if (mca_osc_ucx_component.env_initialized == false) { + /* Lazy initialization of the global state. + * As not all of the MPI applications are using One-Sided functionality + * we don't want to initialize in the component_init() + */ OBJ_CONSTRUCT(&mca_osc_ucx_component.requests, opal_free_list_t); ret = opal_free_list_init (&mca_osc_ucx_component.requests, @@ -236,7 +263,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in 0, 0, 8, 0, 8, NULL, 0, NULL, NULL, NULL); if (OMPI_SUCCESS != ret) { OSC_UCX_VERBOSE(1, "opal_free_list_init failed: %d", ret); - goto error; + goto select_unlock; } ret = opal_common_ucx_wpool_init(mca_osc_ucx_component.wpool, @@ -244,13 +271,37 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in mca_osc_ucx_component.enable_mpi_threads); if (OMPI_SUCCESS != ret) { OSC_UCX_VERBOSE(1, "opal_common_ucx_wpool_init failed: %d", ret); - goto error; + goto select_unlock; } + /* Make sure that all memory updates performed above are globally + * observable before (mca_osc_ucx_component.env_initialized = true) + */ mca_osc_ucx_component.env_initialized = true; env_initialized = true; } + /* Account for the number of active "modules" = MPI windows */ + mca_osc_ucx_component.num_modules++; + + /* If this is the first window to be registered - register the progress + * callback + */ + OSC_UCX_ASSERT(mca_osc_ucx_component.num_modules > 0); + if (1 == mca_osc_ucx_component.num_modules) { + ret = opal_progress_register(progress_callback); + if (OMPI_SUCCESS != ret) { + OSC_UCX_VERBOSE(1, "opal_progress_register failed: %d", ret); + goto error; + } + } + +select_unlock: + _osc_ucx_init_unlock(); + if (ret) { + goto error; + } + /* create module structure */ module = (ompi_osc_ucx_module_t *)calloc(1, sizeof(ompi_osc_ucx_module_t)); if (module == NULL) { @@ -258,8 +309,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in goto error_nomem; } - mca_osc_ucx_component.num_modules++; - /* fill in the function pointer part */ memcpy(module, &ompi_osc_ucx_module_template, sizeof(ompi_osc_base_module_t)); @@ -413,19 +462,15 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in goto error; } - OSC_UCX_ASSERT(mca_osc_ucx_component.num_modules > 0); - if (1 == mca_osc_ucx_component.num_modules) { - ret = opal_progress_register(progress_callback); - if (OMPI_SUCCESS != ret) { - OSC_UCX_VERBOSE(1, "opal_progress_register failed: %d", ret); - goto error; - } - } return ret; - error: +error: if (module->disp_units) free(module->disp_units); if (module->comm) ompi_comm_free(&module->comm); + /* We update the modules count and (if need) registering a callback right + * prior to memory allocation for the module. + * So we use it as an indirect sign here + */ if (module) { free(module); ompi_osc_ucx_unregister_progress();