From 4d7a3856face0dafc4d6dd19fc272bb4c2314a24 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Mon, 4 Nov 2019 15:04:22 +0100 Subject: [PATCH] UCX osc: Use accumulate for operations/datatypes that are not covered by UCX Signed-off-by: Joseph Schuchart --- ompi/mca/osc/ucx/osc_ucx_comm.c | 169 ++++++-------------------------- 1 file changed, 30 insertions(+), 139 deletions(-) diff --git a/ompi/mca/osc/ucx/osc_ucx_comm.c b/ompi/mca/osc/ucx/osc_ucx_comm.c index 3a7161ef1f..366a920374 100644 --- a/ompi/mca/osc/ucx/osc_ucx_comm.c +++ b/ompi/mca/osc/ucx/osc_ucx_comm.c @@ -323,7 +323,25 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module return ret; } -static int do_atomic_op_replace_sum( +static inline +bool use_ucx_op(struct ompi_op_t *op, struct ompi_datatype_t *origin_dt) +{ + + if (op == &ompi_mpi_op_replace.op || + op == &ompi_mpi_op_sum.op || + op == &ompi_mpi_op_no_op.op) { + size_t dt_bytes; + ompi_datatype_type_size(origin_dt, &dt_bytes); + if (ompi_datatype_is_predefined(origin_dt) && + sizeof(uint64_t) >= dt_bytes) { + return true; + } + } + + return false; +} + +static int do_atomic_op_intrinsic( ompi_osc_ucx_module_t *module, struct ompi_op_t *op, int target, @@ -342,7 +360,7 @@ static int do_atomic_op_replace_sum( ompi_datatype_type_size(origin_dt, &origin_dt_bytes); ompi_datatype_type_size(target_dt, &target_dt_bytes); - if (origin_dt_bytes > sizeof(uint64_t) || + if (sizeof(uint64_t) > origin_dt_bytes || origin_dt_bytes != target_dt_bytes || target_count != origin_count) { return OMPI_ERR_NOT_SUPPORTED; @@ -409,133 +427,6 @@ static int do_atomic_op_replace_sum( return ret; } -static int do_atomic_op_cswap( - ompi_osc_ucx_module_t *module, - struct ompi_op_t *op, - int target, - const void *origin_addr, - int origin_count, - struct ompi_datatype_t *origin_dt, - ptrdiff_t target_disp, - int target_count, - struct ompi_datatype_t *target_dt, - void *result_addr, - ompi_osc_ucx_request_t *ucx_req) -{ - int ret = OMPI_SUCCESS; - size_t origin_dt_bytes; - size_t target_dt_bytes; - ompi_datatype_type_size(origin_dt, &origin_dt_bytes); - ompi_datatype_type_size(target_dt, &target_dt_bytes); - - if (origin_dt_bytes > sizeof(uint64_t) || - origin_dt_bytes != target_dt_bytes || - target_count != origin_count) { - return OMPI_ERR_NOT_SUPPORTED; - } - - uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); - - if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { - ret = get_dynamic_win_info(remote_addr, module, target); - if (ret != OMPI_SUCCESS) { - return ret; - } - } - - for (int i = 0; i < origin_count; ++i) { - - uint64_t tmp_val; - uint64_t target_val = 0; - - // get the value from the origin - ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET, - target, &target_val, origin_dt_bytes, - remote_addr); - if (ret != OMPI_SUCCESS) { - return ret; - } - - ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target); - if (ret != OMPI_SUCCESS) { - return ret; - } - - /* JS: move this loop into the request to overlap multiple cas operations? */ - do { - - tmp_val = target_val; - // compute the result value - ompi_op_reduce(op, (void *)origin_addr, &tmp_val, 1, origin_dt); - - // compare-and-swap the resulting value - ret = opal_common_ucx_wpmem_cmpswp(module->mem, target_val, tmp_val, - target, &tmp_val, origin_dt_bytes, - remote_addr); - if (ret != OMPI_SUCCESS) { - return ret; - } - - // check whether the conditional swap was successful - if (tmp_val == target_val) { - break; - } - - target_val = tmp_val; - - } while (1); - - // store the result if necessary - if (NULL != result_addr) { - memcpy(result_addr, &tmp_val, origin_dt_bytes); - result_addr = (void*)((intptr_t)result_addr + origin_dt_bytes); - } - // advance origin and remote address - origin_addr = (void*)((intptr_t)origin_addr + origin_dt_bytes); - remote_addr += origin_dt_bytes; - } - - if (NULL != ucx_req) { - // nothing to wait for so mark the request as completed - ompi_request_complete(&ucx_req->super, true); - } - - return ret; -} - -static inline -int do_atomic_op( - ompi_osc_ucx_module_t *module, - struct ompi_op_t *op, - int target, - const void *origin_addr, - int origin_count, - struct ompi_datatype_t *origin_dt, - ptrdiff_t target_disp, - int target_count, - struct ompi_datatype_t *target_dt, - void *result_addr, - ompi_osc_ucx_request_t *ucx_req) -{ - int ret; - - if (op == &ompi_mpi_op_replace.op || - op == &ompi_mpi_op_sum.op || - op == &ompi_mpi_op_no_op.op) { - ret = do_atomic_op_replace_sum(module, op, target, - origin_addr, origin_count, origin_dt, - target_disp, target_count, target_dt, - result_addr, ucx_req); - } else { - ret = do_atomic_op_cswap(module, op, target, - origin_addr, origin_count, origin_dt, - target_disp, target_count, target_dt, - result_addr, ucx_req); - } - return ret; -} - - int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt, int target, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_dt, struct ompi_win_t *win) { @@ -665,11 +556,11 @@ int accumulate_req(const void *origin_addr, int origin_count, return ret; } - if (module->acc_single_intrinsic) { - return do_atomic_op(module, op, target, - origin_addr, origin_count, origin_dt, - target_disp, target_count, target_dt, - NULL, ucx_req); + if (module->acc_single_intrinsic && use_ucx_op(op, origin_dt)) { + return do_atomic_op_intrinsic(module, op, target, + origin_addr, origin_count, origin_dt, + target_disp, target_count, target_dt, + NULL, ucx_req); } @@ -923,11 +814,11 @@ int get_accumulate_req(const void *origin_addr, int origin_count, return ret; } - if (module->acc_single_intrinsic) { - return do_atomic_op(module, op, target, - origin_addr, origin_count, origin_dt, - target_disp, target_count, target_dt, - result_addr, ucx_req); + if (module->acc_single_intrinsic && use_ucx_op(op, origin_dt)) { + return do_atomic_op_intrinsic(module, op, target, + origin_addr, origin_count, origin_dt, + target_disp, target_count, target_dt, + result_addr, ucx_req); } ret = start_atomicity(module, target);