From 19803d660548baaa45e09732ee490736d430e2e9 Mon Sep 17 00:00:00 2001 From: Oscar Vega-Gisbert Date: Sat, 19 Apr 2014 11:12:38 +0000 Subject: [PATCH] Java - neighborhood collective communication: get buffers according topology information This commit was SVN r31452. --- ompi/mpi/java/c/mpi_Comm.c | 70 +++++++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/ompi/mpi/java/c/mpi_Comm.c b/ompi/mpi/java/c/mpi_Comm.c index 80ca0d09c5..94ae161080 100644 --- a/ompi/mpi/java/c/mpi_Comm.c +++ b/ompi/mpi/java/c/mpi_Comm.c @@ -119,6 +119,39 @@ static int getRank(JNIEnv *env, MPI_Comm comm) return rank; } +static int getTopo(JNIEnv *env, MPI_Comm comm) +{ + int rc, status; + rc = MPI_Topo_test(comm, &status); + ompi_java_exceptionCheck(env, rc); + return status; +} + +static void getNeighbors(JNIEnv *env, MPI_Comm comm, int *out, int *in) +{ + int rc, weighted; + + switch(getTopo(env, comm)) + { + case MPI_CART: + rc = MPI_Cartdim_get(comm, in); + *in *= 2; + *out = *in; + break; + case MPI_GRAPH: + rc = MPI_Graph_neighbors_count(comm, getRank(env, comm), in); + *out = *in; + break; + case MPI_DIST_GRAPH: + rc = MPI_Dist_graph_neighbors_count(comm, in, out, &weighted); + break; + default: + assert(0); + } + + ompi_java_exceptionCheck(env, rc); +} + static int getSum(int *counts, int size) { int i, s = 0; @@ -624,10 +657,7 @@ JNIEXPORT void JNICALL Java_mpi_Comm_probe( JNIEXPORT jint JNICALL Java_mpi_Comm_getTopology( JNIEnv *env, jobject jthis, jlong comm) { - int rc, status; - rc = MPI_Topo_test((MPI_Comm)comm, &status); - ompi_java_exceptionCheck(env, rc); - return status; + return getTopo(env, (MPI_Comm)comm); } JNIEXPORT void JNICALL Java_mpi_Comm_abort( @@ -1559,9 +1589,9 @@ JNIEXPORT void JNICALL Java_mpi_Comm_neighborAllGather( MPI_Datatype sType = (MPI_Datatype)sjType; MPI_Datatype rType = (MPI_Datatype)rjType; - int inter = isInter(env, comm), - size = getSize(env, comm, inter), - rTotal = rCount * size; + int sSize, rSize; + getNeighbors(env, comm, &sSize, &rSize); + int rTotal = rCount * rSize; void *sPtr, *rPtr; ompi_java_buffer_t *sItem, *rItem; @@ -1608,8 +1638,8 @@ JNIEXPORT void JNICALL Java_mpi_Comm_neighborAllGatherv( MPI_Datatype sType = (MPI_Datatype)sjType; MPI_Datatype rType = (MPI_Datatype)rjType; - int inter = isInter(env, comm), - size = getSize(env, comm, inter); + int sSize, rSize; + getNeighbors(env, comm, &sSize, &rSize); jint *jRCount, *jDispls; int *cRCount, *cDispls; @@ -1621,7 +1651,7 @@ JNIEXPORT void JNICALL Java_mpi_Comm_neighborAllGatherv( ompi_java_getReadPtr(&sPtr,&sItem,env,sBuf,sdb,sOff,sCount,sType,sBType); ompi_java_getWritePtrv(&rPtr, &rItem, env, rBuf, rdb, - cRCount, cDispls, size, rType); + cRCount, cDispls, rSize, rType); int rc = MPI_Neighbor_allgatherv( sPtr, sCount, sType, rPtr, cRCount, cDispls, rType, comm); @@ -1630,7 +1660,7 @@ JNIEXPORT void JNICALL Java_mpi_Comm_neighborAllGatherv( ompi_java_releaseReadPtr(sPtr, sItem, sBuf, sdb); ompi_java_releaseWritePtrv(rPtr, rItem, env, rBuf, rdb, rOff, - cRCount, cDispls, size, rType, rBType); + cRCount, cDispls, rSize, rType, rBType); ompi_java_forgetIntArray(env, rCount, jRCount, cRCount); ompi_java_forgetIntArray(env, displs, jDispls, cDispls); @@ -1675,10 +1705,10 @@ JNIEXPORT void JNICALL Java_mpi_Comm_neighborAllToAll( MPI_Datatype sType = (MPI_Datatype)sjType; MPI_Datatype rType = (MPI_Datatype)rjType; - int inter = isInter(env, comm), - size = getSize(env, comm, inter), - sTotal = sCount * size, - rTotal = rCount * size; + int sSize, rSize; + getNeighbors(env, comm, &sSize, &rSize); + int sTotal = sCount * sSize; + int rTotal = rCount * rSize; void *sPtr, *rPtr; ompi_java_buffer_t *sItem, *rItem; @@ -1725,8 +1755,8 @@ JNIEXPORT void JNICALL Java_mpi_Comm_neighborAllToAllv( MPI_Datatype sType = (MPI_Datatype)sjType; MPI_Datatype rType = (MPI_Datatype)rjType; - int inter = isInter(env, comm), - size = getSize(env, comm, inter); + int sSize, rSize; + getNeighbors(env, comm, &sSize, &rSize); jint *jSCount, *jRCount, *jSDispl, *jRDispl; int *cSCount, *cRCount, *cSDispl, *cRDispl; @@ -1739,9 +1769,9 @@ JNIEXPORT void JNICALL Java_mpi_Comm_neighborAllToAllv( ompi_java_buffer_t *sItem, *rItem; ompi_java_getReadPtrv(&sPtr, &sItem, env, sBuf, sdb, sOff, - cSCount, cSDispl, size, -1, sType, sBType); + cSCount, cSDispl, sSize, -1, sType, sBType); ompi_java_getWritePtrv(&rPtr, &rItem, env, rBuf, rdb, - cRCount, cRDispl, size, rType); + cRCount, cRDispl, rSize, rType); int rc = MPI_Neighbor_alltoallv(sPtr, cSCount, cSDispl, sType, rPtr, cRCount, cRDispl, rType, comm); @@ -1750,7 +1780,7 @@ JNIEXPORT void JNICALL Java_mpi_Comm_neighborAllToAllv( ompi_java_releaseReadPtr(sPtr, sItem, sBuf, sdb); ompi_java_releaseWritePtrv(rPtr, rItem, env, rBuf, rdb, rOff, - cRCount, cRDispl, size, rType, rBType); + cRCount, cRDispl, rSize, rType, rBType); ompi_java_forgetIntArray(env, sCount, jSCount, cSCount); ompi_java_forgetIntArray(env, rCount, jRCount, cRCount);