Merge pull request #5163 from mkurnosov/reduce-scatter-block-rhalving
coll: reduce_scatter_block: add recursive halving algorithm
Этот коммит содержится в:
Коммит
9f353fe3d4
@ -252,6 +252,7 @@ int ompi_coll_base_reduce_scatter_intra_ring(REDUCESCATTER_ARGS);
|
||||
/* Reduce_scatter_block */
|
||||
int ompi_coll_base_reduce_scatter_block_basic(REDUCESCATTERBLOCK_ARGS);
|
||||
int ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(REDUCESCATTERBLOCK_ARGS);
|
||||
int ompi_coll_base_reduce_scatter_block_intra_recursivehalving(REDUCESCATTERBLOCK_ARGS);
|
||||
|
||||
/* Scan */
|
||||
int ompi_coll_base_scan_intra_recursivedoubling(SCAN_ARGS);
|
||||
|
@ -32,14 +32,15 @@
|
||||
#include "ompi/datatype/ompi_datatype.h"
|
||||
#include "ompi/communicator/communicator.h"
|
||||
#include "ompi/mca/coll/coll.h"
|
||||
#include "ompi/mca/coll/base/coll_tags.h"
|
||||
#include "ompi/mca/coll/base/coll_base_functions.h"
|
||||
#include "ompi/mca/coll/basic/coll_basic.h"
|
||||
#include "ompi/mca/pml/pml.h"
|
||||
#include "ompi/op/op.h"
|
||||
#include "ompi/mca/coll/base/coll_base_functions.h"
|
||||
#include "coll_tags.h"
|
||||
#include "coll_base_functions.h"
|
||||
#include "coll_base_topo.h"
|
||||
#include "coll_base_util.h"
|
||||
|
||||
|
||||
/*
|
||||
* ompi_reduce_scatter_block_basic
|
||||
*
|
||||
@ -303,3 +304,210 @@ cleanup_and_return:
|
||||
free(tmprecv_raw);
|
||||
return err;
|
||||
}
|
||||
|
||||
/*
|
||||
* ompi_range_sum: Returns sum of elems in intersection of [a, b] and [0, r]
|
||||
* index: 0 1 2 3 4 ... r r+1 r+2 ... nproc_pof2
|
||||
* value: 2 2 2 2 2 ... 2 1 1 ... 1
|
||||
*/
|
||||
static int ompi_range_sum(int a, int b, int r)
|
||||
{
|
||||
if (r < a)
|
||||
return b - a + 1;
|
||||
else if (r > b)
|
||||
return 2 * (b - a + 1);
|
||||
return 2 * (r - a + 1) + b - r;
|
||||
}
|
||||
|
||||
/*
|
||||
* ompi_coll_base_reduce_scatter_block_intra_recursivehalving
|
||||
*
|
||||
* Function: Recursive halving algorithm for reduce_scatter_block
|
||||
* Accepts: Same as MPI_Reduce_scatter_block
|
||||
* Returns: MPI_SUCCESS or error code
|
||||
*
|
||||
* Description: Implements recursive halving algorithm for MPI_Reduce_scatter_block.
|
||||
* The algorithm can be used by commutative operations only.
|
||||
*
|
||||
* Limitations: commutative operations only
|
||||
* Memory requirements (per process): 2 * rcount * comm_size * typesize
|
||||
*/
|
||||
int
|
||||
ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
|
||||
const void *sbuf, void *rbuf, int rcount, struct ompi_datatype_t *dtype,
|
||||
struct ompi_op_t *op, struct ompi_communicator_t *comm,
|
||||
mca_coll_base_module_t *module)
|
||||
{
|
||||
char *tmprecv_raw = NULL, *tmpbuf_raw = NULL, *tmprecv, *tmpbuf;
|
||||
ptrdiff_t span, gap, totalcount, extent;
|
||||
int err = MPI_SUCCESS;
|
||||
int comm_size = ompi_comm_size(comm);
|
||||
int rank = ompi_comm_rank(comm);
|
||||
|
||||
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
|
||||
"coll:base:reduce_scatter_block_intra_recursivehalving: rank %d/%d",
|
||||
rank, comm_size));
|
||||
if (rcount == 0 || comm_size < 2)
|
||||
return MPI_SUCCESS;
|
||||
|
||||
if (!ompi_op_is_commute(op)) {
|
||||
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
|
||||
"coll:base:reduce_scatter_block_intra_recursivehalving: rank %d/%d "
|
||||
"switching to basic reduce_scatter_block", rank, comm_size));
|
||||
return ompi_coll_base_reduce_scatter_block_basic(sbuf, rbuf, rcount, dtype,
|
||||
op, comm, module);
|
||||
}
|
||||
totalcount = comm_size * rcount;
|
||||
ompi_datatype_type_extent(dtype, &extent);
|
||||
span = opal_datatype_span(&dtype->super, totalcount, &gap);
|
||||
tmpbuf_raw = malloc(span);
|
||||
tmprecv_raw = malloc(span);
|
||||
if (NULL == tmpbuf_raw || NULL == tmprecv_raw) {
|
||||
err = OMPI_ERR_OUT_OF_RESOURCE;
|
||||
goto cleanup_and_return;
|
||||
}
|
||||
tmpbuf = tmpbuf_raw - gap;
|
||||
tmprecv = tmprecv_raw - gap;
|
||||
|
||||
if (sbuf != MPI_IN_PLACE) {
|
||||
err = ompi_datatype_copy_content_same_ddt(dtype, totalcount, tmpbuf, (char *)sbuf);
|
||||
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
|
||||
} else {
|
||||
err = ompi_datatype_copy_content_same_ddt(dtype, totalcount, tmpbuf, rbuf);
|
||||
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
|
||||
}
|
||||
|
||||
/*
|
||||
* Step 1. Reduce the number of processes to the nearest lower power of two
|
||||
* p' = 2^{\floor{\log_2 p}} by removing r = p - p' processes.
|
||||
* In the first 2r processes (ranks 0 to 2r - 1), all the even ranks send
|
||||
* the input vector to their neighbor (rank + 1) and all the odd ranks recv
|
||||
* the input vector and perform local reduction.
|
||||
* The odd ranks (0 to 2r - 1) contain the reduction with the input
|
||||
* vector on their neighbors (the even ranks). The first r odd
|
||||
* processes and the p - 2r last processes are renumbered from
|
||||
* 0 to 2^{\floor{\log_2 p}} - 1. Even ranks do not participate in the
|
||||
* rest of the algorithm.
|
||||
*/
|
||||
|
||||
/* Find nearest power-of-two less than or equal to comm_size */
|
||||
int nprocs_pof2 = opal_next_poweroftwo(comm_size);
|
||||
nprocs_pof2 >>= 1;
|
||||
int nprocs_rem = comm_size - nprocs_pof2;
|
||||
|
||||
int vrank = -1;
|
||||
if (rank < 2 * nprocs_rem) {
|
||||
if ((rank % 2) == 0) {
|
||||
/* Even process */
|
||||
err = MCA_PML_CALL(send(tmpbuf, totalcount, dtype, rank + 1,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
|
||||
MCA_PML_BASE_SEND_STANDARD, comm));
|
||||
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
|
||||
/* This process does not pariticipate in the rest of the algorithm */
|
||||
vrank = -1;
|
||||
} else {
|
||||
/* Odd process */
|
||||
err = MCA_PML_CALL(recv(tmprecv, totalcount, dtype, rank - 1,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
|
||||
comm, MPI_STATUS_IGNORE));
|
||||
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
|
||||
ompi_op_reduce(op, tmprecv, tmpbuf, totalcount, dtype);
|
||||
/* Adjust rank to be the bottom "remain" ranks */
|
||||
vrank = rank / 2;
|
||||
}
|
||||
} else {
|
||||
/* Adjust rank to show that the bottom "even remain" ranks dropped out */
|
||||
vrank = rank - nprocs_rem;
|
||||
}
|
||||
|
||||
if (vrank != -1) {
|
||||
/*
|
||||
* Step 2. Recursive vector halving. We have p' = 2^{\floor{\log_2 p}}
|
||||
* power-of-two number of processes with new ranks (vrank) and partial
|
||||
* result in tmpbuf.
|
||||
* All processes then compute the reduction between the local
|
||||
* buffer and the received buffer. In the next \log_2(p') - 1 steps, the
|
||||
* buffers are recursively halved. At the end, each of the p' processes
|
||||
* has 1 / p' of the total reduction result.
|
||||
*/
|
||||
int send_index = 0, recv_index = 0, last_index = nprocs_pof2;
|
||||
for (int mask = nprocs_pof2 >> 1; mask > 0; mask >>= 1) {
|
||||
int vpeer = vrank ^ mask;
|
||||
int peer = (vpeer < nprocs_rem) ? vpeer * 2 + 1 : vpeer + nprocs_rem;
|
||||
|
||||
/*
|
||||
* Calculate the recv_count and send_count because the
|
||||
* even-numbered processes who no longer participate will
|
||||
* have their result calculated by the process to their
|
||||
* right (rank + 1).
|
||||
*/
|
||||
int send_count = 0, recv_count = 0;
|
||||
if (vrank < vpeer) {
|
||||
/* Send the right half of the buffer, recv the left half */
|
||||
send_index = recv_index + mask;
|
||||
send_count = rcount * ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
|
||||
recv_count = rcount * ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
|
||||
} else {
|
||||
/* Send the left half of the buffer, recv the right half */
|
||||
recv_index = send_index + mask;
|
||||
send_count = rcount * ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
|
||||
recv_count = rcount * ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
|
||||
}
|
||||
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
|
||||
2 * recv_index : nprocs_rem + recv_index);
|
||||
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
|
||||
2 * send_index : nprocs_rem + send_index);
|
||||
struct ompi_request_t *request = NULL;
|
||||
|
||||
if (recv_count > 0) {
|
||||
err = MCA_PML_CALL(irecv(tmprecv + rdispl * extent, recv_count,
|
||||
dtype, peer, MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
|
||||
comm, &request));
|
||||
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
|
||||
}
|
||||
if (send_count > 0) {
|
||||
err = MCA_PML_CALL(send(tmpbuf + sdispl * extent, send_count,
|
||||
dtype, peer, MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
|
||||
MCA_PML_BASE_SEND_STANDARD,
|
||||
comm));
|
||||
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
|
||||
}
|
||||
if (recv_count > 0) {
|
||||
err = ompi_request_wait(&request, MPI_STATUS_IGNORE);
|
||||
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
|
||||
ompi_op_reduce(op, tmprecv + rdispl * extent,
|
||||
tmpbuf + rdispl * extent, recv_count, dtype);
|
||||
}
|
||||
send_index = recv_index;
|
||||
last_index = recv_index + mask;
|
||||
}
|
||||
err = ompi_datatype_copy_content_same_ddt(dtype, rcount, rbuf,
|
||||
tmpbuf + (ptrdiff_t)rank * rcount * extent);
|
||||
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
|
||||
}
|
||||
|
||||
/* Step 3. Send the result to excluded even ranks */
|
||||
if (rank < 2 * nprocs_rem) {
|
||||
if ((rank % 2) == 0) {
|
||||
/* Even process */
|
||||
err = MCA_PML_CALL(recv(rbuf, rcount, dtype, rank + 1,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK, comm,
|
||||
MPI_STATUS_IGNORE));
|
||||
if (OMPI_SUCCESS != err) { goto cleanup_and_return; }
|
||||
} else {
|
||||
/* Odd process */
|
||||
err = MCA_PML_CALL(send(tmpbuf + (ptrdiff_t)(rank - 1) * rcount * extent,
|
||||
rcount, dtype, rank - 1,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
|
||||
MCA_PML_BASE_SEND_STANDARD, comm));
|
||||
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
|
||||
}
|
||||
}
|
||||
|
||||
cleanup_and_return:
|
||||
if (tmpbuf_raw)
|
||||
free(tmpbuf_raw);
|
||||
if (tmprecv_raw)
|
||||
free(tmprecv_raw);
|
||||
return err;
|
||||
}
|
||||
|
@ -35,6 +35,7 @@ static mca_base_var_enum_value_t reduce_scatter_block_algorithms[] = {
|
||||
{0, "ignore"},
|
||||
{1, "basic"},
|
||||
{2, "recursive_doubling"},
|
||||
{3, "recursive_halving"},
|
||||
{0, NULL}
|
||||
};
|
||||
|
||||
@ -125,6 +126,8 @@ int ompi_coll_tuned_reduce_scatter_block_intra_do_this(const void *sbuf, void *r
|
||||
dtype, op, comm, module);
|
||||
case (2): return ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(sbuf, rbuf, rcount,
|
||||
dtype, op, comm, module);
|
||||
case (3): return ompi_coll_base_reduce_scatter_block_intra_recursivehalving(sbuf, rbuf, rcount,
|
||||
dtype, op, comm, module);
|
||||
} /* switch */
|
||||
OPAL_OUTPUT((ompi_coll_tuned_stream, "coll:tuned:reduce_scatter_block_intra_do_this attempt to select algorithm %d when only 0-%d is valid?",
|
||||
algorithm, ompi_coll_tuned_forced_max_algorithms[REDUCESCATTERBLOCK]));
|
||||
|
Загрузка…
x
Ссылка в новой задаче
Block a user