From 243b75aa804fa9655a0cce3c5a938dc0c6322231 Mon Sep 17 00:00:00 2001 From: Devendar Bureddy Date: Fri, 2 Oct 2015 01:49:31 +0300 Subject: [PATCH] HCOLL: Add alltoallv interface --- ompi/mca/coll/hcoll/coll_hcoll.h | 39 ++++++++++++++++++++++--- ompi/mca/coll/hcoll/coll_hcoll_module.c | 15 +++++++--- ompi/mca/coll/hcoll/coll_hcoll_ops.c | 37 +++++++++++++++++++++++ 3 files changed, 83 insertions(+), 8 deletions(-) diff --git a/ompi/mca/coll/hcoll/coll_hcoll.h b/ompi/mca/coll/hcoll/coll_hcoll.h index 8211061a2f..25a60b3922 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll.h +++ b/ompi/mca/coll/hcoll/coll_hcoll.h @@ -141,6 +141,10 @@ struct mca_coll_hcoll_module_t { mca_coll_base_module_t *previous_iallreduce_module; mca_coll_base_module_igatherv_fn_t previous_igatherv; mca_coll_base_module_t *previous_igatherv_module; + mca_coll_base_module_ialltoall_fn_t previous_ialltoall; + mca_coll_base_module_t *previous_ialltoall_module; + mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv; + mca_coll_base_module_t *previous_ialltoallv_module; }; typedef struct mca_coll_hcoll_module_t mca_coll_hcoll_module_t; @@ -192,6 +196,15 @@ int mca_coll_hcoll_alltoall(void *sbuf, int scount, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); +int mca_coll_hcoll_alltoallv(void *sbuf, int *scounts, + int *sdisps, + struct ompi_datatype_t *sdtype, + void *rbuf, int *rcounts, + int *rdisps, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); + int mca_coll_hcoll_gatherv(void* sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int *rcounts, int *displs, @@ -205,10 +218,10 @@ int mca_coll_hcoll_ibarrier(struct ompi_communicator_t *comm, mca_coll_base_module_t *module); int mca_coll_hcoll_ibcast(void *buff, int count, - struct ompi_datatype_t *datatype, int root, - struct ompi_communicator_t *comm, - ompi_request_t** request, - mca_coll_base_module_t *module); + struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, + ompi_request_t** request, + mca_coll_base_module_t *module); int mca_coll_hcoll_iallgather(void *sbuf, int scount, struct ompi_datatype_t *sdtype, @@ -225,6 +238,24 @@ int mca_coll_hcoll_iallreduce(void *sbuf, void *rbuf, int count, ompi_request_t** request, mca_coll_base_module_t *module); +int mca_coll_hcoll_ialltoall(void *sbuf, int scount, + struct ompi_datatype_t *sdtype, + void* rbuf, int rcount, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + ompi_request_t **req, + mca_coll_base_module_t *module); + +int mca_coll_hcoll_ialltoallv(void *sbuf, int *scounts, + int *sdisps, + struct ompi_datatype_t *sdtype, + void *rbuf, int *rcounts, + int *rdisps, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + ompi_request_t **req, + mca_coll_base_module_t *module); + int mca_coll_hcoll_igatherv(void* sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int *rcounts, int *displs, diff --git a/ompi/mca/coll/hcoll/coll_hcoll_module.c b/ompi/mca/coll/hcoll/coll_hcoll_module.c index d2623ad129..eb47c08183 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_module.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_module.c @@ -81,19 +81,21 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module OBJ_RELEASE(hcoll_module->previous_allreduce_module); OBJ_RELEASE(hcoll_module->previous_allgather_module); OBJ_RELEASE(hcoll_module->previous_gatherv_module); + OBJ_RELEASE(hcoll_module->previous_alltoall_module); + OBJ_RELEASE(hcoll_module->previous_alltoallv_module); OBJ_RELEASE(hcoll_module->previous_ibarrier_module); OBJ_RELEASE(hcoll_module->previous_ibcast_module); OBJ_RELEASE(hcoll_module->previous_iallreduce_module); OBJ_RELEASE(hcoll_module->previous_iallgather_module); OBJ_RELEASE(hcoll_module->previous_igatherv_module); + OBJ_RELEASE(hcoll_module->previous_ialltoall_module); + OBJ_RELEASE(hcoll_module->previous_ialltoallv_module); /* OBJ_RELEASE(hcoll_module->previous_allgatherv_module); OBJ_RELEASE(hcoll_module->previous_gather_module); OBJ_RELEASE(hcoll_module->previous_gatherv_module); - OBJ_RELEASE(hcoll_module->previous_alltoall_module); - OBJ_RELEASE(hcoll_module->previous_alltoallv_module); OBJ_RELEASE(hcoll_module->previous_alltoallw_module); OBJ_RELEASE(hcoll_module->previous_reduce_scatter_module); OBJ_RELEASE(hcoll_module->previous_reduce_module); @@ -127,12 +129,16 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu HCOL_SAVE_PREV_COLL_API(allreduce); HCOL_SAVE_PREV_COLL_API(allgather); HCOL_SAVE_PREV_COLL_API(gatherv); + HCOL_SAVE_PREV_COLL_API(alltoall); + HCOL_SAVE_PREV_COLL_API(alltoallv); HCOL_SAVE_PREV_COLL_API(ibarrier); HCOL_SAVE_PREV_COLL_API(ibcast); HCOL_SAVE_PREV_COLL_API(iallreduce); HCOL_SAVE_PREV_COLL_API(iallgather); HCOL_SAVE_PREV_COLL_API(igatherv); + HCOL_SAVE_PREV_COLL_API(ialltoall); + HCOL_SAVE_PREV_COLL_API(ialltoallv); /* These collectives are not yet part of hcoll, so @@ -141,8 +147,6 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu HCOL_SAVE_PREV_COLL_API(gather); HCOL_SAVE_PREV_COLL_API(reduce); HCOL_SAVE_PREV_COLL_API(allgatherv); - HCOL_SAVE_PREV_COLL_API(alltoall); - HCOL_SAVE_PREV_COLL_API(alltoallv); HCOL_SAVE_PREV_COLL_API(alltoallw); */ return OMPI_SUCCESS; @@ -310,6 +314,7 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority) hcoll_module->super.coll_allgather = hcoll_collectives.coll_allgather ? mca_coll_hcoll_allgather : NULL; hcoll_module->super.coll_allreduce = hcoll_collectives.coll_allreduce ? mca_coll_hcoll_allreduce : NULL; hcoll_module->super.coll_alltoall = /*hcoll_collectives.coll_alltoall ? mca_coll_hcoll_alltoall : */ NULL; + hcoll_module->super.coll_alltoallv = hcoll_collectives.coll_alltoallv ? mca_coll_hcoll_alltoallv : NULL; hcoll_module->super.coll_gatherv = hcoll_collectives.coll_gatherv ? mca_coll_hcoll_gatherv : NULL; hcoll_module->super.coll_ibarrier = hcoll_collectives.coll_ibarrier ? mca_coll_hcoll_ibarrier : NULL; hcoll_module->super.coll_ibcast = hcoll_collectives.coll_ibcast ? mca_coll_hcoll_ibcast : NULL; @@ -317,6 +322,8 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority) hcoll_module->super.coll_iallreduce = hcoll_collectives.coll_iallreduce ? mca_coll_hcoll_iallreduce : NULL; hcoll_module->super.coll_gather = /*hcoll_collectives.coll_gather ? mca_coll_hcoll_gather :*/ NULL; hcoll_module->super.coll_igatherv = hcoll_collectives.coll_igatherv ? mca_coll_hcoll_igatherv : NULL; + hcoll_module->super.coll_ialltoall = /*hcoll_collectives.coll_ialltoall ? mca_coll_hcoll_ialltoall : */ NULL; + hcoll_module->super.coll_ialltoallv = /*hcoll_collectives.coll_ialltoallv ? mca_coll_hcoll_ialltoallv : */ NULL; *priority = cm->hcoll_priority; module = &hcoll_module->super; diff --git a/ompi/mca/coll/hcoll/coll_hcoll_ops.c b/ompi/mca/coll/hcoll/coll_hcoll_ops.c index c27486ae73..32e7612d2f 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_ops.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_ops.c @@ -226,6 +226,43 @@ int mca_coll_hcoll_alltoall(void *sbuf, int scount, return rc; } +int mca_coll_hcoll_alltoallv(void *sbuf, int *scounts, int *sdisps, + struct ompi_datatype_t *sdtype, + void *rbuf, int *rcounts, int *rdisps, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + dte_data_representation_t stype; + dte_data_representation_t rtype; + int rc; + HCOL_VERBOSE(20,"RUNNING HCOL ALLTOALLV"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + stype = ompi_dtype_2_dte_dtype(sdtype); + rtype = ompi_dtype_2_dte_dtype(rdtype); + if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) + || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) + && mca_coll_hcoll_component.hcoll_datatype_fallback){ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: sdtype = %s, rdtype = %s; calling fallback alltoallv;", + sdtype->super.name, + rdtype->super.name); + rc = hcoll_module->previous_alltoallv(sbuf, scounts, sdisps, sdtype, + rbuf, rcounts, rdisps, rdtype, + comm, hcoll_module->previous_alltoallv_module); + return rc; + } + rc = hcoll_collectives.coll_alltoallv(sbuf, scounts, sdisps, stype, + rbuf, rcounts, rdisps, rtype, + hcoll_module->hcoll_context); + if (HCOLL_SUCCESS != rc){ + HCOL_VERBOSE(20,"RUNNING FALLBACK ALLTOALLV"); + rc = hcoll_module->previous_alltoallv(sbuf, scounts, sdisps, sdtype, + rbuf, rcounts, rdisps, rdtype, + comm, hcoll_module->previous_alltoallv_module); + } + return rc; +} + int mca_coll_hcoll_gatherv(void* sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int *rcounts, int *displs,