diff --git a/ompi/mca/osc/ucx/osc_ucx_comm.c b/ompi/mca/osc/ucx/osc_ucx_comm.c index 46ce011599..c4e7bcaaa7 100644 --- a/ompi/mca/osc/ucx/osc_ucx_comm.c +++ b/ompi/mca/osc/ucx/osc_ucx_comm.c @@ -24,6 +24,11 @@ return OMPI_ERROR; \ } +/* macro to check whether UCX supports atomic operation on the size the operands */ +#define ATOMIC_SIZE_SUPPORTED(_remote_addr, _size) \ + ((sizeof(uint32_t) == (_size) && !((_remote_addr) & 0x3)) || \ + (sizeof(uint64_t) == (_size) && !((_remote_addr) & 0x7))) + typedef struct ucx_iovec { void *addr; size_t len; @@ -367,6 +372,7 @@ static inline bool use_atomic_op( ompi_osc_ucx_module_t *module, struct ompi_op_t *op, + uint64_t remote_addr, struct ompi_datatype_t *origin_dt, struct ompi_datatype_t *target_dt, int origin_count, @@ -384,9 +390,8 @@ bool use_atomic_op( ompi_datatype_type_size(origin_dt, &origin_dt_bytes); ompi_datatype_type_size(target_dt, &target_dt_bytes); /* UCX only supports 32 and 64-bit operands atm */ - if (sizeof(uint64_t) >= origin_dt_bytes && - sizeof(uint32_t) <= origin_dt_bytes && - origin_dt_bytes == target_dt_bytes && + if (ATOMIC_SIZE_SUPPORTED(remote_addr, origin_dt_bytes) && + origin_dt_bytes == target_dt_bytes && origin_count == target_count) { return true; } @@ -603,7 +608,7 @@ int accumulate_req(const void *origin_addr, int origin_count, } /* rely on UCX network atomics if the user told us that it safe */ - if (use_atomic_op(module, op, origin_dt, target_dt, origin_count, target_count)) { + if (use_atomic_op(module, op, target_disp, origin_dt, target_dt, origin_count, target_count)) { return do_atomic_op_intrinsic(module, op, target, origin_addr, origin_count, origin_dt, target_disp, NULL, ucx_req); @@ -775,7 +780,7 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a } ompi_datatype_type_size(dt, &dt_bytes); - if (4 == dt_bytes || 8 == dt_bytes) { + if (ATOMIC_SIZE_SUPPORTED(remote_addr, dt_bytes)) { // fast path using UCX atomic operations return do_atomic_compare_and_swap(origin_addr, compare_addr, result_addr, dt, target, @@ -818,6 +823,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, struct ompi_datatype_t *dt, int target, ptrdiff_t target_disp, struct ompi_op_t *op, struct ompi_win_t *win) { + size_t dt_bytes; ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; int ret = OMPI_SUCCESS; @@ -826,12 +832,15 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, return ret; } - if (op == &ompi_mpi_op_no_op.op || op == &ompi_mpi_op_replace.op || - op == &ompi_mpi_op_sum.op) { + uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); + ompi_datatype_type_size(dt, &dt_bytes); + + /* UCX atomics are only supported on 32 and 64 bit values */ + if (ATOMIC_SIZE_SUPPORTED(remote_addr, dt_bytes) && + (op == &ompi_mpi_op_no_op.op || op == &ompi_mpi_op_replace.op || + op == &ompi_mpi_op_sum.op)) { uint64_t value; - uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); ucp_atomic_fetch_op_t opcode; - size_t dt_bytes; bool lock_acquired = false; if (!module->acc_single_intrinsic) { @@ -894,7 +903,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count, } /* rely on UCX network atomics if the user told us that it safe */ - if (use_atomic_op(module, op, origin_dt, target_dt, origin_count, target_count)) { + if (use_atomic_op(module, op, target_disp, origin_dt, target_dt, origin_count, target_count)) { return do_atomic_op_intrinsic(module, op, target, origin_addr, origin_count, origin_dt, target_disp, result_addr, ucx_req);