diff --git a/ompi/mca/coll/tuned/coll_tuned_decision_fixed.c b/ompi/mca/coll/tuned/coll_tuned_decision_fixed.c index d01cbbcb3c..a8350b3d84 100644 --- a/ompi/mca/coll/tuned/coll_tuned_decision_fixed.c +++ b/ompi/mca/coll/tuned/coll_tuned_decision_fixed.c @@ -416,9 +416,12 @@ int ompi_coll_tuned_reduce_scatter_intra_dec_fixed( void *sbuf, void *rbuf, struct ompi_op_t *op, struct ompi_communicator_t *comm) { - int comm_size, i; + int comm_size, i, pow2; size_t total_message_size, dsize; - const size_t large_message_size = 512 * 1024; + const double a = 0.0012; + const double b = 8.0; + const size_t small_message_size = 12 * 1024; + const size_t large_message_size = 256 * 1024; OPAL_OUTPUT((ompi_coll_tuned_stream, "ompi_coll_tuned_reduce_scatter_intra_dec_fixed")); @@ -438,7 +441,12 @@ int ompi_coll_tuned_reduce_scatter_intra_dec_fixed( void *sbuf, void *rbuf, for (i = 0; i < comm_size; i++) { total_message_size += rcounts[i]; } total_message_size *= dsize; - if (total_message_size <= large_message_size) { + /* compute the nearest power of 2 */ + for (pow2 = 1; pow2 < comm_size; pow2 <<= 1); + + if ((total_message_size <= small_message_size) || + ((total_message_size <= large_message_size) && (pow2 == comm_size)) || + (comm_size >= a * total_message_size + b)) { return ompi_coll_tuned_reduce_scatter_intra_basic_recursivehalving(sbuf, rbuf,