diff --git a/oshmem/mca/scoll/mpi/scoll_mpi_ops.c b/oshmem/mca/scoll/mpi/scoll_mpi_ops.c index eb03dfec2d..2aa87a0222 100644 --- a/oshmem/mca/scoll/mpi/scoll_mpi_ops.c +++ b/oshmem/mca/scoll/mpi/scoll_mpi_ops.c @@ -107,16 +107,18 @@ int mca_scoll_mpi_collect(struct oshmem_group_t *group, bool nlong_type, int alg) { + ompi_datatype_t* stype = &ompi_mpi_char.dt; + ompi_datatype_t* rtype = &ompi_mpi_char.dt; mca_scoll_mpi_module_t *mpi_module; - ompi_datatype_t* stype; - ompi_datatype_t* rtype; int rc; + int len; + int i; void *sbuf, *rbuf; + int *disps, *recvcounts; MPI_COLL_VERBOSE(20,"RUNNING MPI ALLGATHER"); mpi_module = (mca_scoll_mpi_module_t *) group->g_scoll.scoll_collect_module; if (nlong_type == true) { - /* Do nothing on zero-length request */ if (OPAL_UNLIKELY(!nlong)) { return OSHMEM_SUCCESS; @@ -124,8 +126,6 @@ int mca_scoll_mpi_collect(struct oshmem_group_t *group, sbuf = (void *) source; rbuf = target; - stype = &ompi_mpi_char.dt; - rtype = &ompi_mpi_char.dt; /* Open SHMEM specification has the following constrains (page 85): * "If using C/C++, nelems must be of type integer. If you are using Fortran, it must be a * default integer value". And also fortran signature says "INTEGER". @@ -159,15 +159,52 @@ int mca_scoll_mpi_collect(struct oshmem_group_t *group, SCOLL_DEFAULT_ALG); } } else { - MPI_COLL_VERBOSE(20,"RUNNING FALLBACK COLLECT"); - PREVIOUS_SCOLL_FN(mpi_module, collect, group, - target, - source, - nlong, - pSync, - nlong_type, - SCOLL_DEFAULT_ALG); + if (INT_MAX < nlong) { + MPI_COLL_VERBOSE(20,"RUNNING FALLBACK COLLECT"); + PREVIOUS_SCOLL_FN(mpi_module, collect, group, + target, + source, + nlong, + pSync, + nlong_type, + SCOLL_DEFAULT_ALG); + return rc; + } + + len = nlong; + disps = malloc(group->proc_count * sizeof(*disps)); + if (disps == NULL) { + rc = OSHMEM_ERR_OUT_OF_RESOURCE; + goto complete; + } + + recvcounts = malloc(group->proc_count * sizeof(*recvcounts)); + if (recvcounts == NULL) { + rc = OSHMEM_ERR_OUT_OF_RESOURCE; + goto failed_mem; + } + + rc = mpi_module->comm->c_coll->coll_allgather(&len, sizeof(len), stype, recvcounts, + sizeof(len), rtype, mpi_module->comm, + mpi_module->comm->c_coll->coll_allgather_module); + if (rc != OSHMEM_SUCCESS) { + goto failed_allgather; + } + + disps[0] = 0; + for (i = 1; i < group->proc_count; i++) { + disps[i] = disps[i - 1] + recvcounts[i - 1]; + } + + rc = mpi_module->comm->c_coll->coll_allgatherv(source, nlong, stype, target, recvcounts, + disps, rtype, mpi_module->comm, + mpi_module->comm->c_coll->coll_allgatherv_module); +failed_allgather: + free(recvcounts); +failed_mem: + free(disps); } +complete: return rc; }