diff --git a/ompi/mca/osc/ucx/osc_ucx_comm.c b/ompi/mca/osc/ucx/osc_ucx_comm.c index ddab2c2d5b..e4a3ede8a4 100644 --- a/ompi/mca/osc/ucx/osc_ucx_comm.c +++ b/ompi/mca/osc/ucx/osc_ucx_comm.c @@ -567,6 +567,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a return ret; } + ret = start_atomicity(module, ep, target); + if (ret != OMPI_SUCCESS) { + return ret; + } + ompi_datatype_type_size(dt, &dt_bytes); memcpy(result_addr, origin_addr, dt_bytes); req = ucp_atomic_fetch_nb(ep, UCP_ATOMIC_FETCH_OP_CSWAP, *(uint64_t *)compare_addr, @@ -575,7 +580,12 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a ucp_request_release(req); } - return incr_and_check_ops_num(module, target, ep); + ret = incr_and_check_ops_num(module, target, ep); + if (ret != OMPI_SUCCESS) { + return ret; + } + + return end_atomicity(module, ep, target); } int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, @@ -600,6 +610,11 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, size_t dt_bytes; ompi_osc_ucx_internal_request_t *req = NULL; + ret = start_atomicity(module, ep, target); + if (ret != OMPI_SUCCESS) { + return ret; + } + ompi_datatype_type_size(dt, &dt_bytes); if (op == &ompi_mpi_op_replace.op) { @@ -617,7 +632,12 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, ucp_request_release(req); } - return incr_and_check_ops_num(module, target, ep); + ret = incr_and_check_ops_num(module, target, ep); + if (ret != OMPI_SUCCESS) { + return ret; + } + + return end_atomicity(module, ep, target); } else { return ompi_osc_ucx_get_accumulate(origin_addr, 1, dt, result_addr, 1, dt, target, target_disp, 1, dt, op, win);