1
1

Merge pull request #2061 from hjelmn/cid_inter

comm/cid: use ibcast to distribute result in intercomm case
Этот коммит содержится в:
Nathan Hjelm 2016-09-07 16:36:00 -06:00 коммит произвёл GitHub
родитель fd829ac389 54cc829aab
Коммит 63d73a5dd0

Просмотреть файл

@ -103,10 +103,6 @@ struct ompi_comm_allreduce_context_t {
ompi_comm_cid_context_t *cid_context;
int *tmpbuf;
/* for intercomm allreduce */
int *rcounts;
int *rdisps;
/* for group allreduce */
int peers_comm[3];
};
@ -121,8 +117,6 @@ static void ompi_comm_allreduce_context_construct (ompi_comm_allreduce_context_t
static void ompi_comm_allreduce_context_destruct (ompi_comm_allreduce_context_t *context)
{
free (context->tmpbuf);
free (context->rcounts);
free (context->rdisps);
}
OBJ_CLASS_INSTANCE (ompi_comm_allreduce_context_t, opal_object_t,
@ -602,7 +596,7 @@ static int ompi_comm_allreduce_intra_nb (int *inbuf, int *outbuf, int count, str
/* Non-blocking version of ompi_comm_allreduce_inter */
static int ompi_comm_allreduce_inter_leader_exchange (ompi_comm_request_t *request);
static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request);
static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t *request);
static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t *request);
static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
int count, struct ompi_op_t *op,
@ -636,18 +630,19 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
rsize = ompi_comm_remote_size (intercomm);
local_rank = ompi_comm_rank (intercomm);
context->tmpbuf = (int *) calloc (count, sizeof(int));
context->rdisps = (int *) calloc (rsize, sizeof(int));
context->rcounts = (int *) calloc (rsize, sizeof(int));
if (OPAL_UNLIKELY (NULL == context->tmpbuf || NULL == context->rdisps || NULL == context->rcounts)) {
ompi_comm_request_return (request);
return OMPI_ERR_OUT_OF_RESOURCE;
if (0 == local_rank) {
context->tmpbuf = (int *) calloc (count, sizeof(int));
if (OPAL_UNLIKELY (NULL == context->tmpbuf)) {
ompi_comm_request_return (request);
return OMPI_ERR_OUT_OF_RESOURCE;
}
}
/* Execute the inter-allreduce: the result from the local will be in the buffer of the remote group
* and vise-versa. */
rc = intercomm->c_coll.coll_iallreduce (inbuf, context->tmpbuf, count, MPI_INT, op, intercomm,
&subreq, intercomm->c_coll.coll_iallreduce_module);
rc = intercomm->c_local_comm->c_coll.coll_ireduce (inbuf, context->tmpbuf, count, MPI_INT, op, 0,
intercomm->c_local_comm, &subreq,
intercomm->c_local_comm->c_coll.coll_ireduce_module);
if (OPAL_UNLIKELY(OMPI_SUCCESS != rc)) {
ompi_comm_request_return (request);
return rc;
@ -656,7 +651,7 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
if (0 == local_rank) {
ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_leader_exchange, &subreq, 1);
} else {
ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_allgather, &subreq, 1);
ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_bcast, &subreq, 1);
}
ompi_comm_request_start (request);
@ -696,33 +691,20 @@ static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request
ompi_op_reduce (context->op, context->tmpbuf, context->outbuf, context->count, MPI_INT);
return ompi_comm_allreduce_inter_allgather (request);
return ompi_comm_allreduce_inter_bcast (request);
}
static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t *request)
static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t *request)
{
ompi_comm_allreduce_context_t *context = (ompi_comm_allreduce_context_t *) request->context;
ompi_communicator_t *intercomm = context->cid_context->comm;
ompi_communicator_t *comm = context->cid_context->comm->c_local_comm;
ompi_request_t *subreq;
int scount = 0, rc;
/* distribute the overall result to all processes in the other group.
Instead of using bcast, we are using here allgatherv, to avoid the
possible deadlock. Else, we need an algorithm to determine,
which group sends first in the inter-bcast and which receives
the result first.
*/
if (0 != ompi_comm_rank (intercomm)) {
context->rcounts[0] = context->count;
} else {
scount = context->count;
}
rc = intercomm->c_coll.coll_iallgatherv (context->outbuf, scount, MPI_INT, context->outbuf,
context->rcounts, context->rdisps, MPI_INT, intercomm,
&subreq, intercomm->c_coll.coll_iallgatherv_module);
/* both roots have the same result. broadcast to the local group */
rc = comm->c_coll.coll_ibcast (context->outbuf, context->count, MPI_INT, 0, comm,
&subreq, comm->c_coll.coll_ibcast_module);
if (OMPI_SUCCESS != rc) {
return rc;
}