1
1

coll/hcoll: reduce_scatter(block) interface

Signed-off-by: Valentin Petrov <valentinp@mellanox.com>
Этот коммит содержится в:
Valentin Petrov 2020-06-11 22:39:01 +03:00
родитель 868eee31c1
Коммит 1d54071fc1
3 изменённых файлов: 110 добавлений и 0 удалений

Просмотреть файл

@ -141,6 +141,8 @@ struct mca_coll_hcoll_module_t {
mca_coll_base_module_t *previous_scatterv_module;
mca_coll_base_module_reduce_scatter_fn_t previous_reduce_scatter;
mca_coll_base_module_t *previous_reduce_scatter_module;
mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block;
mca_coll_base_module_t *previous_reduce_scatter_block_module;
mca_coll_base_module_ibcast_fn_t previous_ibcast;
mca_coll_base_module_t *previous_ibcast_module;
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
@ -211,6 +213,18 @@ int mca_coll_hcoll_allreduce(const void *sbuf, void *rbuf, int count,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);
#if HCOLL_API > HCOLL_VERSION(4,5)
int mca_coll_hcoll_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);
int mca_coll_hcoll_reduce_scatter(const void *sbuf, void *rbuf, const int* rcounts,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);
#endif
int mca_coll_hcoll_reduce(const void *sbuf, void *rbuf, int count,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,

Просмотреть файл

@ -51,6 +51,7 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module)
hcoll_module->previous_alltoallw = NULL;
hcoll_module->previous_reduce = NULL;
hcoll_module->previous_reduce_scatter = NULL;
hcoll_module->previous_reduce_scatter_block = NULL;
hcoll_module->previous_ibarrier = NULL;
hcoll_module->previous_ibcast = NULL;
hcoll_module->previous_iallreduce = NULL;
@ -119,6 +120,8 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_barrier_module);
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_bcast_module);
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allreduce_module);
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_block_module);
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_module);
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgather_module);
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgatherv_module);
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_gatherv_module);
@ -173,6 +176,8 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu
HCOL_SAVE_PREV_COLL_API(barrier);
HCOL_SAVE_PREV_COLL_API(bcast);
HCOL_SAVE_PREV_COLL_API(allreduce);
HCOL_SAVE_PREV_COLL_API(reduce_scatter_block);
HCOL_SAVE_PREV_COLL_API(reduce_scatter);
HCOL_SAVE_PREV_COLL_API(reduce);
HCOL_SAVE_PREV_COLL_API(allgather);
HCOL_SAVE_PREV_COLL_API(allgatherv);
@ -419,6 +424,12 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority)
hcoll_module->super.coll_ialltoallv = hcoll_collectives.coll_ialltoallv ? mca_coll_hcoll_ialltoallv : NULL;
#else
hcoll_module->super.coll_ialltoallv = NULL;
#endif
#if HCOLL_API > HCOLL_VERSION(4,5)
hcoll_module->super.coll_reduce_scatter_block = hcoll_collectives.coll_reduce_scatter_block ?
mca_coll_hcoll_reduce_scatter_block : NULL;
hcoll_module->super.coll_reduce_scatter = hcoll_collectives.coll_reduce_scatter ?
mca_coll_hcoll_reduce_scatter : NULL;
#endif
*priority = cm->hcoll_priority;
module = &hcoll_module->super;

Просмотреть файл

@ -760,3 +760,88 @@ int mca_coll_hcoll_ialltoallv(const void *sbuf, int *scounts, int *sdisps,
return rc;
}
#endif
#if HCOLL_API > HCOLL_VERSION(4,5)
int mca_coll_hcoll_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module) {
dte_data_representation_t Dtype;
hcoll_dte_op_t *Op;
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL REDUCE SCATTER BLOCK");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;",
dtype->super.name);
goto fallback;
}
Op = ompi_op_2_hcolrte_op(op);
if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback allreduce;",
op->o_name);
goto fallback;
}
rc = hcoll_collectives.coll_reduce_scatter_block((void *)sbuf,rbuf,rcount,Dtype,Op,hcoll_module->hcoll_context);
if (HCOLL_SUCCESS != rc){
fallback:
HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE");
rc = hcoll_module->previous_reduce_scatter_block(sbuf,rbuf,
rcount,dtype,op,
comm, hcoll_module->previous_allreduce_module);
}
return rc;
}
int mca_coll_hcoll_reduce_scatter(const void *sbuf, void *rbuf, const int* rcounts,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module) {
dte_data_representation_t Dtype;
hcoll_dte_op_t *Op;
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL REDUCE SCATTER");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;",
dtype->super.name);
goto fallback;
}
Op = ompi_op_2_hcolrte_op(op);
if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback allreduce;",
op->o_name);
goto fallback;
}
rc = hcoll_collectives.coll_reduce_scatter((void*)sbuf, rbuf, (int*)rcounts,
Dtype, Op, hcoll_module->hcoll_context);
if (HCOLL_SUCCESS != rc){
fallback:
HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE");
rc = hcoll_module->previous_reduce_scatter(sbuf,rbuf,
rcounts,dtype,op,
comm, hcoll_module->previous_allreduce_module);
}
return rc;
}
#endif