From 1d54071fc117502f1b80bdf244318b0fb76fac53 Mon Sep 17 00:00:00 2001 From: Valentin Petrov Date: Thu, 11 Jun 2020 22:39:01 +0300 Subject: [PATCH] coll/hcoll: reduce_scatter(block) interface Signed-off-by: Valentin Petrov --- ompi/mca/coll/hcoll/coll_hcoll.h | 14 ++++ ompi/mca/coll/hcoll/coll_hcoll_module.c | 11 ++++ ompi/mca/coll/hcoll/coll_hcoll_ops.c | 85 +++++++++++++++++++++++++ 3 files changed, 110 insertions(+) diff --git a/ompi/mca/coll/hcoll/coll_hcoll.h b/ompi/mca/coll/hcoll/coll_hcoll.h index 141792d636..aadc0735fd 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll.h +++ b/ompi/mca/coll/hcoll/coll_hcoll.h @@ -141,6 +141,8 @@ struct mca_coll_hcoll_module_t { mca_coll_base_module_t *previous_scatterv_module; mca_coll_base_module_reduce_scatter_fn_t previous_reduce_scatter; mca_coll_base_module_t *previous_reduce_scatter_module; + mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block; + mca_coll_base_module_t *previous_reduce_scatter_block_module; mca_coll_base_module_ibcast_fn_t previous_ibcast; mca_coll_base_module_t *previous_ibcast_module; mca_coll_base_module_ibarrier_fn_t previous_ibarrier; @@ -211,6 +213,18 @@ int mca_coll_hcoll_allreduce(const void *sbuf, void *rbuf, int count, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); +#if HCOLL_API > HCOLL_VERSION(4,5) +int mca_coll_hcoll_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); +int mca_coll_hcoll_reduce_scatter(const void *sbuf, void *rbuf, const int* rcounts, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); +#endif int mca_coll_hcoll_reduce(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, diff --git a/ompi/mca/coll/hcoll/coll_hcoll_module.c b/ompi/mca/coll/hcoll/coll_hcoll_module.c index 7e638bb309..d09607d8d0 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_module.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_module.c @@ -51,6 +51,7 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module) hcoll_module->previous_alltoallw = NULL; hcoll_module->previous_reduce = NULL; hcoll_module->previous_reduce_scatter = NULL; + hcoll_module->previous_reduce_scatter_block = NULL; hcoll_module->previous_ibarrier = NULL; hcoll_module->previous_ibcast = NULL; hcoll_module->previous_iallreduce = NULL; @@ -119,6 +120,8 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_barrier_module); OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_bcast_module); OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allreduce_module); + OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_block_module); + OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_module); OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgather_module); OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgatherv_module); OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_gatherv_module); @@ -173,6 +176,8 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu HCOL_SAVE_PREV_COLL_API(barrier); HCOL_SAVE_PREV_COLL_API(bcast); HCOL_SAVE_PREV_COLL_API(allreduce); + HCOL_SAVE_PREV_COLL_API(reduce_scatter_block); + HCOL_SAVE_PREV_COLL_API(reduce_scatter); HCOL_SAVE_PREV_COLL_API(reduce); HCOL_SAVE_PREV_COLL_API(allgather); HCOL_SAVE_PREV_COLL_API(allgatherv); @@ -419,6 +424,12 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority) hcoll_module->super.coll_ialltoallv = hcoll_collectives.coll_ialltoallv ? mca_coll_hcoll_ialltoallv : NULL; #else hcoll_module->super.coll_ialltoallv = NULL; +#endif +#if HCOLL_API > HCOLL_VERSION(4,5) + hcoll_module->super.coll_reduce_scatter_block = hcoll_collectives.coll_reduce_scatter_block ? + mca_coll_hcoll_reduce_scatter_block : NULL; + hcoll_module->super.coll_reduce_scatter = hcoll_collectives.coll_reduce_scatter ? + mca_coll_hcoll_reduce_scatter : NULL; #endif *priority = cm->hcoll_priority; module = &hcoll_module->super; diff --git a/ompi/mca/coll/hcoll/coll_hcoll_ops.c b/ompi/mca/coll/hcoll/coll_hcoll_ops.c index 5791fe17db..d864ae0d55 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_ops.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_ops.c @@ -760,3 +760,88 @@ int mca_coll_hcoll_ialltoallv(const void *sbuf, int *scounts, int *sdisps, return rc; } #endif + +#if HCOLL_API > HCOLL_VERSION(4,5) +int mca_coll_hcoll_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) { + dte_data_representation_t Dtype; + hcoll_dte_op_t *Op; + int rc; + HCOL_VERBOSE(20,"RUNNING HCOL REDUCE SCATTER BLOCK"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;", + dtype->super.name); + goto fallback; + } + + Op = ompi_op_2_hcolrte_op(op); + if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback allreduce;", + op->o_name); + goto fallback; + } + + rc = hcoll_collectives.coll_reduce_scatter_block((void *)sbuf,rbuf,rcount,Dtype,Op,hcoll_module->hcoll_context); + if (HCOLL_SUCCESS != rc){ + fallback: + HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE"); + rc = hcoll_module->previous_reduce_scatter_block(sbuf,rbuf, + rcount,dtype,op, + comm, hcoll_module->previous_allreduce_module); + } + return rc; +} + +int mca_coll_hcoll_reduce_scatter(const void *sbuf, void *rbuf, const int* rcounts, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) { + dte_data_representation_t Dtype; + hcoll_dte_op_t *Op; + int rc; + HCOL_VERBOSE(20,"RUNNING HCOL REDUCE SCATTER"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;", + dtype->super.name); + goto fallback; + } + + Op = ompi_op_2_hcolrte_op(op); + if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback allreduce;", + op->o_name); + goto fallback; + } + + rc = hcoll_collectives.coll_reduce_scatter((void*)sbuf, rbuf, (int*)rcounts, + Dtype, Op, hcoll_module->hcoll_context); + if (HCOLL_SUCCESS != rc){ + fallback: + HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE"); + rc = hcoll_module->previous_reduce_scatter(sbuf,rbuf, + rcounts,dtype,op, + comm, hcoll_module->previous_allreduce_module); + } + return rc; +} +#endif