UCX osc: Use accumulate for operations/datatypes that are not covered by UCX
Signed-off-by: Joseph Schuchart <schuchart@hlrs.de>
Этот коммит содержится в:
родитель
899f58cef5
Коммит
4d7a3856fa
@ -323,7 +323,25 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
|
||||
return ret;
|
||||
}
|
||||
|
||||
static int do_atomic_op_replace_sum(
|
||||
static inline
|
||||
bool use_ucx_op(struct ompi_op_t *op, struct ompi_datatype_t *origin_dt)
|
||||
{
|
||||
|
||||
if (op == &ompi_mpi_op_replace.op ||
|
||||
op == &ompi_mpi_op_sum.op ||
|
||||
op == &ompi_mpi_op_no_op.op) {
|
||||
size_t dt_bytes;
|
||||
ompi_datatype_type_size(origin_dt, &dt_bytes);
|
||||
if (ompi_datatype_is_predefined(origin_dt) &&
|
||||
sizeof(uint64_t) >= dt_bytes) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static int do_atomic_op_intrinsic(
|
||||
ompi_osc_ucx_module_t *module,
|
||||
struct ompi_op_t *op,
|
||||
int target,
|
||||
@ -342,7 +360,7 @@ static int do_atomic_op_replace_sum(
|
||||
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
|
||||
ompi_datatype_type_size(target_dt, &target_dt_bytes);
|
||||
|
||||
if (origin_dt_bytes > sizeof(uint64_t) ||
|
||||
if (sizeof(uint64_t) > origin_dt_bytes ||
|
||||
origin_dt_bytes != target_dt_bytes ||
|
||||
target_count != origin_count) {
|
||||
return OMPI_ERR_NOT_SUPPORTED;
|
||||
@ -409,133 +427,6 @@ static int do_atomic_op_replace_sum(
|
||||
return ret;
|
||||
}
|
||||
|
||||
static int do_atomic_op_cswap(
|
||||
ompi_osc_ucx_module_t *module,
|
||||
struct ompi_op_t *op,
|
||||
int target,
|
||||
const void *origin_addr,
|
||||
int origin_count,
|
||||
struct ompi_datatype_t *origin_dt,
|
||||
ptrdiff_t target_disp,
|
||||
int target_count,
|
||||
struct ompi_datatype_t *target_dt,
|
||||
void *result_addr,
|
||||
ompi_osc_ucx_request_t *ucx_req)
|
||||
{
|
||||
int ret = OMPI_SUCCESS;
|
||||
size_t origin_dt_bytes;
|
||||
size_t target_dt_bytes;
|
||||
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
|
||||
ompi_datatype_type_size(target_dt, &target_dt_bytes);
|
||||
|
||||
if (origin_dt_bytes > sizeof(uint64_t) ||
|
||||
origin_dt_bytes != target_dt_bytes ||
|
||||
target_count != origin_count) {
|
||||
return OMPI_ERR_NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
|
||||
|
||||
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
|
||||
ret = get_dynamic_win_info(remote_addr, module, target);
|
||||
if (ret != OMPI_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < origin_count; ++i) {
|
||||
|
||||
uint64_t tmp_val;
|
||||
uint64_t target_val = 0;
|
||||
|
||||
// get the value from the origin
|
||||
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET,
|
||||
target, &target_val, origin_dt_bytes,
|
||||
remote_addr);
|
||||
if (ret != OMPI_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
|
||||
if (ret != OMPI_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
/* JS: move this loop into the request to overlap multiple cas operations? */
|
||||
do {
|
||||
|
||||
tmp_val = target_val;
|
||||
// compute the result value
|
||||
ompi_op_reduce(op, (void *)origin_addr, &tmp_val, 1, origin_dt);
|
||||
|
||||
// compare-and-swap the resulting value
|
||||
ret = opal_common_ucx_wpmem_cmpswp(module->mem, target_val, tmp_val,
|
||||
target, &tmp_val, origin_dt_bytes,
|
||||
remote_addr);
|
||||
if (ret != OMPI_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
// check whether the conditional swap was successful
|
||||
if (tmp_val == target_val) {
|
||||
break;
|
||||
}
|
||||
|
||||
target_val = tmp_val;
|
||||
|
||||
} while (1);
|
||||
|
||||
// store the result if necessary
|
||||
if (NULL != result_addr) {
|
||||
memcpy(result_addr, &tmp_val, origin_dt_bytes);
|
||||
result_addr = (void*)((intptr_t)result_addr + origin_dt_bytes);
|
||||
}
|
||||
// advance origin and remote address
|
||||
origin_addr = (void*)((intptr_t)origin_addr + origin_dt_bytes);
|
||||
remote_addr += origin_dt_bytes;
|
||||
}
|
||||
|
||||
if (NULL != ucx_req) {
|
||||
// nothing to wait for so mark the request as completed
|
||||
ompi_request_complete(&ucx_req->super, true);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static inline
|
||||
int do_atomic_op(
|
||||
ompi_osc_ucx_module_t *module,
|
||||
struct ompi_op_t *op,
|
||||
int target,
|
||||
const void *origin_addr,
|
||||
int origin_count,
|
||||
struct ompi_datatype_t *origin_dt,
|
||||
ptrdiff_t target_disp,
|
||||
int target_count,
|
||||
struct ompi_datatype_t *target_dt,
|
||||
void *result_addr,
|
||||
ompi_osc_ucx_request_t *ucx_req)
|
||||
{
|
||||
int ret;
|
||||
|
||||
if (op == &ompi_mpi_op_replace.op ||
|
||||
op == &ompi_mpi_op_sum.op ||
|
||||
op == &ompi_mpi_op_no_op.op) {
|
||||
ret = do_atomic_op_replace_sum(module, op, target,
|
||||
origin_addr, origin_count, origin_dt,
|
||||
target_disp, target_count, target_dt,
|
||||
result_addr, ucx_req);
|
||||
} else {
|
||||
ret = do_atomic_op_cswap(module, op, target,
|
||||
origin_addr, origin_count, origin_dt,
|
||||
target_disp, target_count, target_dt,
|
||||
result_addr, ucx_req);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt,
|
||||
int target, ptrdiff_t target_disp, int target_count,
|
||||
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
|
||||
@ -665,11 +556,11 @@ int accumulate_req(const void *origin_addr, int origin_count,
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (module->acc_single_intrinsic) {
|
||||
return do_atomic_op(module, op, target,
|
||||
origin_addr, origin_count, origin_dt,
|
||||
target_disp, target_count, target_dt,
|
||||
NULL, ucx_req);
|
||||
if (module->acc_single_intrinsic && use_ucx_op(op, origin_dt)) {
|
||||
return do_atomic_op_intrinsic(module, op, target,
|
||||
origin_addr, origin_count, origin_dt,
|
||||
target_disp, target_count, target_dt,
|
||||
NULL, ucx_req);
|
||||
}
|
||||
|
||||
|
||||
@ -923,11 +814,11 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (module->acc_single_intrinsic) {
|
||||
return do_atomic_op(module, op, target,
|
||||
origin_addr, origin_count, origin_dt,
|
||||
target_disp, target_count, target_dt,
|
||||
result_addr, ucx_req);
|
||||
if (module->acc_single_intrinsic && use_ucx_op(op, origin_dt)) {
|
||||
return do_atomic_op_intrinsic(module, op, target,
|
||||
origin_addr, origin_count, origin_dt,
|
||||
target_disp, target_count, target_dt,
|
||||
result_addr, ucx_req);
|
||||
}
|
||||
|
||||
ret = start_atomicity(module, target);
|
||||
|
Загрузка…
x
Ссылка в новой задаче
Block a user