1
1

UCX osc: fall back to get-compare-put for unsupported datatypes

Signed-off-by: Joseph Schuchart <schuchart@hlrs.de>
Этот коммит содержится в:
Joseph Schuchart 2020-04-03 12:03:17 +02:00
родитель 7d5a6e3e8b
Коммит 434c9055ee
2 изменённых файлов: 64 добавлений и 19 удалений

Просмотреть файл

@ -722,6 +722,36 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
target_disp, target_count, target_dt, op, win, NULL);
}
static int
do_atomic_compare_and_swap(const void *origin_addr, const void *compare_addr,
void *result_addr, struct ompi_datatype_t *dt,
int target, uint64_t remote_addr,
ompi_osc_ucx_module_t *module)
{
int ret;
bool lock_acquired = false;
size_t dt_bytes;
if (!module->acc_single_intrinsic) {
ret = start_atomicity(module, target, &lock_acquired);
if (ret != OMPI_SUCCESS) {
return ret;
}
}
ompi_datatype_type_size(dt, &dt_bytes);
uint64_t compare_val = opal_common_ucx_load_uint64(compare_addr, dt_bytes);
uint64_t value = opal_common_ucx_load_uint64(origin_addr, dt_bytes);
ret = opal_common_ucx_wpmem_cmpswp_nb(module->mem, compare_val, value, target,
result_addr, dt_bytes, remote_addr,
NULL, NULL);
if (module->acc_single_intrinsic) {
return ret;
}
return end_atomicity(module, target, lock_acquired, NULL);
}
int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr,
void *result_addr, struct ompi_datatype_t *dt,
int target, ptrdiff_t target_disp,
@ -732,23 +762,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
int ret = OMPI_SUCCESS;
bool lock_acquired = false;
ompi_datatype_type_size(dt, &dt_bytes);
if (sizeof(uint64_t) < dt_bytes) {
return OMPI_ERR_NOT_SUPPORTED;
}
ret = check_sync_state(module, target, false);
if (ret != OMPI_SUCCESS) {
return ret;
}
if (!module->acc_single_intrinsic) {
ret = start_atomicity(module, target, &lock_acquired);
if (ret != OMPI_SUCCESS) {
return ret;
}
}
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
ret = get_dynamic_win_info(remote_addr, module, target);
if (ret != OMPI_SUCCESS) {
@ -756,16 +774,43 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
}
}
uint64_t compare_val = opal_common_ucx_load_uint64(compare_addr, dt_bytes);
uint64_t value = opal_common_ucx_load_uint64(origin_addr, dt_bytes);
ret = opal_common_ucx_wpmem_cmpswp_nb(module->mem, compare_val, value, target,
result_addr, dt_bytes, remote_addr,
NULL, NULL);
ompi_datatype_type_size(dt, &dt_bytes);
if (4 == dt_bytes || 8 == dt_bytes) {
// fast path using UCX atomic operations
return do_atomic_compare_and_swap(origin_addr, compare_addr,
result_addr, dt, target,
remote_addr, module);
}
if (module->acc_single_intrinsic) {
/* fall back to get-compare-put */
ret = start_atomicity(module, target, &lock_acquired);
if (ret != OMPI_SUCCESS) {
return ret;
}
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET, target,
&result_addr, dt_bytes, remote_addr);
if (OPAL_SUCCESS != ret) {
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
return OMPI_ERROR;
}
ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
if (ret != OPAL_SUCCESS) {
return ret;
}
if (0 == memcmp(result_addr, compare_addr, dt_bytes)) {
// write the new value
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_PUT, target,
(void*)origin_addr, dt_bytes, remote_addr);
if (OPAL_SUCCESS != ret) {
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
return OMPI_ERROR;
}
}
return end_atomicity(module, target, lock_acquired, NULL);
}

Просмотреть файл

@ -120,7 +120,7 @@ OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *
* Load an integer value of \c size bytes from \c ptr and cast it to uint64_t.
*/
static inline
uint64_t opal_common_ucx_load_uint64(void *ptr, size_t size)
uint64_t opal_common_ucx_load_uint64(const void *ptr, size_t size)
{
if (sizeof(uint8_t) == size) {
return *(uint8_t*)ptr;