diff --git a/ompi/datatype/dt_create.c b/ompi/datatype/dt_create.c index b9091a0fe0..019fbbebfa 100644 --- a/ompi/datatype/dt_create.c +++ b/ompi/datatype/dt_create.c @@ -45,6 +45,7 @@ static void __get_free_dt_struct( ompi_datatype_t* pData ) pData->ub = LONG_MIN; pData->d_f_to_c_index = ompi_pointer_array_add(ompi_datatype_f_to_c_table, pData); pData->d_keyhash = NULL; + pData->name[0] = '\0'; } static void __destroy_ddt_struct( ompi_datatype_t* pData ) diff --git a/ompi/mpi/c/allreduce.c b/ompi/mpi/c/allreduce.c index 5e31f69d71..5ecfaad5e8 100644 --- a/ompi/mpi/c/allreduce.c +++ b/ompi/mpi/c/allreduce.c @@ -38,6 +38,7 @@ int MPI_Allreduce(void *sendbuf, void *recvbuf, int count, int err; if (MPI_PARAM_CHECK) { + char *msg; /* Unrooted operation -- same checks for all ranks on both intracommunicators and intercommunicators */ @@ -49,10 +50,10 @@ int MPI_Allreduce(void *sendbuf, void *recvbuf, int count, FUNC_NAME); } else if (MPI_OP_NULL == op) { err = MPI_ERR_OP; - } else if (ompi_op_is_intrinsic(op) && - datatype->id < DT_MAX_PREDEFINED && - -1 == ompi_op_ddt_map[datatype->id]) { - err = MPI_ERR_OP; + } else if (!ompi_op_is_valid(op, datatype, &msg, FUNC_NAME)) { + int ret = OMPI_ERRHANDLER_INVOKE(MPI_COMM_WORLD, MPI_ERR_OP, msg); + free(msg); + return ret; } else { OMPI_CHECK_DATATYPE_FOR_SEND(err, datatype, count); } diff --git a/ompi/mpi/c/exscan.c b/ompi/mpi/c/exscan.c index 2d6d1c5bd4..96c1a343f9 100644 --- a/ompi/mpi/c/exscan.c +++ b/ompi/mpi/c/exscan.c @@ -37,6 +37,7 @@ int MPI_Exscan(void *sendbuf, void *recvbuf, int count, int err; if (MPI_PARAM_CHECK) { + char *msg; err = MPI_SUCCESS; OMPI_ERR_INIT_FINALIZE(FUNC_NAME); if (ompi_comm_invalid(comm)) { @@ -49,10 +50,10 @@ int MPI_Exscan(void *sendbuf, void *recvbuf, int count, else if (MPI_OP_NULL == op) { err = MPI_ERR_OP; - } else if (ompi_op_is_intrinsic(op) && - datatype->id < DT_MAX_PREDEFINED && - -1 == ompi_op_ddt_map[datatype->id]) { - err = MPI_ERR_OP; + } else if (!ompi_op_is_valid(op, datatype, &msg, FUNC_NAME)) { + int ret = OMPI_ERRHANDLER_INVOKE(MPI_COMM_WORLD, MPI_ERR_OP, msg); + free(msg); + return ret; } else { OMPI_CHECK_DATATYPE_FOR_SEND(err, datatype, count); } diff --git a/ompi/mpi/c/reduce.c b/ompi/mpi/c/reduce.c index fffe1f1b4f..3115b9d057 100644 --- a/ompi/mpi/c/reduce.c +++ b/ompi/mpi/c/reduce.c @@ -37,6 +37,7 @@ int MPI_Reduce(void *sendbuf, void *recvbuf, int count, int err; if (MPI_PARAM_CHECK) { + char *msg; err = MPI_SUCCESS; OMPI_ERR_INIT_FINALIZE(FUNC_NAME); if (ompi_comm_invalid(comm)) { @@ -48,10 +49,10 @@ int MPI_Reduce(void *sendbuf, void *recvbuf, int count, else if (MPI_OP_NULL == op) { err = MPI_ERR_OP; - } else if (ompi_op_is_intrinsic(op) && - datatype->id < DT_MAX_PREDEFINED && - -1 == ompi_op_ddt_map[datatype->id]) { - err = MPI_ERR_OP; + } else if (!ompi_op_is_valid(op, datatype, &msg, FUNC_NAME)) { + int ret = OMPI_ERRHANDLER_INVOKE(MPI_COMM_WORLD, MPI_ERR_OP, msg); + free(msg); + return ret; } else if ((root != ompi_comm_rank(comm) && MPI_IN_PLACE == sendbuf) || MPI_IN_PLACE == recvbuf) { err = MPI_ERR_ARG; diff --git a/ompi/mpi/c/reduce_scatter.c b/ompi/mpi/c/reduce_scatter.c index 221e3b4a90..127db03e86 100644 --- a/ompi/mpi/c/reduce_scatter.c +++ b/ompi/mpi/c/reduce_scatter.c @@ -37,6 +37,7 @@ int MPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, int i, err, size, count; if (MPI_PARAM_CHECK) { + char *msg; err = MPI_SUCCESS; OMPI_ERR_INIT_FINALIZE(FUNC_NAME); if (ompi_comm_invalid(comm)) { @@ -49,10 +50,10 @@ int MPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, else if (MPI_OP_NULL == op) { err = MPI_ERR_OP; - } else if (ompi_op_is_intrinsic(op) && - datatype->id < DT_MAX_PREDEFINED && - -1 == ompi_op_ddt_map[datatype->id]) { - err = MPI_ERR_OP; + } else if (!ompi_op_is_valid(op, datatype, &msg, FUNC_NAME)) { + int ret = OMPI_ERRHANDLER_INVOKE(MPI_COMM_WORLD, MPI_ERR_OP, msg); + free(msg); + return ret; } else if (NULL == recvcounts) { err = MPI_ERR_COUNT; } else if (MPI_IN_PLACE == recvbuf) { diff --git a/ompi/mpi/c/scan.c b/ompi/mpi/c/scan.c index 9e01a2624e..5608563db1 100644 --- a/ompi/mpi/c/scan.c +++ b/ompi/mpi/c/scan.c @@ -37,6 +37,7 @@ int MPI_Scan(void *sendbuf, void *recvbuf, int count, int err; if (MPI_PARAM_CHECK) { + char *msg; err = MPI_SUCCESS; OMPI_ERR_INIT_FINALIZE(FUNC_NAME); if (ompi_comm_invalid(comm)) { @@ -57,10 +58,10 @@ int MPI_Scan(void *sendbuf, void *recvbuf, int count, err = MPI_ERR_OP; } else if (MPI_IN_PLACE == recvbuf) { err = MPI_ERR_ARG; - } else if (ompi_op_is_intrinsic(op) && - datatype->id < DT_MAX_PREDEFINED && - -1 == ompi_op_ddt_map[datatype->id]) { - err = MPI_ERR_OP; + } else if (!ompi_op_is_valid(op, datatype, &msg, FUNC_NAME)) { + int ret = OMPI_ERRHANDLER_INVOKE(MPI_COMM_WORLD, MPI_ERR_OP, msg); + free(msg); + return ret; } else { OMPI_CHECK_DATATYPE_FOR_SEND(err, datatype, count); } diff --git a/ompi/op/op.h b/ompi/op/op.h index a09d2b58f0..b842bea8d2 100644 --- a/ompi/op/op.h +++ b/ompi/op/op.h @@ -452,6 +452,56 @@ static inline bool ompi_op_is_float_assoc(ompi_op_t *op) } +/** + * Check to see if an op is valid on a given datatype + * + * @param op The op to check + * @param ddt The datatype to check + * + * @returns true If the op is valid on that datatype + * @returns false If the op is not valid on that datatype + * + * Self-explanitory. This is needed in a few top-level MPI functions; + * this function is provided to hide the internal structure field + * names. + */ +static inline bool ompi_op_is_valid(ompi_op_t *op, ompi_datatype_t *ddt, + char **msg, const char *func) +{ + /* Check: + - non-intrinsic ddt's cannot be invoked on intrinsic op's + - if intrinsic ddt invoked on intrinsic op: + - ensure the datatype is defined in the op map + - ensure we have a function pointer for that combination + */ + + if (ompi_op_is_intrinsic(op)) { + if (0 != (ddt->flags & DT_FLAG_PREDEFINED)) { + /* Intrinsic ddt on intrinsic op */ + if ((-1 == ompi_op_ddt_map[ddt->id] || + (0 != (op->o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC) && + NULL == op->o_func[ompi_op_ddt_map[ddt->id]].fort_fn) || + (0 == (op->o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC) && + NULL == op->o_func[ompi_op_ddt_map[ddt->id]].c_fn))) { + asprintf(msg, "%s: the reduction operation %s is not defined on the %s datatype", func, op->o_name, ddt->name); + return false; + } + } else { + /* Non-intrinsic ddt on intrinsic op */ + if ('\0' != ddt->name[0]) { + asprintf(msg, "%s: the reduction operation %s is not defined for non-intrinsic datatypes (attempted with datatype named \"%s\")", func, op->o_name, ddt->name); + } else { + asprintf(msg, "%s: the reduction operation %s is not defined for non-intrinsic datatypes", func, op->o_name); + } + return false; + } + } + + /* All other cases ok */ + return true; +} + + /** * Perform a reduction operation. * @@ -502,7 +552,7 @@ static inline void ompi_op_reduce(ompi_op_t *op, void *source, void *target, */ if (0 != (op->o_flags & OMPI_OP_FLAGS_INTRINSIC) && - dtype->id < DT_MAX_PREDEFINED) { + 0 != (dtype->flags & DT_FLAG_PREDEFINED)) { if (0 != (op->o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC)) { f_dtype = OMPI_INT_2_FINT(dtype->d_f_to_c_index); f_count = OMPI_INT_2_FINT(count);