UCX osc: fall back to get-compare-put for unsupported datatypes
Signed-off-by: Joseph Schuchart <schuchart@hlrs.de>
Этот коммит содержится в:
родитель
7d5a6e3e8b
Коммит
434c9055ee
@ -722,6 +722,36 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
|
||||
target_disp, target_count, target_dt, op, win, NULL);
|
||||
}
|
||||
|
||||
static int
|
||||
do_atomic_compare_and_swap(const void *origin_addr, const void *compare_addr,
|
||||
void *result_addr, struct ompi_datatype_t *dt,
|
||||
int target, uint64_t remote_addr,
|
||||
ompi_osc_ucx_module_t *module)
|
||||
{
|
||||
int ret;
|
||||
bool lock_acquired = false;
|
||||
size_t dt_bytes;
|
||||
if (!module->acc_single_intrinsic) {
|
||||
ret = start_atomicity(module, target, &lock_acquired);
|
||||
if (ret != OMPI_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
ompi_datatype_type_size(dt, &dt_bytes);
|
||||
uint64_t compare_val = opal_common_ucx_load_uint64(compare_addr, dt_bytes);
|
||||
uint64_t value = opal_common_ucx_load_uint64(origin_addr, dt_bytes);
|
||||
ret = opal_common_ucx_wpmem_cmpswp_nb(module->mem, compare_val, value, target,
|
||||
result_addr, dt_bytes, remote_addr,
|
||||
NULL, NULL);
|
||||
|
||||
if (module->acc_single_intrinsic) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
return end_atomicity(module, target, lock_acquired, NULL);
|
||||
}
|
||||
|
||||
int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr,
|
||||
void *result_addr, struct ompi_datatype_t *dt,
|
||||
int target, ptrdiff_t target_disp,
|
||||
@ -732,23 +762,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
|
||||
int ret = OMPI_SUCCESS;
|
||||
bool lock_acquired = false;
|
||||
|
||||
ompi_datatype_type_size(dt, &dt_bytes);
|
||||
if (sizeof(uint64_t) < dt_bytes) {
|
||||
return OMPI_ERR_NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
ret = check_sync_state(module, target, false);
|
||||
if (ret != OMPI_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (!module->acc_single_intrinsic) {
|
||||
ret = start_atomicity(module, target, &lock_acquired);
|
||||
if (ret != OMPI_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
|
||||
ret = get_dynamic_win_info(remote_addr, module, target);
|
||||
if (ret != OMPI_SUCCESS) {
|
||||
@ -756,16 +774,43 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t compare_val = opal_common_ucx_load_uint64(compare_addr, dt_bytes);
|
||||
uint64_t value = opal_common_ucx_load_uint64(origin_addr, dt_bytes);
|
||||
ret = opal_common_ucx_wpmem_cmpswp_nb(module->mem, compare_val, value, target,
|
||||
result_addr, dt_bytes, remote_addr,
|
||||
NULL, NULL);
|
||||
ompi_datatype_type_size(dt, &dt_bytes);
|
||||
if (4 == dt_bytes || 8 == dt_bytes) {
|
||||
// fast path using UCX atomic operations
|
||||
return do_atomic_compare_and_swap(origin_addr, compare_addr,
|
||||
result_addr, dt, target,
|
||||
remote_addr, module);
|
||||
}
|
||||
|
||||
if (module->acc_single_intrinsic) {
|
||||
/* fall back to get-compare-put */
|
||||
|
||||
ret = start_atomicity(module, target, &lock_acquired);
|
||||
if (ret != OMPI_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET, target,
|
||||
&result_addr, dt_bytes, remote_addr);
|
||||
if (OPAL_SUCCESS != ret) {
|
||||
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
|
||||
return OMPI_ERROR;
|
||||
}
|
||||
|
||||
ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
|
||||
if (ret != OPAL_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (0 == memcmp(result_addr, compare_addr, dt_bytes)) {
|
||||
// write the new value
|
||||
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_PUT, target,
|
||||
(void*)origin_addr, dt_bytes, remote_addr);
|
||||
if (OPAL_SUCCESS != ret) {
|
||||
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
|
||||
return OMPI_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
return end_atomicity(module, target, lock_acquired, NULL);
|
||||
}
|
||||
|
||||
|
@ -120,7 +120,7 @@ OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *
|
||||
* Load an integer value of \c size bytes from \c ptr and cast it to uint64_t.
|
||||
*/
|
||||
static inline
|
||||
uint64_t opal_common_ucx_load_uint64(void *ptr, size_t size)
|
||||
uint64_t opal_common_ucx_load_uint64(const void *ptr, size_t size)
|
||||
{
|
||||
if (sizeof(uint8_t) == size) {
|
||||
return *(uint8_t*)ptr;
|
||||
|
Загрузка…
x
Ссылка в новой задаче
Block a user