Use a recursive halving communication algorithm similar to the one used by
MPICH2 for "small" commutative operations in the reduce_scatter basic implementation. "small" is currently pretty big, as it doesn't take much to beat reduce/scatterv. Need to do much more than this for better all around performance of MPI_Reduce_scatter, but this was enough to solve the problems I was having. This commit was SVN r13348.
Этот коммит содержится в:
родитель
61bc9fed3c
Коммит
93a2f31932
@ -26,9 +26,11 @@
|
||||
#include "ompi/constants.h"
|
||||
#include "ompi/mca/coll/coll.h"
|
||||
#include "ompi/mca/coll/base/coll_tags.h"
|
||||
#include "ompi/datatype/datatype.h"
|
||||
#include "coll_basic.h"
|
||||
#include "ompi/op/op.h"
|
||||
|
||||
#define COMMUTATIVE_LONG_MSG 8 * 1024 * 1024
|
||||
|
||||
/*
|
||||
* reduce_scatter
|
||||
@ -36,6 +38,21 @@
|
||||
* Function: - reduce then scatter
|
||||
* Accepts: - same as MPI_Reduce_scatter()
|
||||
* Returns: - MPI_SUCCESS or error code
|
||||
*
|
||||
* Algorithm:
|
||||
* Cummutative, reasonable sized messages
|
||||
* recursive halving algorithm
|
||||
* Others:
|
||||
* reduce and scatterv (needs to be cleaned
|
||||
* up at some point)
|
||||
*
|
||||
* NOTE: that the recursive halving algorithm should be faster than
|
||||
* the reduce/scatter for all message sizes. However, the memory
|
||||
* usage for the recusive halving is msg_size + 2 * comm_size greater
|
||||
* for the recursive halving, so I've limited where the recursive
|
||||
* halving is used to be nice to the app memory wise. There are much
|
||||
* better algorithms for large messages with cummutative operations,
|
||||
* so this should be investigated further.
|
||||
*/
|
||||
int
|
||||
mca_coll_basic_reduce_scatter_intra(void *sbuf, void *rbuf, int *rcounts,
|
||||
@ -43,78 +60,284 @@ mca_coll_basic_reduce_scatter_intra(void *sbuf, void *rbuf, int *rcounts,
|
||||
struct ompi_op_t *op,
|
||||
struct ompi_communicator_t *comm)
|
||||
{
|
||||
int i, err, rank, size, count;
|
||||
ptrdiff_t true_lb, true_extent, lb, extent;
|
||||
int i, rank, size, count, err = OMPI_SUCCESS;
|
||||
ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
|
||||
int *disps = NULL;
|
||||
char *free_buffer = NULL, *pml_buffer = NULL;
|
||||
char *recv_buf = NULL, *recv_buf_free = NULL;
|
||||
char *result_buf = NULL, *result_buf_free = NULL;
|
||||
|
||||
/* Initialize */
|
||||
|
||||
rank = ompi_comm_rank(comm);
|
||||
size = ompi_comm_size(comm);
|
||||
|
||||
/* Initialize reduce & scatterv info at the root (rank 0). */
|
||||
/* Find displacements and the like */
|
||||
disps = (int*) malloc(sizeof(int) * size);
|
||||
if (NULL == disps) return OMPI_ERR_OUT_OF_RESOURCE;
|
||||
|
||||
for (i = 0, count = 0; i < size; ++i) {
|
||||
if (rcounts[i] < 0) {
|
||||
return MPI_ERR_ARG;
|
||||
}
|
||||
count += rcounts[i];
|
||||
disps[0] = 0;
|
||||
for (i = 0; i < (size - 1); ++i) {
|
||||
disps[i + 1] = disps[i] + rcounts[i];
|
||||
}
|
||||
count = disps[size - 1] + rcounts[size - 1];
|
||||
|
||||
/* short cut the trivial case */
|
||||
if (0 == count) {
|
||||
free(disps);
|
||||
return OMPI_SUCCESS;
|
||||
}
|
||||
|
||||
if (0 == rank) {
|
||||
disps = (int*)malloc((unsigned) size * sizeof(int));
|
||||
if (NULL == disps) {
|
||||
return OMPI_ERR_OUT_OF_RESOURCE;
|
||||
}
|
||||
|
||||
/* There is lengthy rationale about how this malloc works in
|
||||
* coll_basic_reduce.c */
|
||||
|
||||
ompi_ddt_get_extent(dtype, &lb, &extent);
|
||||
ompi_ddt_get_true_extent(dtype, &true_lb, &true_extent);
|
||||
|
||||
free_buffer = (char*)malloc(true_extent + (count - 1) * extent);
|
||||
if (NULL == free_buffer) {
|
||||
free(disps);
|
||||
return OMPI_ERR_OUT_OF_RESOURCE;
|
||||
}
|
||||
pml_buffer = free_buffer - lb;
|
||||
|
||||
disps[0] = 0;
|
||||
for (i = 0; i < (size - 1); ++i) {
|
||||
disps[i + 1] = disps[i] + rcounts[i];
|
||||
}
|
||||
}
|
||||
/* get datatype information */
|
||||
ompi_ddt_get_extent(dtype, &lb, &extent);
|
||||
ompi_ddt_get_true_extent(dtype, &true_lb, &true_extent);
|
||||
buf_size = true_extent + (count - 1) * extent;
|
||||
|
||||
/* Handle MPI_IN_PLACE */
|
||||
|
||||
if (MPI_IN_PLACE == sbuf) {
|
||||
sbuf = rbuf;
|
||||
}
|
||||
|
||||
/* reduction */
|
||||
if ((op->o_flags & OMPI_OP_FLAGS_COMMUTE) &&
|
||||
(buf_size < COMMUTATIVE_LONG_MSG)) {
|
||||
int tmp_size = 1, remain = 0, tmp_rank;
|
||||
|
||||
err =
|
||||
comm->c_coll.coll_reduce(sbuf, pml_buffer, count, dtype, op, 0,
|
||||
comm);
|
||||
/* temporary receive buffer. See coll_basic_reduce.c for details on sizing */
|
||||
recv_buf_free = (char*) malloc(buf_size);
|
||||
recv_buf = recv_buf_free - lb;
|
||||
if (NULL == recv_buf_free) {
|
||||
err = OMPI_ERR_OUT_OF_RESOURCE;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
/* scatter */
|
||||
/* allocate temporary buffer for results */
|
||||
result_buf_free = (char*) malloc(buf_size);
|
||||
result_buf = result_buf_free - lb;
|
||||
|
||||
if (MPI_SUCCESS == err) {
|
||||
err = comm->c_coll.coll_scatterv(pml_buffer, rcounts, disps, dtype,
|
||||
rbuf, rcounts[rank], dtype, 0,
|
||||
comm);
|
||||
/* copy local buffer into the temporary results */
|
||||
err = ompi_ddt_sndrcv(sbuf, count, dtype, result_buf, count, dtype);
|
||||
if (OMPI_SUCCESS != err) goto cleanup;
|
||||
|
||||
/* figure out power of two mapping: grow until larger than
|
||||
comm size, then go back one, to get the largest power of
|
||||
two less than comm size */
|
||||
while (tmp_size <= size) tmp_size <<= 1;
|
||||
tmp_size >>= 1;
|
||||
remain = size - tmp_size;
|
||||
|
||||
/* If comm size is not a power of two, have the first "remain"
|
||||
procs with an even rank send to rank + 1, leaving a power of
|
||||
two procs to do the rest of the algorithm */
|
||||
if (rank < 2 * remain) {
|
||||
if ((rank & 1) == 0) {
|
||||
err = MCA_PML_CALL(send(result_buf, count, dtype, rank + 1,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
MCA_PML_BASE_SEND_STANDARD,
|
||||
comm));
|
||||
if (OMPI_SUCCESS != err) goto cleanup;
|
||||
|
||||
/* we don't participate from here on out */
|
||||
tmp_rank = -1;
|
||||
} else {
|
||||
err = MCA_PML_CALL(recv(recv_buf, count, dtype, rank - 1,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
comm, MPI_STATUS_IGNORE));
|
||||
|
||||
/* integrate their results into our temp results */
|
||||
ompi_op_reduce(op, recv_buf, result_buf, count, dtype);
|
||||
|
||||
/* adjust rank to be the bottom "remain" ranks */
|
||||
tmp_rank = rank / 2;
|
||||
}
|
||||
} else {
|
||||
/* just need to adjust rank to show that the bottom "even
|
||||
remain" ranks dropped out */
|
||||
tmp_rank = rank - remain;
|
||||
}
|
||||
|
||||
/* For ranks not kicked out by the above code, perform the
|
||||
recursive halving */
|
||||
if (tmp_rank >= 0) {
|
||||
int *tmp_disps = NULL, *tmp_rcounts = NULL;
|
||||
int mask, send_index, recv_index, last_index;
|
||||
|
||||
/* recalculate disps and rcounts to account for the
|
||||
special "remainder" processes that are no longer doing
|
||||
anything */
|
||||
tmp_rcounts = (int*) malloc(tmp_size * sizeof(int));
|
||||
if (NULL == tmp_rcounts) {
|
||||
err = OMPI_ERR_OUT_OF_RESOURCE;
|
||||
goto cleanup;
|
||||
}
|
||||
tmp_disps = (int*) malloc(tmp_size * sizeof(int));
|
||||
if (NULL == tmp_disps) {
|
||||
free(tmp_rcounts);
|
||||
err = OMPI_ERR_OUT_OF_RESOURCE;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
for (i = 0 ; i < tmp_size ; ++i) {
|
||||
if (i < remain) {
|
||||
/* need to include old neighbor as well */
|
||||
tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
|
||||
} else {
|
||||
tmp_rcounts[i] = rcounts[i + remain];
|
||||
}
|
||||
}
|
||||
|
||||
tmp_disps[0] = 0;
|
||||
for (i = 0; i < tmp_size - 1; ++i) {
|
||||
tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
|
||||
}
|
||||
|
||||
/* do the recursive halving communication. Don't use the
|
||||
dimension information on the communicator because I
|
||||
think the information is invalidated by our "shrinking"
|
||||
of the communicator */
|
||||
mask = tmp_size >> 1;
|
||||
send_index = recv_index = 0;
|
||||
last_index = tmp_size;
|
||||
while (mask > 0) {
|
||||
int tmp_peer, peer, send_count, recv_count;
|
||||
struct ompi_request_t *request;
|
||||
|
||||
tmp_peer = tmp_rank ^ mask;
|
||||
peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
|
||||
|
||||
/* figure out if we're sending, receiving, or both */
|
||||
send_count = recv_count = 0;
|
||||
if (tmp_rank < tmp_peer) {
|
||||
send_index = recv_index + mask;
|
||||
for (i = send_index ; i < last_index ; ++i) {
|
||||
send_count += tmp_rcounts[i];
|
||||
}
|
||||
for (i = recv_index ; i < send_index ; ++i) {
|
||||
recv_count += tmp_rcounts[i];
|
||||
}
|
||||
} else {
|
||||
recv_index = send_index + mask;
|
||||
for (i = send_index ; i < recv_index ; ++i) {
|
||||
send_count += tmp_rcounts[i];
|
||||
}
|
||||
for (i = recv_index ; i < last_index ; ++i) {
|
||||
recv_count += tmp_rcounts[i];
|
||||
}
|
||||
}
|
||||
|
||||
/* actual data transfer. Send from result_buf,
|
||||
receive into recv_buf */
|
||||
if (send_count > 0 && recv_count != 0) {
|
||||
err = MCA_PML_CALL(irecv(recv_buf + tmp_disps[recv_index] * extent,
|
||||
recv_count, dtype, peer,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
comm, &request));
|
||||
if (OMPI_SUCCESS != err) {
|
||||
free(tmp_rcounts);
|
||||
free(tmp_disps);
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
if (recv_count > 0 && send_count != 0) {
|
||||
err = MCA_PML_CALL(send(result_buf + tmp_disps[send_index] * extent,
|
||||
send_count, dtype, peer,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
MCA_PML_BASE_SEND_STANDARD,
|
||||
comm));
|
||||
if (OMPI_SUCCESS != err) {
|
||||
free(tmp_rcounts);
|
||||
free(tmp_disps);
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
if (send_count > 0 && recv_count != 0) {
|
||||
err = ompi_request_wait(&request, MPI_STATUS_IGNORE);
|
||||
if (OMPI_SUCCESS != err) {
|
||||
free(tmp_rcounts);
|
||||
free(tmp_disps);
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
|
||||
/* if we received something on this step, push it into
|
||||
the results buffer */
|
||||
if (recv_count > 0) {
|
||||
ompi_op_reduce(op,
|
||||
recv_buf + tmp_disps[recv_index] * extent,
|
||||
result_buf + tmp_disps[recv_index] * extent,
|
||||
recv_count, dtype);
|
||||
}
|
||||
|
||||
/* update for next iteration */
|
||||
send_index = recv_index;
|
||||
last_index = recv_index + mask;
|
||||
mask >>= 1;
|
||||
}
|
||||
|
||||
/* copy local results from results buffer into real receive buffer */
|
||||
if (0 != rcounts[rank]) {
|
||||
err = ompi_ddt_sndrcv(result_buf + disps[rank] * extent,
|
||||
rcounts[rank], dtype,
|
||||
rbuf, rcounts[rank], dtype);
|
||||
if (OMPI_SUCCESS != err) {
|
||||
free(tmp_rcounts);
|
||||
free(tmp_disps);
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
|
||||
free(tmp_rcounts);
|
||||
free(tmp_disps);
|
||||
}
|
||||
|
||||
/* Now fix up the non-power of two case, by having the odd
|
||||
procs send the even procs the proper results */
|
||||
if (rank < 2 * remain) {
|
||||
if ((rank & 1) == 0) {
|
||||
if (rcounts[rank]) {
|
||||
err = MCA_PML_CALL(recv(rbuf, rcounts[rank], dtype, rank + 1,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
comm, MPI_STATUS_IGNORE));
|
||||
if (OMPI_SUCCESS != err) goto cleanup;
|
||||
}
|
||||
} else {
|
||||
if (rcounts[rank - 1]) {
|
||||
err = MCA_PML_CALL(send(result_buf + disps[rank - 1] * extent,
|
||||
rcounts[rank - 1], dtype, rank - 1,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
MCA_PML_BASE_SEND_STANDARD,
|
||||
comm));
|
||||
if (OMPI_SUCCESS != err) goto cleanup;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
if (0 == rank) {
|
||||
/* temporary receive buffer. See coll_basic_reduce.c for
|
||||
details on sizing */
|
||||
recv_buf_free = (char*) malloc(buf_size);
|
||||
recv_buf = recv_buf_free - lb;
|
||||
if (NULL == recv_buf_free) {
|
||||
err = OMPI_ERR_OUT_OF_RESOURCE;
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
|
||||
/* reduction */
|
||||
err =
|
||||
comm->c_coll.coll_reduce(sbuf, recv_buf, count, dtype, op, 0,
|
||||
comm);
|
||||
|
||||
/* scatter */
|
||||
if (MPI_SUCCESS == err) {
|
||||
err = comm->c_coll.coll_scatterv(recv_buf, rcounts, disps, dtype,
|
||||
rbuf, rcounts[rank], dtype, 0,
|
||||
comm);
|
||||
}
|
||||
}
|
||||
|
||||
/* All done */
|
||||
|
||||
if (NULL != disps) {
|
||||
free(disps);
|
||||
}
|
||||
if (NULL != free_buffer) {
|
||||
free(free_buffer);
|
||||
}
|
||||
cleanup:
|
||||
if (NULL != disps) free(disps);
|
||||
if (NULL != recv_buf_free) free(recv_buf_free);
|
||||
if (NULL != result_buf_free) free(result_buf_free);
|
||||
|
||||
return err;
|
||||
}
|
||||
|
Загрузка…
x
Ссылка в новой задаче
Block a user