diff --git a/ompi/mca/coll/base/coll_base_bcast.c b/ompi/mca/coll/base/coll_base_bcast.c index 712aa70805..1490d3a0cf 100644 --- a/ompi/mca/coll/base/coll_base_bcast.c +++ b/ompi/mca/coll/base/coll_base_bcast.c @@ -891,3 +891,134 @@ int ompi_coll_base_bcast_intra_scatter_allgather( cleanup_and_return: return err; } + +/* + * ompi_coll_base_bcast_intra_scatter_allgather_ring + * + * Function: Bcast using a binomial tree scatter followed by a ring allgather. + * Accepts: Same arguments as MPI_Bcast + * Returns: MPI_SUCCESS or error code + * + * Limitations: count >= comm_size + * Time complexity: O(\alpha(\log(p) + p) + \beta*m((p-1)/p)) + * Binomial tree scatter: \alpha\log(p) + \beta*m((p-1)/p) + * Ring allgather: 2(p-1)(\alpha + m/p\beta) + * + * Example, p=8, count=8, root=0 + * Binomial tree scatter Ring allgather: p - 1 steps + * 0: --+ --+ --+ [0*******] [0******7] [0*****67] [0****567] ... [01234567] + * 1: | 2| <-+ [*1******] [01******] [01*****7] [01****67] ... [01234567] + * 2: 4| <-+ --+ [**2*****] [*12*****] [012*****] [012****7] ... [01234567] + * 3: | <-+ [***3****] [**23****] [*123****] [0123****] ... [01234567] + * 4: <-+ --+ --+ [****4***] [***34***] [**234***] [*1234***] ... [01234567] + * 5: 2| <-+ [*****5**] [****45**] [***345**] [**2345**] ... [01234567] + * 6: <-+ --+ [******6*] [*****56*] [****456*] [***3456*] ... [01234567] + * 7: <-+ [*******7] [******67] [*****567] [****4567] ... [01234567] + */ +int ompi_coll_base_bcast_intra_scatter_allgather_ring( + void *buf, int count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module, + uint32_t segsize) +{ + int err = MPI_SUCCESS; + ptrdiff_t lb, extent; + size_t datatype_size; + MPI_Status status; + ompi_datatype_get_extent(datatype, &lb, &extent); + ompi_datatype_type_size(datatype, &datatype_size); + int comm_size = ompi_comm_size(comm); + int rank = ompi_comm_rank(comm); + + OPAL_OUTPUT((ompi_coll_base_framework.framework_output, + "coll:base:bcast_intra_scatter_allgather_ring: rank %d/%d", + rank, comm_size)); + if (comm_size < 2 || datatype_size == 0) + return MPI_SUCCESS; + + if (count < comm_size) { + OPAL_OUTPUT((ompi_coll_base_framework.framework_output, + "coll:base:bcast_intra_scatter_allgather_ring: rank %d/%d " + "count %d switching to basic linear bcast", + rank, comm_size, count)); + return ompi_coll_base_bcast_intra_basic_linear(buf, count, datatype, + root, comm, module); + } + + int vrank = (rank - root + comm_size) % comm_size; + int recv_count = 0, send_count = 0; + int scatter_count = (count + comm_size - 1) / comm_size; /* ceil(count / comm_size) */ + int curr_count = (rank == root) ? count : 0; + + /* Scatter by binomial tree: receive data from parent */ + int mask = 1; + while (mask < comm_size) { + if (vrank & mask) { + int parent = (rank - mask + comm_size) % comm_size; + /* Compute an upper bound on recv block size */ + recv_count = count - vrank * scatter_count; + if (recv_count <= 0) { + curr_count = 0; + } else { + /* Recv data from parent */ + err = MCA_PML_CALL(recv((char *)buf + (ptrdiff_t)vrank * scatter_count * extent, + recv_count, datatype, parent, + MCA_COLL_BASE_TAG_BCAST, comm, &status)); + if (MPI_SUCCESS != err) { goto cleanup_and_return; } + /* Get received count */ + curr_count = (int)(status._ucount / datatype_size); + } + break; + } + mask <<= 1; + } + + /* Scatter by binomial tree: send data to child processes */ + mask >>= 1; + while (mask > 0) { + if (vrank + mask < comm_size) { + send_count = curr_count - scatter_count * mask; + if (send_count > 0) { + int child = (rank + mask) % comm_size; + err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent, + send_count, datatype, child, + MCA_COLL_BASE_TAG_BCAST, + MCA_PML_BASE_SEND_STANDARD, comm)); + if (MPI_SUCCESS != err) { goto cleanup_and_return; } + curr_count -= send_count; + } + } + mask >>= 1; + } + + /* Allgather by a ring algorithm */ + int left = (rank - 1 + comm_size) % comm_size; + int right = (rank + 1) % comm_size; + int send_block = vrank; + int recv_block = (vrank - 1 + comm_size) % comm_size; + + for (int i = 1; i < comm_size; i++) { + recv_count = (scatter_count < count - recv_block * scatter_count) ? + scatter_count : count - recv_block * scatter_count; + if (recv_count < 0) + recv_count = 0; + ptrdiff_t recv_offset = recv_block * scatter_count * extent; + + send_count = (scatter_count < count - send_block * scatter_count) ? + scatter_count : count - send_block * scatter_count; + if (send_count < 0) + send_count = 0; + ptrdiff_t send_offset = send_block * scatter_count * extent; + + err = ompi_coll_base_sendrecv((char *)buf + send_offset, send_count, + datatype, right, MCA_COLL_BASE_TAG_BCAST, + (char *)buf + recv_offset, recv_count, + datatype, left, MCA_COLL_BASE_TAG_BCAST, + comm, MPI_STATUS_IGNORE, rank); + if (MPI_SUCCESS != err) { goto cleanup_and_return; } + send_block = recv_block; + recv_block = (recv_block - 1 + comm_size) % comm_size; + } + +cleanup_and_return: + return err; +} diff --git a/ompi/mca/coll/base/coll_base_functions.h b/ompi/mca/coll/base/coll_base_functions.h index d5418a0ee8..40de8762eb 100644 --- a/ompi/mca/coll/base/coll_base_functions.h +++ b/ompi/mca/coll/base/coll_base_functions.h @@ -247,6 +247,7 @@ int ompi_coll_base_bcast_intra_bintree(BCAST_ARGS, uint32_t segsize); int ompi_coll_base_bcast_intra_split_bintree(BCAST_ARGS, uint32_t segsize); int ompi_coll_base_bcast_intra_knomial(BCAST_ARGS, uint32_t segsize, int radix); int ompi_coll_base_bcast_intra_scatter_allgather(BCAST_ARGS, uint32_t segsize); +int ompi_coll_base_bcast_intra_scatter_allgather_ring(BCAST_ARGS, uint32_t segsize); /* Exscan */ int ompi_coll_base_exscan_intra_recursivedoubling(EXSCAN_ARGS); diff --git a/ompi/mca/coll/tuned/coll_tuned_bcast_decision.c b/ompi/mca/coll/tuned/coll_tuned_bcast_decision.c index 5ce5e93df6..e3b9ae82a0 100644 --- a/ompi/mca/coll/tuned/coll_tuned_bcast_decision.c +++ b/ompi/mca/coll/tuned/coll_tuned_bcast_decision.c @@ -44,6 +44,7 @@ static mca_base_var_enum_value_t bcast_algorithms[] = { {6, "binomial"}, {7, "knomial"}, {8, "scatter_allgather"}, + {9, "scatter_allgather_ring"}, {0, NULL} }; @@ -79,7 +80,7 @@ int ompi_coll_tuned_bcast_intra_check_forced_init (coll_tuned_force_algorithm_mc mca_param_indices->algorithm_param_index = mca_base_component_var_register(&mca_coll_tuned_component.super.collm_version, "bcast_algorithm", - "Which bcast algorithm is used. Can be locked down to choice of: 0 ignore, 1 basic linear, 2 chain, 3: pipeline, 4: split binary tree, 5: binary tree, 6: binomial tree, 7: knomial tree, 8: scatter_allgather.", + "Which bcast algorithm is used. Can be locked down to choice of: 0 ignore, 1 basic linear, 2 chain, 3: pipeline, 4: split binary tree, 5: binary tree, 6: binomial tree, 7: knomial tree, 8: scatter_allgather, 9: scatter_allgather_ring.", MCA_BASE_VAR_TYPE_INT, new_enum, 0, MCA_BASE_VAR_FLAG_SETTABLE, OPAL_INFO_LVL_5, MCA_BASE_VAR_SCOPE_ALL, @@ -160,6 +161,8 @@ int ompi_coll_tuned_bcast_intra_do_this(void *buf, int count, segsize, coll_tuned_bcast_knomial_radix); case (8): return ompi_coll_base_bcast_intra_scatter_allgather(buf, count, dtype, root, comm, module, segsize); + case (9): + return ompi_coll_base_bcast_intra_scatter_allgather_ring(buf, count, dtype, root, comm, module, segsize); } /* switch */ OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned:bcast_intra_do_this attempt to select algorithm %d when only 0-%d is valid?", algorithm, ompi_coll_tuned_forced_max_algorithms[BCAST]));