Fix the implementation of MPI_Reduce_scatter on intercommunicators.
We still do an interreduce but it is now followed by an intrascatterv. This fixes trac:1554. This commit was SVN r19723. The following Trac tickets were found above: Ticket 1554 --> https://svn.open-mpi.org/trac/ompi/ticket/1554
Этот коммит содержится в:
родитель
f7a94f17b9
Коммит
aad4427caa
@ -370,40 +370,57 @@ mca_coll_basic_reduce_scatter_inter(void *sbuf, void *rbuf, int *rcounts,
|
||||
struct ompi_communicator_t *comm,
|
||||
mca_coll_base_module_t *module)
|
||||
{
|
||||
int err, i, rank, root = 0, rsize;
|
||||
int err, i, rank, root = 0, rsize, lsize;
|
||||
int totalcounts, tcount;
|
||||
ptrdiff_t lb, extent;
|
||||
char *tmpbuf = NULL, *tmpbuf2 = NULL, *tbuf = NULL;
|
||||
ompi_request_t *req;
|
||||
mca_coll_basic_module_t *basic_module = (mca_coll_basic_module_t*) module;
|
||||
ompi_request_t **reqs = basic_module->mccb_reqs;
|
||||
int *disps = NULL;
|
||||
|
||||
rank = ompi_comm_rank(comm);
|
||||
rsize = ompi_comm_remote_size(comm);
|
||||
lsize = ompi_comm_size(comm);
|
||||
|
||||
/* According to MPI-2, the total sum of elements transfered has to
|
||||
* be identical in both groups. Thus, it is enough to calculate
|
||||
* that locally.
|
||||
*/
|
||||
for (totalcounts = 0, i = 0; i < rsize; i++) {
|
||||
/* Figure out the total amount of data for the reduction. */
|
||||
for (totalcounts = 0, i = 0; i < lsize; i++) {
|
||||
totalcounts += rcounts[i];
|
||||
}
|
||||
|
||||
/* determine result of the remote group, you cannot
|
||||
* use coll_reduce for inter-communicators, since than
|
||||
* you would need to determine an order between the
|
||||
* two groups (e.g. which group is providing the data
|
||||
* and which one enters coll_reduce with providing
|
||||
* MPI_PROC_NULL as root argument etc.) Here,
|
||||
* we execute the data exchange for both groups
|
||||
* simultaniously. */
|
||||
/*****************************************************************/
|
||||
/*
|
||||
* The following code basically does an interreduce followed by a
|
||||
* intrascatterv. This is implemented by having the roots of each
|
||||
* group exchange their sbuf. Then, the roots receive the data
|
||||
* from each of the remote ranks and execute the reduce. When
|
||||
* this is complete, they have the reduced data available to them
|
||||
* for doing the scatterv. They do this on the local communicator
|
||||
* associated with the intercommunicator.
|
||||
*
|
||||
* Note: There are other ways to implement MPI_Reduce_scatter on
|
||||
* intercommunicators. For example, one could do a MPI_Reduce locally,
|
||||
* then send the results to the other root which could scatter it.
|
||||
*
|
||||
* Note: It is also worth pointing out that the rcounts argument
|
||||
* represents how the data is going to be scatter locally. Therefore,
|
||||
* its size is the same as the local communicator size.
|
||||
*/
|
||||
if (rank == root) {
|
||||
err = ompi_ddt_get_extent(dtype, &lb, &extent);
|
||||
if (OMPI_SUCCESS != err) {
|
||||
return OMPI_ERROR;
|
||||
}
|
||||
|
||||
/* Generate displacements for the scatterv part */
|
||||
disps = (int*) malloc(sizeof(int) * lsize);
|
||||
if (NULL == disps) {
|
||||
return OMPI_ERR_OUT_OF_RESOURCE;
|
||||
}
|
||||
disps[0] = 0;
|
||||
for (i = 0; i < (lsize - 1); ++i) {
|
||||
disps[i + 1] = disps[i] + rcounts[i];
|
||||
}
|
||||
|
||||
tmpbuf = (char *) malloc(totalcounts * extent);
|
||||
tmpbuf2 = (char *) malloc(totalcounts * extent);
|
||||
if (NULL == tmpbuf || NULL == tmpbuf2) {
|
||||
@ -456,72 +473,13 @@ mca_coll_basic_reduce_scatter_inter(void *sbuf, void *rbuf, int *rcounts,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* now we have on one process the result of the remote group. To distribute
|
||||
* the data to all processes in the local group, we exchange the data between
|
||||
* the two root processes. They then send it to every other process in the
|
||||
* remote group.
|
||||
*/
|
||||
/***************************************************************************/
|
||||
if (rank == root) {
|
||||
/* sendrecv between the two roots */
|
||||
err = MCA_PML_CALL(irecv(tmpbuf, totalcounts, dtype, 0,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
comm, &req));
|
||||
if (OMPI_SUCCESS != err) {
|
||||
goto exit;
|
||||
}
|
||||
|
||||
err = MCA_PML_CALL(send(tmpbuf2, totalcounts, dtype, 0,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
MCA_PML_BASE_SEND_STANDARD, comm));
|
||||
if (OMPI_SUCCESS != err) {
|
||||
goto exit;
|
||||
}
|
||||
|
||||
err = ompi_request_wait_all(1, &req, MPI_STATUS_IGNORE);
|
||||
if (OMPI_SUCCESS != err) {
|
||||
goto exit;
|
||||
}
|
||||
|
||||
/* distribute the data to other processes in remote group.
|
||||
* Note that we start from 1 (not from zero), since zero
|
||||
* has already the correct data AND we avoid a potential
|
||||
* deadlock here.
|
||||
*/
|
||||
err = MCA_PML_CALL(irecv(rbuf, rcounts[rank], dtype, root,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
comm, &req));
|
||||
|
||||
tcount = 0;
|
||||
for (i = 0; i < rsize; i++) {
|
||||
tbuf = (char *) tmpbuf + tcount * extent;
|
||||
err = MCA_PML_CALL(isend(tbuf, rcounts[i], dtype, i,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
MCA_PML_BASE_SEND_STANDARD, comm,
|
||||
reqs++));
|
||||
if (OMPI_SUCCESS != err) {
|
||||
goto exit;
|
||||
}
|
||||
tcount += rcounts[i];
|
||||
}
|
||||
|
||||
err =
|
||||
ompi_request_wait_all(rsize,
|
||||
basic_module->mccb_reqs,
|
||||
MPI_STATUSES_IGNORE);
|
||||
if (OMPI_SUCCESS != err) {
|
||||
goto exit;
|
||||
}
|
||||
|
||||
err = ompi_request_wait_all(1, &req, MPI_STATUS_IGNORE);
|
||||
if (OMPI_SUCCESS != err) {
|
||||
goto exit;
|
||||
}
|
||||
} else {
|
||||
err = MCA_PML_CALL(recv(rbuf, rcounts[rank], dtype, root,
|
||||
MCA_COLL_BASE_TAG_REDUCE_SCATTER,
|
||||
comm, MPI_STATUS_IGNORE));
|
||||
/* Now do a scatterv on the local communicator */
|
||||
err = comm->c_local_comm->c_coll.coll_scatterv(tmpbuf2, rcounts, disps, dtype,
|
||||
rbuf, rcounts[rank], dtype, 0,
|
||||
comm->c_local_comm,
|
||||
comm->c_local_comm->c_coll.coll_scatterv_module);
|
||||
if (OMPI_SUCCESS != err) {
|
||||
goto exit;
|
||||
}
|
||||
|
||||
exit:
|
||||
@ -533,5 +491,9 @@ mca_coll_basic_reduce_scatter_inter(void *sbuf, void *rbuf, int *rcounts,
|
||||
free(tmpbuf2);
|
||||
}
|
||||
|
||||
if (NULL != disps) {
|
||||
free(disps);
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
Загрузка…
x
Ссылка в новой задаче
Block a user