From 6ea920e225c7ed905949c6afec554b6bf2705f94 Mon Sep 17 00:00:00 2001 From: Tomislav Janjusic Date: Thu, 17 Jan 2019 06:20:58 +0200 Subject: [PATCH] Coll/hcoll: adding scatterv interface Signed-off-by: Valentin Petrov valentinp@mellanox.com --- ompi/mca/coll/hcoll/coll_hcoll.h | 11 ++++++++ ompi/mca/coll/hcoll/coll_hcoll_module.c | 5 ++++ ompi/mca/coll/hcoll/coll_hcoll_ops.c | 37 +++++++++++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/ompi/mca/coll/hcoll/coll_hcoll.h b/ompi/mca/coll/hcoll/coll_hcoll.h index aaecbc11fe..d7bb79658e 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll.h +++ b/ompi/mca/coll/hcoll/coll_hcoll.h @@ -138,6 +138,8 @@ struct mca_coll_hcoll_module_t { mca_coll_base_module_t *previous_gather_module; mca_coll_base_module_gatherv_fn_t previous_gatherv; mca_coll_base_module_t *previous_gatherv_module; + mca_coll_base_module_scatterv_fn_t previous_scatterv; + mca_coll_base_module_t *previous_scatterv_module; mca_coll_base_module_reduce_scatter_fn_t previous_reduce_scatter; mca_coll_base_module_t *previous_reduce_scatter_module; mca_coll_base_module_ibcast_fn_t previous_ibcast; @@ -241,6 +243,15 @@ int mca_coll_hcoll_gatherv(const void* sbuf, int scount, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + +int mca_coll_hcoll_scatterv(const void* sbuf, const int *scounts, const int *displs, + struct ompi_datatype_t *sdtype, + void* rbuf, int rcount, + struct ompi_datatype_t *rdtype, + int root, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); + int mca_coll_hcoll_ibarrier(struct ompi_communicator_t *comm, ompi_request_t** request, mca_coll_base_module_t *module); diff --git a/ompi/mca/coll/hcoll/coll_hcoll_module.c b/ompi/mca/coll/hcoll/coll_hcoll_module.c index aa262c9849..7e638bb309 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_module.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_module.c @@ -45,6 +45,7 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module) hcoll_module->previous_allgatherv = NULL; hcoll_module->previous_gather = NULL; hcoll_module->previous_gatherv = NULL; + hcoll_module->previous_scatterv = NULL; hcoll_module->previous_alltoall = NULL; hcoll_module->previous_alltoallv = NULL; hcoll_module->previous_alltoallw = NULL; @@ -68,6 +69,7 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module) hcoll_module->previous_allgatherv_module = NULL; hcoll_module->previous_gather_module = NULL; hcoll_module->previous_gatherv_module = NULL; + hcoll_module->previous_scatterv_module = NULL; hcoll_module->previous_alltoall_module = NULL; hcoll_module->previous_alltoallv_module = NULL; hcoll_module->previous_alltoallw_module = NULL; @@ -120,6 +122,7 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgather_module); OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgatherv_module); OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_gatherv_module); + OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_scatterv_module); OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_alltoall_module); OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_alltoallv_module); OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_module); @@ -174,6 +177,7 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu HCOL_SAVE_PREV_COLL_API(allgather); HCOL_SAVE_PREV_COLL_API(allgatherv); HCOL_SAVE_PREV_COLL_API(gatherv); + HCOL_SAVE_PREV_COLL_API(scatterv); HCOL_SAVE_PREV_COLL_API(alltoall); HCOL_SAVE_PREV_COLL_API(alltoallv); @@ -392,6 +396,7 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority) hcoll_module->super.coll_alltoall = hcoll_collectives.coll_alltoall ? mca_coll_hcoll_alltoall : NULL; hcoll_module->super.coll_alltoallv = hcoll_collectives.coll_alltoallv ? mca_coll_hcoll_alltoallv : NULL; hcoll_module->super.coll_gatherv = hcoll_collectives.coll_gatherv ? mca_coll_hcoll_gatherv : NULL; + hcoll_module->super.coll_scatterv = hcoll_collectives.coll_scatterv ? mca_coll_hcoll_scatterv : NULL; hcoll_module->super.coll_reduce = hcoll_collectives.coll_reduce ? mca_coll_hcoll_reduce : NULL; hcoll_module->super.coll_ibarrier = hcoll_collectives.coll_ibarrier ? mca_coll_hcoll_ibarrier : NULL; hcoll_module->super.coll_ibcast = hcoll_collectives.coll_ibcast ? mca_coll_hcoll_ibcast : NULL; diff --git a/ompi/mca/coll/hcoll/coll_hcoll_ops.c b/ompi/mca/coll/hcoll/coll_hcoll_ops.c index de563e455b..5791fe17db 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_ops.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_ops.c @@ -397,6 +397,43 @@ int mca_coll_hcoll_gatherv(const void* sbuf, int scount, } +int mca_coll_hcoll_scatterv(const void* sbuf, const int *scounts, const int *displs, + struct ompi_datatype_t *sdtype, + void* rbuf, int rcount, + struct ompi_datatype_t *rdtype, + int root, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + dte_data_representation_t stype; + dte_data_representation_t rtype; + int rc; + HCOL_VERBOSE(20,"RUNNING HCOL SCATTERV"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED); + rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) { + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: sdtype = %s, rdtype = %s; calling fallback scatterv;", + sdtype->super.name, + rdtype->super.name); + rc = hcoll_module->previous_scatterv(sbuf, scounts, displs, sdtype, + rbuf, rcount, rdtype, root, + comm, hcoll_module->previous_scatterv_module); + return rc; + } + rc = hcoll_collectives.coll_scatterv((void *)sbuf, (int *)scounts, (int *)displs, stype, rbuf, rcount, rtype, root, hcoll_module->hcoll_context); + if (HCOLL_SUCCESS != rc){ + HCOL_VERBOSE(20,"RUNNING FALLBACK SCATTERV"); + rc = hcoll_module->previous_scatterv(sbuf, scounts, displs, sdtype, + rbuf, rcount, rdtype, root, + comm, hcoll_module->previous_scatterv_module); + } + return rc; +} + int mca_coll_hcoll_ibarrier(struct ompi_communicator_t *comm, ompi_request_t ** request, mca_coll_base_module_t *module)