diff --git a/ompi/mca/osc/ucx/osc_ucx_active_target.c b/ompi/mca/osc/ucx/osc_ucx_active_target.c index 348d1cf701..49c72b4a50 100644 --- a/ompi/mca/osc/ucx/osc_ucx_active_target.c +++ b/ompi/mca/osc/ucx/osc_ucx_active_target.c @@ -260,7 +260,9 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int assert, struct ompi_win_t uint64_t curr_idx = 0, result = 0; /* do fop first to get an post index */ - status = ucp_atomic_fadd64(ep, 1, remote_addr, rkey, &result); + status = opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_FADD, 1, + &result, sizeof(result), + remote_addr, rkey, mca_osc_ucx_component.ucp_worker); if (status != UCS_OK) { opal_output_verbose(1, ompi_osc_base_framework.framework_output, "%s:%d: ucp_atomic_fadd64 failed: %d\n", @@ -273,8 +275,9 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int assert, struct ompi_win_t /* do cas to send post message */ do { - status = ucp_atomic_cswap64(ep, 0, (uint64_t)myrank + 1, - remote_addr, rkey, &result); + status = opal_common_ucx_atomic_cswap(ep, 0, (uint64_t)myrank + 1, &result, + sizeof(result), remote_addr, rkey, + mca_osc_ucx_component.ucp_worker); if (status != UCS_OK) { opal_output_verbose(1, ompi_osc_base_framework.framework_output, "%s:%d: ucp_atomic_cswap64 failed: %d\n", diff --git a/ompi/mca/osc/ucx/osc_ucx_comm.c b/ompi/mca/osc/ucx/osc_ucx_comm.c index 35a6af55b8..99fb8af44d 100644 --- a/ompi/mca/osc/ucx/osc_ucx_comm.c +++ b/ompi/mca/osc/ucx/osc_ucx_comm.c @@ -290,9 +290,10 @@ static inline int start_atomicity(ompi_osc_ucx_module_t *module, ucp_ep_h ep, in ucs_status_t status; while (result_value != TARGET_LOCK_UNLOCKED) { - status = ucp_atomic_cswap64(ep, TARGET_LOCK_UNLOCKED, - TARGET_LOCK_EXCLUSIVE, - remote_addr, rkey, &result_value); + status = opal_common_ucx_atomic_cswap(ep, TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE, + &result_value, sizeof(result_value), + remote_addr, rkey, + mca_osc_ucx_component.ucp_worker); if (status != UCS_OK) { opal_output_verbose(1, ompi_osc_base_framework.framework_output, "%s:%d: ucp_atomic_cswap64 failed: %d\n", @@ -310,8 +311,9 @@ static inline int end_atomicity(ompi_osc_ucx_module_t *module, ucp_ep_h ep, int uint64_t remote_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_ACC_LOCK_OFFSET; ucs_status_t status; - status = ucp_atomic_swap64(ep, TARGET_LOCK_UNLOCKED, - remote_addr, rkey, &result_value); + status = opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED, + &result_value, sizeof(result_value), + remote_addr, rkey, mca_osc_ucx_component.ucp_worker); if (status != UCS_OK) { opal_output_verbose(1, ompi_osc_base_framework.framework_output, "%s:%d: ucp_atomic_swap64 failed: %d\n", diff --git a/ompi/mca/osc/ucx/osc_ucx_passive_target.c b/ompi/mca/osc/ucx/osc_ucx_passive_target.c index 432b2b7edf..e0781da81c 100644 --- a/ompi/mca/osc/ucx/osc_ucx_passive_target.c +++ b/ompi/mca/osc/ucx/osc_ucx_passive_target.c @@ -26,7 +26,9 @@ static inline int start_shared(ompi_osc_ucx_module_t *module, int target) { ucs_status_t status; while (true) { - status = ucp_atomic_fadd64(ep, 1, remote_addr, rkey, &result_value); + status = opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_FADD, 1, + &result_value, sizeof(result_value), + remote_addr, rkey, mca_osc_ucx_component.ucp_worker); if (status != UCS_OK) { opal_output_verbose(1, ompi_osc_base_framework.framework_output, "%s:%d: ucp_atomic_fadd64 failed: %d\n", @@ -35,7 +37,8 @@ static inline int start_shared(ompi_osc_ucx_module_t *module, int target) { } assert(result_value >= 0); if (result_value >= TARGET_LOCK_EXCLUSIVE) { - status = ucp_atomic_add64(ep, (-1), remote_addr, rkey); + status = ucp_atomic_post(ep, UCP_ATOMIC_POST_OP_ADD, (-1), sizeof(uint64_t), + remote_addr, rkey); if (status != UCS_OK) { opal_output_verbose(1, ompi_osc_base_framework.framework_output, "%s:%d: ucp_atomic_add64 failed: %d\n", @@ -56,7 +59,8 @@ static inline int end_shared(ompi_osc_ucx_module_t *module, int target) { uint64_t remote_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_LOCK_OFFSET; ucs_status_t status; - status = ucp_atomic_add64(ep, (-1), remote_addr, rkey); + status = ucp_atomic_post(ep, UCP_ATOMIC_POST_OP_ADD, (-1), sizeof(uint64_t), + remote_addr, rkey); if (status != UCS_OK) { opal_output_verbose(1, ompi_osc_base_framework.framework_output, "%s:%d: ucp_atomic_add64 failed: %d\n", @@ -75,9 +79,10 @@ static inline int start_exclusive(ompi_osc_ucx_module_t *module, int target) { ucs_status_t status; while (result_value != TARGET_LOCK_UNLOCKED) { - status = ucp_atomic_cswap64(ep, TARGET_LOCK_UNLOCKED, - TARGET_LOCK_EXCLUSIVE, - remote_addr, rkey, &result_value); + status = opal_common_ucx_atomic_cswap(ep, TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE, + &result_value, sizeof(result_value), + remote_addr, rkey, + mca_osc_ucx_component.ucp_worker); if (status != UCS_OK) { opal_output_verbose(1, ompi_osc_base_framework.framework_output, "%s:%d: ucp_atomic_cswap64 failed: %d\n", @@ -96,8 +101,9 @@ static inline int end_exclusive(ompi_osc_ucx_module_t *module, int target) { uint64_t remote_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_LOCK_OFFSET; ucs_status_t status; - status = ucp_atomic_swap64(ep, TARGET_LOCK_UNLOCKED, - remote_addr, rkey, &result_value); + status = opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED, + &result_value, sizeof(result_value), + remote_addr, rkey, mca_osc_ucx_component.ucp_worker); if (status != UCS_OK) { opal_output_verbose(1, ompi_osc_base_framework.framework_output, "%s:%d: ucp_atomic_swap64 failed: %d\n", diff --git a/opal/mca/common/ucx/common_ucx.c b/opal/mca/common/ucx/common_ucx.c index 5d30fbe47e..85b96a92cd 100644 --- a/opal/mca/common/ucx/common_ucx.c +++ b/opal/mca/common/ucx/common_ucx.c @@ -27,10 +27,10 @@ OPAL_DECLSPEC void opal_common_ucx_mca_register(void) registered = 1; mca_base_var_register("opal", "opal_common", "ucx", "progress_iterations", - "Set number of calls of internal UCX progress calls per opal_progress call", - MCA_BASE_VAR_TYPE_INT, NULL, 0, MCA_BASE_VAR_FLAG_SETTABLE, - OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_LOCAL, - &opal_common_ucx_progress_iterations); + "Set number of calls of internal UCX progress calls per opal_progress call", + MCA_BASE_VAR_TYPE_INT, NULL, 0, MCA_BASE_VAR_FLAG_SETTABLE, + OPAL_INFO_LVL_9, MCA_BASE_VAR_SCOPE_LOCAL, + &opal_common_ucx_progress_iterations); } void opal_common_ucx_empty_complete_cb(void *request, ucs_status_t status) diff --git a/opal/mca/common/ucx/common_ucx.h b/opal/mca/common/ucx/common_ucx.h index d1bbcbb8c9..3a2d208dc8 100644 --- a/opal/mca/common/ucx/common_ucx.h +++ b/opal/mca/common/ucx/common_ucx.h @@ -73,6 +73,44 @@ ucs_status_t opal_common_ucx_worker_flush(ucp_worker_h worker) status = ucp_worker_flush_nb(worker, 0, opal_common_ucx_empty_complete_cb); return opal_common_ucx_wait_request(status, worker); } + +static inline +ucs_status_t opal_common_ucx_atomic_fetch(ucp_ep_h ep, ucp_atomic_fetch_op_t opcode, + uint64_t value, void *result, size_t op_size, + uint64_t remote_addr, ucp_rkey_h rkey, + ucp_worker_h worker) +{ + ucs_status_ptr_t request; + + request = ucp_atomic_fetch_nb(ep, opcode, value, result, op_size, + remote_addr, rkey, opal_common_ucx_empty_complete_cb); + return opal_common_ucx_wait_request(request, worker); +} + +static inline +ucs_status_t opal_common_ucx_atomic_cswap(ucp_ep_h ep, uint64_t compare, + uint64_t value, void *result, size_t op_size, + uint64_t remote_addr, ucp_rkey_h rkey, + ucp_worker_h worker) +{ + uint64_t tmp = value; + ucs_status_t status; + + status = opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_CSWAP, compare, &tmp, + op_size, remote_addr, rkey, worker); + if (OPAL_LIKELY(UCS_OK == status)) { + /* in case if op_size is constant (like sizeof(type)) then this condition + * is evaluated in compile time */ + if (op_size == sizeof(uint64_t)) { + *(uint64_t*)result = tmp; + } else { + assert(op_size == sizeof(uint32_t)); + *(uint32_t*)result = tmp; + } + } + return status; +} + END_C_DECLS #endif