From 434c9055ee9bc2d951148def8e9fd2b2dba3b6fc Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Fri, 3 Apr 2020 12:03:17 +0200 Subject: [PATCH] UCX osc: fall back to get-compare-put for unsupported datatypes Signed-off-by: Joseph Schuchart --- ompi/mca/osc/ucx/osc_ucx_comm.c | 81 +++++++++++++++++++++++++------- opal/mca/common/ucx/common_ucx.h | 2 +- 2 files changed, 64 insertions(+), 19 deletions(-) diff --git a/ompi/mca/osc/ucx/osc_ucx_comm.c b/ompi/mca/osc/ucx/osc_ucx_comm.c index 6bafc4ea0b..46ce011599 100644 --- a/ompi/mca/osc/ucx/osc_ucx_comm.c +++ b/ompi/mca/osc/ucx/osc_ucx_comm.c @@ -722,6 +722,36 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count, target_disp, target_count, target_dt, op, win, NULL); } +static int +do_atomic_compare_and_swap(const void *origin_addr, const void *compare_addr, + void *result_addr, struct ompi_datatype_t *dt, + int target, uint64_t remote_addr, + ompi_osc_ucx_module_t *module) +{ + int ret; + bool lock_acquired = false; + size_t dt_bytes; + if (!module->acc_single_intrinsic) { + ret = start_atomicity(module, target, &lock_acquired); + if (ret != OMPI_SUCCESS) { + return ret; + } + } + + ompi_datatype_type_size(dt, &dt_bytes); + uint64_t compare_val = opal_common_ucx_load_uint64(compare_addr, dt_bytes); + uint64_t value = opal_common_ucx_load_uint64(origin_addr, dt_bytes); + ret = opal_common_ucx_wpmem_cmpswp_nb(module->mem, compare_val, value, target, + result_addr, dt_bytes, remote_addr, + NULL, NULL); + + if (module->acc_single_intrinsic) { + return ret; + } + + return end_atomicity(module, target, lock_acquired, NULL); +} + int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr, void *result_addr, struct ompi_datatype_t *dt, int target, ptrdiff_t target_disp, @@ -732,23 +762,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a int ret = OMPI_SUCCESS; bool lock_acquired = false; - ompi_datatype_type_size(dt, &dt_bytes); - if (sizeof(uint64_t) < dt_bytes) { - return OMPI_ERR_NOT_SUPPORTED; - } - ret = check_sync_state(module, target, false); if (ret != OMPI_SUCCESS) { return ret; } - if (!module->acc_single_intrinsic) { - ret = start_atomicity(module, target, &lock_acquired); - if (ret != OMPI_SUCCESS) { - return ret; - } - } - if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { ret = get_dynamic_win_info(remote_addr, module, target); if (ret != OMPI_SUCCESS) { @@ -756,16 +774,43 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a } } - uint64_t compare_val = opal_common_ucx_load_uint64(compare_addr, dt_bytes); - uint64_t value = opal_common_ucx_load_uint64(origin_addr, dt_bytes); - ret = opal_common_ucx_wpmem_cmpswp_nb(module->mem, compare_val, value, target, - result_addr, dt_bytes, remote_addr, - NULL, NULL); + ompi_datatype_type_size(dt, &dt_bytes); + if (4 == dt_bytes || 8 == dt_bytes) { + // fast path using UCX atomic operations + return do_atomic_compare_and_swap(origin_addr, compare_addr, + result_addr, dt, target, + remote_addr, module); + } - if (module->acc_single_intrinsic) { + /* fall back to get-compare-put */ + + ret = start_atomicity(module, target, &lock_acquired); + if (ret != OMPI_SUCCESS) { return ret; } + ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET, target, + &result_addr, dt_bytes, remote_addr); + if (OPAL_SUCCESS != ret) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret); + return OMPI_ERROR; + } + + ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target); + if (ret != OPAL_SUCCESS) { + return ret; + } + + if (0 == memcmp(result_addr, compare_addr, dt_bytes)) { + // write the new value + ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_PUT, target, + (void*)origin_addr, dt_bytes, remote_addr); + if (OPAL_SUCCESS != ret) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret); + return OMPI_ERROR; + } + } + return end_atomicity(module, target, lock_acquired, NULL); } diff --git a/opal/mca/common/ucx/common_ucx.h b/opal/mca/common/ucx/common_ucx.h index f877742a2d..2baf6ea946 100644 --- a/opal/mca/common/ucx/common_ucx.h +++ b/opal/mca/common/ucx/common_ucx.h @@ -120,7 +120,7 @@ OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t * * Load an integer value of \c size bytes from \c ptr and cast it to uint64_t. */ static inline -uint64_t opal_common_ucx_load_uint64(void *ptr, size_t size) +uint64_t opal_common_ucx_load_uint64(const void *ptr, size_t size) { if (sizeof(uint8_t) == size) { return *(uint8_t*)ptr;