1
1

Merge pull request #1913 from vspetrov/hcoll_derived_datatypes

coll/hcoll mpi datatypes support
Этот коммит содержится в:
Joshua Ladd 2016-07-29 10:08:23 -04:00 коммит произвёл GitHub
родитель b748afceb1 3582bba6b7
Коммит 4a03a657c6
6 изменённых файлов: 371 добавлений и 190 удалений

Просмотреть файл

@ -49,6 +49,11 @@ typedef struct mca_coll_hcoll_ops_t {
int (*hcoll_barrier)(void *);
} mca_coll_hcoll_ops_t;
typedef struct {
opal_free_list_item_t super;
dte_data_representation_t type;
} mca_coll_hcoll_dtype_t;
OBJ_CLASS_DECLARATION(mca_coll_hcoll_dtype_t);
struct mca_coll_hcoll_component_t {
/** Base coll component */
@ -89,6 +94,8 @@ struct mca_coll_hcoll_component_t {
/* FCA global stuff */
mca_coll_hcoll_ops_t hcoll_ops;
opal_free_list_t requests;
opal_free_list_t dtypes;
int derived_types_support_enabled;
};
typedef struct mca_coll_hcoll_component_t mca_coll_hcoll_component_t;

Просмотреть файл

@ -17,6 +17,7 @@
#include "coll_hcoll.h"
#include "opal/mca/installdirs/installdirs.h"
#include "coll_hcoll_dtypes.h"
/*
* Public string showing the coll ompi_hcol component version number
@ -173,7 +174,15 @@ static int hcoll_register(void)
1,
&mca_coll_hcoll_component.hcoll_datatype_fallback,
0));
#if HCOLL_API >= HCOLL_VERSION(3,6)
CHECK(reg_int("dts",NULL,
"[1|0|] Enable/Disable derived types support",
1,
&mca_coll_hcoll_component.derived_types_support_enabled,
0));
#else
mca_coll_hcoll_component.derived_types_support_enabled = 0;
#endif
mca_coll_hcoll_component.compiletime_version = HCOLL_VERNO_STRING;
mca_base_component_var_register(&mca_coll_hcoll_component.super.collm_version,
MCA_COMPILETIME_VER,
@ -244,7 +253,7 @@ static int hcoll_close(void)
HCOL_VERBOSE(5,"HCOLL FINALIZE");
rc = hcoll_finalize();
OBJ_DESTRUCT(&cm->dtypes);
opal_progress_unregister(mca_coll_hcoll_progress);
if (HCOLL_SUCCESS != rc){
HCOL_VERBOSE(1,"Hcol library finalize failed");

Просмотреть файл

@ -6,8 +6,10 @@
It is used to extract allreduce bcol functions where the arrhythmetics has to be done*/
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/datatype/ompi_datatype_internal.h"
#include "ompi/mca/op/op.h"
#include "hcoll/api/hcoll_dte.h"
extern int hcoll_type_attr_keyval;
/*to keep this at hand: Ids of the basic opal_datatypes:
#define OPAL_DATATYPE_INT1 4
@ -31,9 +33,7 @@
total 15 types
*/
static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX_PREDEFINED] = {
static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OMPI_DATATYPE_MAX_PREDEFINED] = {
&DTE_ZERO, /*OPAL_DATATYPE_LOOP 0 */
&DTE_ZERO, /*OPAL_DATATYPE_END_LOOP 1 */
&DTE_ZERO, /*OPAL_DATATYPE_LB 2 */
@ -53,34 +53,113 @@ static dte_data_representation_t* ompi_datatype_2_dte_data_rep[OPAL_DATATYPE_MAX
&DTE_FLOAT64, /*OPAL_DATATYPE_FLOAT8 16 */
&DTE_FLOAT96, /*OPAL_DATATYPE_FLOAT12 17 */
&DTE_FLOAT128, /*OPAL_DATATYPE_FLOAT16 18 */
#if defined(DTE_FLOAT32_COMPLEX) && defined(DTE_FLOAT64_COMPLEX)
#if defined(DTE_FLOAT32_COMPLEX)
&DTE_FLOAT32_COMPLEX, /*OPAL_DATATYPE_COMPLEX8 19 */
&DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX16 20 */
#else
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX8 19 */
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX16 20 */
&DTE_ZERO,
#endif
#if defined(DTE_FLOAT64_COMPLEX)
&DTE_FLOAT64_COMPLEX, /*OPAL_DATATYPE_COMPLEX32 20 */
#else
&DTE_ZERO,
#endif
#if defined(DTE_FLOAT128_COMPLEX)
&DTE_FLOAT128_COMPLEX, /*OPAL_DATATYPE_COMPLEX64 21 */
#else
&DTE_ZERO,
#endif
&DTE_ZERO, /*OPAL_DATATYPE_COMPLEX32 21 */
&DTE_ZERO, /*OPAL_DATATYPE_BOOL 22 */
&DTE_ZERO, /*OPAL_DATATYPE_WCHAR 23 */
&DTE_ZERO /*OPAL_DATATYPE_UNAVAILABLE 24 */
};
static dte_data_representation_t ompi_dtype_2_dte_dtype(ompi_datatype_t *dtype){
enum {
TRY_FIND_DERIVED,
NO_DERIVED
};
#if HCOLL_API >= HCOLL_VERSION(3,6)
static inline
int hcoll_map_derived_type(ompi_datatype_t *dtype, dte_data_representation_t *new_dte)
{
int rc;
if (NULL == dtype->args) {
/* predefined type, shouldn't call this */
return OMPI_SUCCESS;
}
rc = hcoll_create_mpi_type((void*)dtype, new_dte);
return rc == HCOLL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR;
}
static dte_data_representation_t find_derived_mapping(ompi_datatype_t *dtype){
dte_data_representation_t dte = DTE_ZERO;
mca_coll_hcoll_dtype_t *hcoll_dtype;
if (mca_coll_hcoll_component.derived_types_support_enabled) {
int map_found = 0;
ompi_attr_get_c(dtype->d_keyhash, hcoll_type_attr_keyval,
(void**)&hcoll_dtype, &map_found);
if (!map_found)
hcoll_map_derived_type(dtype, &dte);
else
dte = hcoll_dtype->type;
}
return dte;
}
static inline dte_data_representation_t
ompi_predefined_derived_2_hcoll(int ompi_id) {
switch(ompi_id) {
case OMPI_DATATYPE_MPI_FLOAT_INT:
return DTE_FLOAT_INT;
case OMPI_DATATYPE_MPI_DOUBLE_INT:
return DTE_DOUBLE_INT;
case OMPI_DATATYPE_MPI_LONG_INT:
return DTE_LONG_INT;
case OMPI_DATATYPE_MPI_SHORT_INT:
return DTE_SHORT_INT;
case OMPI_DATATYPE_MPI_LONG_DOUBLE_INT:
return DTE_LONG_DOUBLE_INT;
case OMPI_DATATYPE_MPI_2INT:
return DTE_2INT;
default:
break;
}
return DTE_ZERO;
}
#endif
static dte_data_representation_t
ompi_dtype_2_hcoll_dtype( ompi_datatype_t *dtype,
const int mode)
{
int ompi_type_id = dtype->id;
int opal_type_id = dtype->super.id;
dte_data_representation_t dte_data_rep;
if (!(dtype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS)) {
ompi_type_id = -1;
dte_data_representation_t dte_data_rep = DTE_ZERO;
if (ompi_type_id < OMPI_DATATYPE_MPI_MAX_PREDEFINED) {
if (opal_type_id > 0 && opal_type_id < OPAL_DATATYPE_MAX_PREDEFINED) {
dte_data_rep = *ompi_datatype_2_dte_data_rep[opal_type_id];
}
#if HCOLL_API >= HCOLL_VERSION(3,6)
else if (TRY_FIND_DERIVED == mode){
dte_data_rep = ompi_predefined_derived_2_hcoll(ompi_type_id);
}
} else {
if (TRY_FIND_DERIVED == mode)
dte_data_rep = find_derived_mapping(dtype);
#endif
}
if (OPAL_UNLIKELY( ompi_type_id < 0 ||
ompi_type_id >= OPAL_DATATYPE_MAX_PREDEFINED)){
if (HCOL_DTE_IS_ZERO(dte_data_rep) && TRY_FIND_DERIVED == mode &&
!mca_coll_hcoll_component.hcoll_datatype_fallback) {
dte_data_rep = DTE_ZERO;
dte_data_rep.rep.in_line_rep.data_handle.in_line.in_line = 0;
dte_data_rep.rep.in_line_rep.data_handle.pointer_to_handle = (uint64_t ) &dtype->super;
return dte_data_rep;
}
return *ompi_datatype_2_dte_data_rep[opal_type_id];
return dte_data_rep;
}
static hcoll_dte_op_t* ompi_op_2_hcoll_op[OMPI_OP_BASE_FORTRAN_OP_MAX + 1] = {
@ -108,4 +187,27 @@ static hcoll_dte_op_t* ompi_op_2_hcolrte_op(ompi_op_t *op) {
return ompi_op_2_hcoll_op[op->o_f_to_c_index];
}
#if HCOLL_API >= HCOLL_VERSION(3,6)
static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) {
int ret = OMPI_SUCCESS;
mca_coll_hcoll_dtype_t *dtype =
(mca_coll_hcoll_dtype_t*) attr_val;
assert(dtype);
if (HCOLL_SUCCESS != (ret = hcoll_dt_destroy(dtype->type))) {
HCOL_ERROR("failed to delete type attr: hcoll_dte_destroy returned %d",ret);
return OMPI_ERROR;
}
opal_free_list_return(&mca_coll_hcoll_component.dtypes,
&dtype->super);
return OMPI_SUCCESS;
}
#else
static int hcoll_type_attr_del_fn(MPI_Datatype type, int keyval, void *attr_val, void *extra) {
/*Do nothing - it's an old version of hcoll w/o dtypes support */
return OMPI_SUCCESS;
}
#endif
#endif /* COLL_HCOLL_DTYPES_H */

Просмотреть файл

@ -10,8 +10,10 @@
#include "ompi_config.h"
#include "coll_hcoll.h"
#include "coll_hcoll_dtypes.h"
int hcoll_comm_attr_keyval;
int hcoll_type_attr_keyval;
/*
* Initial query function that is invoked during MPI_INIT, allowing
@ -240,6 +242,10 @@ int mca_coll_hcoll_progress(void)
}
OBJ_CLASS_INSTANCE(mca_coll_hcoll_dtype_t,
opal_free_list_item_t,
NULL,NULL);
/*
* Invoked when there's a new communicator that has been created.
* Look at the communicator and decide which set of functions and
@ -317,6 +323,24 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority)
HCOL_ERROR("Hcol comm keyval create failed");
return NULL;
}
if (mca_coll_hcoll_component.derived_types_support_enabled) {
copy_fn.attr_datatype_copy_fn = (MPI_Type_internal_copy_attr_function *) MPI_TYPE_NULL_COPY_FN;
del_fn.attr_datatype_delete_fn = hcoll_type_attr_del_fn;
err = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn, &hcoll_type_attr_keyval, NULL ,0, NULL);
if (OMPI_SUCCESS != err) {
cm->hcoll_enable = 0;
hcoll_finalize();
opal_progress_unregister(mca_coll_hcoll_progress);
HCOL_ERROR("Hcol type keyval create failed");
return NULL;
}
}
OBJ_CONSTRUCT(&cm->dtypes, opal_free_list_t);
opal_free_list_init(&cm->dtypes, sizeof(mca_coll_hcoll_dtype_t),
8, OBJ_CLASS(mca_coll_hcoll_dtype_t), 0, 0,
32, -1, 32, NULL, 0, NULL, NULL, NULL);
}
hcoll_module = OBJ_NEW(mca_coll_hcoll_module_t);

Просмотреть файл

@ -44,9 +44,9 @@ int mca_coll_hcoll_bcast(void *buff, int count,
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL BCAST");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
dtype = ompi_dtype_2_dte_dtype(datatype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(dtype) || HCOL_DTE_IS_COMPLEX(dtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
dtype = ompi_dtype_2_hcoll_dtype(datatype, TRY_FIND_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(dtype))) {
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -76,11 +76,12 @@ int mca_coll_hcoll_allgather(const void *sbuf, int scount,
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL ALLGATHER");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
stype = ompi_dtype_2_dte_dtype(sdtype);
rtype = ompi_dtype_2_dte_dtype(rdtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype)
|| HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
stype = ompi_dtype_2_hcoll_dtype(sdtype, TRY_FIND_DERIVED);
rtype = ompi_dtype_2_hcoll_dtype(rdtype, TRY_FIND_DERIVED);
if (sbuf == MPI_IN_PLACE) {
stype = rtype;
}
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -117,11 +118,9 @@ int mca_coll_hcoll_allgatherv(const void *sbuf, int scount,
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL ALLGATHERV");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
stype = ompi_dtype_2_dte_dtype(sdtype);
rtype = ompi_dtype_2_dte_dtype(rdtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype)
|| HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -161,11 +160,9 @@ int mca_coll_hcoll_gather(const void *sbuf, int scount,
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL GATHER");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
stype = ompi_dtype_2_dte_dtype(sdtype);
rtype = ompi_dtype_2_dte_dtype(rdtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype)
|| HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -201,9 +198,8 @@ int mca_coll_hcoll_allreduce(const void *sbuf, void *rbuf, int count,
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL ALLREDUCE");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
Dtype = ompi_dtype_2_dte_dtype(dtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(Dtype) || HCOL_DTE_IS_COMPLEX(Dtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -250,9 +246,8 @@ int mca_coll_hcoll_reduce(const void *sbuf, void *rbuf, int count,
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL REDUCE");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
Dtype = ompi_dtype_2_dte_dtype(dtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(Dtype) || HCOL_DTE_IS_COMPLEX(Dtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -302,11 +297,9 @@ int mca_coll_hcoll_alltoall(const void *sbuf, int scount,
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL ALLTOALL");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
stype = ompi_dtype_2_dte_dtype(sdtype);
rtype = ompi_dtype_2_dte_dtype(rdtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype)
|| HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -342,11 +335,9 @@ int mca_coll_hcoll_alltoallv(const void *sbuf, const int *scounts, const int *sd
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL ALLTOALLV");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
stype = ompi_dtype_2_dte_dtype(sdtype);
rtype = ompi_dtype_2_dte_dtype(rdtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype)
|| HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
HCOL_VERBOSE(20,"Ompi_datatype is not supported: sdtype = %s, rdtype = %s; calling fallback alltoallv;",
sdtype->super.name,
rdtype->super.name);
@ -380,11 +371,9 @@ int mca_coll_hcoll_gatherv(const void* sbuf, int scount,
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL GATHERV");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
stype = ompi_dtype_2_dte_dtype(sdtype);
rtype = ompi_dtype_2_dte_dtype(rdtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype)
|| HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -436,9 +425,8 @@ int mca_coll_hcoll_ibcast(void *buff, int count,
HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING BCAST");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
rt_handle = (void**) request;
dtype = ompi_dtype_2_dte_dtype(datatype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(dtype) || HCOL_DTE_IS_COMPLEX(dtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
dtype = ompi_dtype_2_hcoll_dtype(datatype, TRY_FIND_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(dtype))){
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -471,11 +459,9 @@ int mca_coll_hcoll_iallgather(const void *sbuf, int scount,
HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING ALLGATHER");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
rt_handle = (void**) request;
stype = ompi_dtype_2_dte_dtype(sdtype);
rtype = ompi_dtype_2_dte_dtype(rdtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype)
|| HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
stype = ompi_dtype_2_hcoll_dtype(sdtype, TRY_FIND_DERIVED);
rtype = ompi_dtype_2_hcoll_dtype(rdtype, TRY_FIND_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -515,12 +501,10 @@ int mca_coll_hcoll_iallgatherv(const void *sbuf, int scount,
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING ALLGATHERV");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
stype = ompi_dtype_2_dte_dtype(sdtype);
rtype = ompi_dtype_2_dte_dtype(rdtype);
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
void **rt_handle = (void **) request;
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype)
|| HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -565,9 +549,8 @@ int mca_coll_hcoll_iallreduce(const void *sbuf, void *rbuf, int count,
HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING ALLREDUCE");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
rt_handle = (void**) request;
Dtype = ompi_dtype_2_dte_dtype(dtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(Dtype) || HCOL_DTE_IS_COMPLEX(Dtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -615,10 +598,9 @@ int mca_coll_hcoll_ireduce(const void *sbuf, void *rbuf, int count,
int rc;
HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING REDUCE");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
Dtype = ompi_dtype_2_dte_dtype(dtype);
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
void **rt_handle = (void**) request;
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(Dtype) || HCOL_DTE_IS_COMPLEX(Dtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */
@ -675,11 +657,9 @@ int mca_coll_hcoll_igatherv(const void* sbuf, int scount,
HCOL_VERBOSE(20,"RUNNING HCOL IGATHERV");
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
rt_handle = (void**) request;
stype = ompi_dtype_2_dte_dtype(sdtype);
rtype = ompi_dtype_2_dte_dtype(rdtype);
if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype)
|| HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype)))
&& mca_coll_hcoll_component.hcoll_datatype_fallback){
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
/*If we are here then datatype is not simple predefined datatype */
/*In future we need to add more complex mapping to the dte_data_representation_t */
/* Now use fallback */

Просмотреть файл

@ -44,6 +44,7 @@
#include "hcoll/api/hcoll_dte.h"
#include "hcoll/api/hcoll_api.h"
#include "hcoll/api/hcoll_constants.h"
#include "coll_hcoll_dtypes.h"
/*
* Local functions
*/
@ -101,6 +102,22 @@ static int group_id(rte_grp_handle_t group);
static int world_rank(rte_grp_handle_t grp_h, rte_ec_handle_t ec);
/* Module Constructors */
#if HCOLL_API >= HCOLL_VERSION(3,6)
static int get_mpi_type_envelope(void *mpi_type, int *num_integers,
int *num_addresses, int *num_datatypes,
hcoll_mpi_type_combiner_t *combiner);
static int get_mpi_type_contents(void *mpi_type, int max_integers, int max_addresses,
int max_datatypes, int *array_of_integers,
void *array_of_addresses, void *array_of_datatypes);
static int get_hcoll_type(void *mpi_type, dte_data_representation_t *hcoll_type);
static int set_hcoll_type(void *mpi_type, dte_data_representation_t hcoll_type);
static int get_mpi_constants(size_t *mpi_datatype_size,
int *mpi_order_c, int *mpi_order_fortran,
int *mpi_distribute_block,
int *mpi_distribute_cyclic,
int *mpi_distribute_none,
int *mpi_distribute_dflt_darg);
#endif
static void init_module_fns(void){
hcoll_rte_functions.send_fn = send_nb;
@ -120,6 +137,13 @@ static void init_module_fns(void){
hcoll_rte_functions.rte_coll_handle_complete_fn = coll_handle_complete;
hcoll_rte_functions.rte_group_id_fn = group_id;
hcoll_rte_functions.rte_world_rank_fn = world_rank;
#if HCOLL_API >= HCOLL_VERSION(3,6)
hcoll_rte_functions.rte_get_mpi_type_envelope_fn = get_mpi_type_envelope;
hcoll_rte_functions.rte_get_mpi_type_contents_fn = get_mpi_type_contents;
hcoll_rte_functions.rte_get_hcoll_type_fn = get_hcoll_type;
hcoll_rte_functions.rte_set_hcoll_type_fn = set_hcoll_type;
hcoll_rte_functions.rte_get_mpi_constants_fn = get_mpi_constants;
#endif
}
@ -148,22 +172,6 @@ void hcoll_rte_fns_setup(void)
);
}
/* This one converts dte_general_representation data into regular iovec array which is
used in rml
*/
static inline int count_total_dte_repeat_entries(struct dte_data_representation_t *data){
unsigned int i;
struct dte_generalized_iovec_t * dte_iovec =
data->rep.general_rep->data_representation.data;
int total_entries_number = 0;
for (i=0; i< dte_iovec->repeat_count; i++){
total_entries_number += dte_iovec->repeat[i].n_elements;
}
return total_entries_number;
}
static int recv_nb(struct dte_data_representation_t data,
uint32_t count ,
void *buffer,
@ -179,55 +187,27 @@ static int recv_nb(struct dte_data_representation_t data,
"ec_h.handle = %p, ec_h.rank = %d\n",ec_h.handle,ec_h.rank);
return 1;
}
if (HCOL_DTE_IS_INLINE(data)){
/*do inline nb recv*/
size_t size;
ompi_request_t *ompi_req;
if (!buffer && !HCOL_DTE_IS_ZERO(data)) {
fprintf(stderr, "***Error in hcolrte_rml_recv_nb: buffer pointer is NULL"
" for non DTE_ZERO INLINE data representation\n");
return 1;
}
size = (size_t)data.rep.in_line_rep.data_handle.in_line.packed_size*count/8;
HCOL_VERBOSE(30,"PML_IRECV: dest = %d: buf = %p: size = %u: comm = %p",
ec_h.rank, buffer, (unsigned int)size, (void *)comm);
if (MCA_PML_CALL(irecv(buffer,size,&(ompi_mpi_unsigned_char.dt),ec_h.rank,
tag,comm,&ompi_req)))
{
return 1;
}
req->data = (void *)ompi_req;
req->status = HCOLRTE_REQUEST_ACTIVE;
}else{
/*do iovec nb recv*/
int total_entries_number;
int i;
unsigned int j;
void *buf;
uint64_t len;
int repeat_count;
struct dte_struct_t * repeat;
if (NULL != buffer) {
/* We have a full data description & buffer pointer simultaneously.
It is ambiguous. Throw a warning since the user might have made a
mistake with data reps*/
fprintf(stderr,"Warning: buffer_pointer != NULL for NON-inline data representation: buffer_pointer is ignored.\n");
}
total_entries_number = count_total_dte_repeat_entries(&data);
repeat = data.rep.general_rep->data_representation.data->repeat;
repeat_count = data.rep.general_rep->data_representation.data->repeat_count;
for (i=0; i< repeat_count; i++){
for (j=0; j<repeat[i].n_elements; j++){
char *repeat_unit = (char *)&repeat[i];
buf = (void *)(repeat_unit+repeat[i].elements[j].base_offset);
len = repeat[i].elements[j].packed_size;
recv_nb(DTE_BYTE,len,buf,ec_h,grp_h,tag,req);
}
}
assert(HCOL_DTE_IS_INLINE(data));
/*do inline nb recv*/
size_t size;
ompi_request_t *ompi_req;
if (!buffer && !HCOL_DTE_IS_ZERO(data)) {
fprintf(stderr, "***Error in hcolrte_rml_recv_nb: buffer pointer is NULL"
" for non DTE_ZERO INLINE data representation\n");
return 1;
}
size = (size_t)data.rep.in_line_rep.data_handle.in_line.packed_size*count/8;
HCOL_VERBOSE(30,"PML_IRECV: dest = %d: buf = %p: size = %u: comm = %p",
ec_h.rank, buffer, (unsigned int)size, (void *)comm);
if (MCA_PML_CALL(irecv(buffer,size,&(ompi_mpi_unsigned_char.dt),ec_h.rank,
tag,comm,&ompi_req)))
{
return 1;
}
req->data = (void *)ompi_req;
req->status = HCOLRTE_REQUEST_ACTIVE;
return HCOLL_SUCCESS;
}
@ -248,51 +228,25 @@ static int send_nb( dte_data_representation_t data,
"ec_h.handle = %p, ec_h.rank = %d\n",ec_h.handle,ec_h.rank);
return 1;
}
if (HCOL_DTE_IS_INLINE(data)){
/*do inline nb recv*/
size_t size;
ompi_request_t *ompi_req;
if (!buffer && !HCOL_DTE_IS_ZERO(data)) {
fprintf(stderr, "***Error in hcolrte_rml_send_nb: buffer pointer is NULL"
" for non DTE_ZERO INLINE data representation\n");
return 1;
}
size = (size_t)data.rep.in_line_rep.data_handle.in_line.packed_size*count/8;
HCOL_VERBOSE(30,"PML_ISEND: dest = %d: buf = %p: size = %u: comm = %p",
ec_h.rank, buffer, (unsigned int)size, (void *)comm);
if (MCA_PML_CALL(isend(buffer,size,&(ompi_mpi_unsigned_char.dt),ec_h.rank,
tag,MCA_PML_BASE_SEND_STANDARD,comm,&ompi_req)))
{
return 1;
}
req->data = (void *)ompi_req;
req->status = HCOLRTE_REQUEST_ACTIVE;
}else{
int total_entries_number;
int i;
unsigned int j;
void *buf;
uint64_t len;
int repeat_count;
struct dte_struct_t * repeat;
if (NULL != buffer) {
/* We have a full data description & buffer pointer simultaneously.
It is ambiguous. Throw a warning since the user might have made a
mistake with data reps*/
fprintf(stderr,"Warning: buffer_pointer != NULL for NON-inline data representation: buffer_pointer is ignored.\n");
}
total_entries_number = count_total_dte_repeat_entries(&data);
repeat = data.rep.general_rep->data_representation.data->repeat;
repeat_count = data.rep.general_rep->data_representation.data->repeat_count;
for (i=0; i< repeat_count; i++){
for (j=0; j<repeat[i].n_elements; j++){
char *repeat_unit = (char *)&repeat[i];
buf = (void *)(repeat_unit+repeat[i].elements[j].base_offset);
len = repeat[i].elements[j].packed_size;
send_nb(DTE_BYTE,len,buf,ec_h,grp_h,tag,req);
}
}
assert(HCOL_DTE_IS_INLINE(data));
/*do inline nb recv*/
size_t size;
ompi_request_t *ompi_req;
if (!buffer && !HCOL_DTE_IS_ZERO(data)) {
fprintf(stderr, "***Error in hcolrte_rml_send_nb: buffer pointer is NULL"
" for non DTE_ZERO INLINE data representation\n");
return 1;
}
size = (size_t)data.rep.in_line_rep.data_handle.in_line.packed_size*count/8;
HCOL_VERBOSE(30,"PML_ISEND: dest = %d: buf = %p: size = %u: comm = %p",
ec_h.rank, buffer, (unsigned int)size, (void *)comm);
if (MCA_PML_CALL(isend(buffer,size,&(ompi_mpi_unsigned_char.dt),ec_h.rank,
tag,MCA_PML_BASE_SEND_STANDARD,comm,&ompi_req)))
{
return 1;
}
req->data = (void *)ompi_req;
req->status = HCOLRTE_REQUEST_ACTIVE;
return HCOLL_SUCCESS;
}
@ -306,7 +260,7 @@ static int test( rte_request_handle_t * request ,
}
/*ompi_request_test(&ompi_req,completed,MPI_STATUS_IGNORE); */
*completed = ompi_req->req_complete;
*completed = REQUEST_COMPLETE(ompi_req);
if (*completed){
ompi_request_free(&ompi_req);
request->status = HCOLRTE_REQUEST_DONE;
@ -415,7 +369,7 @@ static void* get_coll_handle(void)
static int coll_handle_test(void* handle)
{
ompi_request_t *ompi_req = (ompi_request_t *)handle;
return ompi_req->req_complete;
return REQUEST_COMPLETE(ompi_req);;
}
static void coll_handle_free(void *handle){
@ -435,3 +389,108 @@ static int world_rank(rte_grp_handle_t grp_h, rte_ec_handle_t ec){
ompi_proc_t *proc = (ompi_proc_t *)ec.handle;
return ((ompi_process_name_t*)&proc->super.proc_name)->vpid;
}
#if HCOLL_API >= HCOLL_VERSION(3,6)
hcoll_mpi_type_combiner_t ompi_combiner_2_hcoll_combiner(int ompi_combiner) {
switch (ompi_combiner)
{
case MPI_COMBINER_CONTIGUOUS:
return HCOLL_MPI_COMBINER_CONTIGUOUS;
case MPI_COMBINER_VECTOR:
return HCOLL_MPI_COMBINER_VECTOR;
case MPI_COMBINER_HVECTOR:
return HCOLL_MPI_COMBINER_HVECTOR;
case MPI_COMBINER_INDEXED:
return HCOLL_MPI_COMBINER_INDEXED;
case MPI_COMBINER_HINDEXED_INTEGER:
case MPI_COMBINER_HINDEXED:
return HCOLL_MPI_COMBINER_HINDEXED;
case MPI_COMBINER_DUP:
return HCOLL_MPI_COMBINER_DUP;
case MPI_COMBINER_INDEXED_BLOCK:
return HCOLL_MPI_COMBINER_INDEXED_BLOCK;
case MPI_COMBINER_HINDEXED_BLOCK:
return HCOLL_MPI_COMBINER_HINDEXED_BLOCK;
case MPI_COMBINER_SUBARRAY:
return HCOLL_MPI_COMBINER_SUBARRAY;
case MPI_COMBINER_DARRAY:
return HCOLL_MPI_COMBINER_DARRAY;
case MPI_COMBINER_F90_REAL:
return HCOLL_MPI_COMBINER_F90_REAL;
case MPI_COMBINER_F90_COMPLEX:
return HCOLL_MPI_COMBINER_F90_COMPLEX;
case MPI_COMBINER_F90_INTEGER:
return HCOLL_MPI_COMBINER_F90_INTEGER;
case MPI_COMBINER_RESIZED:
return HCOLL_MPI_COMBINER_RESIZED;
case MPI_COMBINER_STRUCT:
case MPI_COMBINER_STRUCT_INTEGER:
return HCOLL_MPI_COMBINER_STRUCT;
default:
break;
}
return HCOLL_MPI_COMBINER_LAST;
}
static int get_mpi_type_envelope(void *mpi_type, int *num_integers,
int *num_addresses, int *num_datatypes,
hcoll_mpi_type_combiner_t *combiner) {
int ompi_combiner, rc;
rc = ompi_datatype_get_args( (ompi_datatype_t*)mpi_type, 0, num_integers, NULL,
num_addresses, NULL,
num_datatypes, NULL, &ompi_combiner);
*combiner = ompi_combiner_2_hcoll_combiner(ompi_combiner);
return rc == OMPI_SUCCESS ? HCOLL_SUCCESS : HCOLL_ERROR;
}
static int get_mpi_type_contents(void *mpi_type, int max_integers, int max_addresses,
int max_datatypes, int *array_of_integers,
void *array_of_addresses, void *array_of_datatypes) {
int rc;
rc = ompi_datatype_get_args( (ompi_datatype_t*)mpi_type, 1, &max_integers, array_of_integers,
&max_addresses, array_of_addresses,
&max_datatypes, array_of_datatypes, NULL );
return rc == OMPI_SUCCESS ? HCOLL_SUCCESS : HCOLL_ERROR;
}
static int get_hcoll_type(void *mpi_type, dte_data_representation_t *hcoll_type) {
*hcoll_type = ompi_dtype_2_hcoll_dtype((ompi_datatype_t*)mpi_type, TRY_FIND_DERIVED);
return HCOL_DTE_IS_ZERO((*hcoll_type)) ? HCOLL_ERR_NOT_FOUND : HCOLL_SUCCESS;
}
static int set_hcoll_type(void *mpi_type, dte_data_representation_t hcoll_type) {
int rc;
mca_coll_hcoll_dtype_t *hcoll_dtype = (mca_coll_hcoll_dtype_t*)
opal_free_list_get(&mca_coll_hcoll_component.dtypes);
ompi_datatype_t *dtype = (ompi_datatype_t*)mpi_type;
hcoll_dtype->type = hcoll_type;
rc = ompi_attr_set_c(TYPE_ATTR, (void*)dtype, &(dtype->d_keyhash), hcoll_type_attr_keyval, (void *)hcoll_dtype, false);
if (OMPI_SUCCESS != rc) {
HCOL_VERBOSE(1,"hcoll ompi_attr_set_c failed for derived dtype");
goto Cleanup;
}
return HCOLL_SUCCESS;
Cleanup:
opal_free_list_return(&mca_coll_hcoll_component.dtypes,
&hcoll_dtype->super);
return rc;
}
static int get_mpi_constants(size_t *mpi_datatype_size,
int *mpi_order_c, int *mpi_order_fortran,
int *mpi_distribute_block,
int *mpi_distribute_cyclic,
int *mpi_distribute_none,
int *mpi_distribute_dflt_darg) {
*mpi_datatype_size = sizeof(MPI_Datatype);
*mpi_order_c = MPI_ORDER_C;
*mpi_order_fortran = MPI_ORDER_FORTRAN;
*mpi_distribute_block = MPI_DISTRIBUTE_BLOCK;
*mpi_distribute_cyclic = MPI_DISTRIBUTE_CYCLIC;
*mpi_distribute_none = MPI_DISTRIBUTE_NONE;
*mpi_distribute_dflt_darg = MPI_DISTRIBUTE_DFLT_DARG;
return HCOLL_SUCCESS;
}
#endif