coll/hcoll: reduce_scatter(block) interface
Signed-off-by: Valentin Petrov <valentinp@mellanox.com>
Этот коммит содержится в:
родитель
868eee31c1
Коммит
1d54071fc1
@ -141,6 +141,8 @@ struct mca_coll_hcoll_module_t {
|
||||
mca_coll_base_module_t *previous_scatterv_module;
|
||||
mca_coll_base_module_reduce_scatter_fn_t previous_reduce_scatter;
|
||||
mca_coll_base_module_t *previous_reduce_scatter_module;
|
||||
mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block;
|
||||
mca_coll_base_module_t *previous_reduce_scatter_block_module;
|
||||
mca_coll_base_module_ibcast_fn_t previous_ibcast;
|
||||
mca_coll_base_module_t *previous_ibcast_module;
|
||||
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
|
||||
@ -211,6 +213,18 @@ int mca_coll_hcoll_allreduce(const void *sbuf, void *rbuf, int count,
|
||||
struct ompi_communicator_t *comm,
|
||||
mca_coll_base_module_t *module);
|
||||
|
||||
#if HCOLL_API > HCOLL_VERSION(4,5)
|
||||
int mca_coll_hcoll_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
|
||||
struct ompi_datatype_t *dtype,
|
||||
struct ompi_op_t *op,
|
||||
struct ompi_communicator_t *comm,
|
||||
mca_coll_base_module_t *module);
|
||||
int mca_coll_hcoll_reduce_scatter(const void *sbuf, void *rbuf, const int* rcounts,
|
||||
struct ompi_datatype_t *dtype,
|
||||
struct ompi_op_t *op,
|
||||
struct ompi_communicator_t *comm,
|
||||
mca_coll_base_module_t *module);
|
||||
#endif
|
||||
int mca_coll_hcoll_reduce(const void *sbuf, void *rbuf, int count,
|
||||
struct ompi_datatype_t *dtype,
|
||||
struct ompi_op_t *op,
|
||||
|
@ -51,6 +51,7 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module)
|
||||
hcoll_module->previous_alltoallw = NULL;
|
||||
hcoll_module->previous_reduce = NULL;
|
||||
hcoll_module->previous_reduce_scatter = NULL;
|
||||
hcoll_module->previous_reduce_scatter_block = NULL;
|
||||
hcoll_module->previous_ibarrier = NULL;
|
||||
hcoll_module->previous_ibcast = NULL;
|
||||
hcoll_module->previous_iallreduce = NULL;
|
||||
@ -119,6 +120,8 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module
|
||||
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_barrier_module);
|
||||
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_bcast_module);
|
||||
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allreduce_module);
|
||||
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_block_module);
|
||||
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_module);
|
||||
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgather_module);
|
||||
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgatherv_module);
|
||||
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_gatherv_module);
|
||||
@ -173,6 +176,8 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu
|
||||
HCOL_SAVE_PREV_COLL_API(barrier);
|
||||
HCOL_SAVE_PREV_COLL_API(bcast);
|
||||
HCOL_SAVE_PREV_COLL_API(allreduce);
|
||||
HCOL_SAVE_PREV_COLL_API(reduce_scatter_block);
|
||||
HCOL_SAVE_PREV_COLL_API(reduce_scatter);
|
||||
HCOL_SAVE_PREV_COLL_API(reduce);
|
||||
HCOL_SAVE_PREV_COLL_API(allgather);
|
||||
HCOL_SAVE_PREV_COLL_API(allgatherv);
|
||||
@ -419,6 +424,12 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority)
|
||||
hcoll_module->super.coll_ialltoallv = hcoll_collectives.coll_ialltoallv ? mca_coll_hcoll_ialltoallv : NULL;
|
||||
#else
|
||||
hcoll_module->super.coll_ialltoallv = NULL;
|
||||
#endif
|
||||
#if HCOLL_API > HCOLL_VERSION(4,5)
|
||||
hcoll_module->super.coll_reduce_scatter_block = hcoll_collectives.coll_reduce_scatter_block ?
|
||||
mca_coll_hcoll_reduce_scatter_block : NULL;
|
||||
hcoll_module->super.coll_reduce_scatter = hcoll_collectives.coll_reduce_scatter ?
|
||||
mca_coll_hcoll_reduce_scatter : NULL;
|
||||
#endif
|
||||
*priority = cm->hcoll_priority;
|
||||
module = &hcoll_module->super;
|
||||
|
@ -760,3 +760,88 @@ int mca_coll_hcoll_ialltoallv(const void *sbuf, int *scounts, int *sdisps,
|
||||
return rc;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if HCOLL_API > HCOLL_VERSION(4,5)
|
||||
int mca_coll_hcoll_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
|
||||
struct ompi_datatype_t *dtype,
|
||||
struct ompi_op_t *op,
|
||||
struct ompi_communicator_t *comm,
|
||||
mca_coll_base_module_t *module) {
|
||||
dte_data_representation_t Dtype;
|
||||
hcoll_dte_op_t *Op;
|
||||
int rc;
|
||||
HCOL_VERBOSE(20,"RUNNING HCOL REDUCE SCATTER BLOCK");
|
||||
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
|
||||
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
|
||||
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
|
||||
/*If we are here then datatype is not simple predefined datatype */
|
||||
/*In future we need to add more complex mapping to the dte_data_representation_t */
|
||||
/* Now use fallback */
|
||||
HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;",
|
||||
dtype->super.name);
|
||||
goto fallback;
|
||||
}
|
||||
|
||||
Op = ompi_op_2_hcolrte_op(op);
|
||||
if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){
|
||||
/*If we are here then datatype is not simple predefined datatype */
|
||||
/*In future we need to add more complex mapping to the dte_data_representation_t */
|
||||
/* Now use fallback */
|
||||
HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback allreduce;",
|
||||
op->o_name);
|
||||
goto fallback;
|
||||
}
|
||||
|
||||
rc = hcoll_collectives.coll_reduce_scatter_block((void *)sbuf,rbuf,rcount,Dtype,Op,hcoll_module->hcoll_context);
|
||||
if (HCOLL_SUCCESS != rc){
|
||||
fallback:
|
||||
HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE");
|
||||
rc = hcoll_module->previous_reduce_scatter_block(sbuf,rbuf,
|
||||
rcount,dtype,op,
|
||||
comm, hcoll_module->previous_allreduce_module);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
int mca_coll_hcoll_reduce_scatter(const void *sbuf, void *rbuf, const int* rcounts,
|
||||
struct ompi_datatype_t *dtype,
|
||||
struct ompi_op_t *op,
|
||||
struct ompi_communicator_t *comm,
|
||||
mca_coll_base_module_t *module) {
|
||||
dte_data_representation_t Dtype;
|
||||
hcoll_dte_op_t *Op;
|
||||
int rc;
|
||||
HCOL_VERBOSE(20,"RUNNING HCOL REDUCE SCATTER");
|
||||
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
|
||||
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
|
||||
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
|
||||
/*If we are here then datatype is not simple predefined datatype */
|
||||
/*In future we need to add more complex mapping to the dte_data_representation_t */
|
||||
/* Now use fallback */
|
||||
HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;",
|
||||
dtype->super.name);
|
||||
goto fallback;
|
||||
}
|
||||
|
||||
Op = ompi_op_2_hcolrte_op(op);
|
||||
if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){
|
||||
/*If we are here then datatype is not simple predefined datatype */
|
||||
/*In future we need to add more complex mapping to the dte_data_representation_t */
|
||||
/* Now use fallback */
|
||||
HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback allreduce;",
|
||||
op->o_name);
|
||||
goto fallback;
|
||||
}
|
||||
|
||||
rc = hcoll_collectives.coll_reduce_scatter((void*)sbuf, rbuf, (int*)rcounts,
|
||||
Dtype, Op, hcoll_module->hcoll_context);
|
||||
if (HCOLL_SUCCESS != rc){
|
||||
fallback:
|
||||
HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE");
|
||||
rc = hcoll_module->previous_reduce_scatter(sbuf,rbuf,
|
||||
rcounts,dtype,op,
|
||||
comm, hcoll_module->previous_allreduce_module);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
#endif
|
||||
|
Загрузка…
x
Ссылка в новой задаче
Block a user