From 7ece56497827178002c94e30dfa3c415aa40b0eb Mon Sep 17 00:00:00 2001 From: "Kurita, Takehiro" Date: Tue, 21 May 2019 15:31:56 +0900 Subject: [PATCH] java: Fix compilation error in allToAllw using Java arrays Java bindings in Open MPI support Java arrays and direct buffers as buffers. All non-blocking methods must use direct buffers and only blocking methods can choose between Java arrays and direct buffers. Though Comm.allToAllw() is a blocking method, Java applications using Java arrays as buffers get compilation errors. This fix enables using Java arrays in Comm.allToAllw(). Signed-off-by: Kurita, Takehiro --- ompi/mpi/java/c/mpiJava.h | 24 +++++ ompi/mpi/java/c/mpi_Comm.c | 82 ++++++++++------ ompi/mpi/java/c/mpi_MPI.c | 182 +++++++++++++++++++++++++++++++++++ ompi/mpi/java/java/Comm.java | 62 ++++++++++-- 4 files changed, 314 insertions(+), 36 deletions(-) diff --git a/ompi/mpi/java/c/mpiJava.h b/ompi/mpi/java/c/mpiJava.h index 6f20cf943b..319536e22d 100644 --- a/ompi/mpi/java/c/mpiJava.h +++ b/ompi/mpi/java/c/mpiJava.h @@ -11,6 +11,7 @@ * All rights reserved. * Copyright (c) 2015 Los Alamos National Security, LLC. All rights * reserved. + * Copyright (c) 2019 FUJITSU LIMITED. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -94,6 +95,15 @@ void ompi_java_getReadPtrv( jobject buf, jboolean db, int off, int *counts, int *displs, int size, int rank, MPI_Datatype type, int baseType); +/* Gets a buffer pointer for reading, but it + * 'size' is the number of processes. + * if rank == -1 it copies all data from Java. + * if rank != -1 it only copies from Java the rank data. */ +void ompi_java_getReadPtrw( + void **ptr, ompi_java_buffer_t **item, JNIEnv *env, + jobject buf, jboolean db, int *offs, int *counts, int *displs, + int size, int rank, MPI_Datatype *types, int *baseTypes); + /* Releases a buffer used for reading. */ void ompi_java_releaseReadPtr( void *ptr, ompi_java_buffer_t *item, jobject buf, jboolean db); @@ -109,6 +119,12 @@ void ompi_java_getWritePtrv( void **ptr, ompi_java_buffer_t **item, JNIEnv *env, jobject buf, jboolean db, int *counts, int *displs, int size, MPI_Datatype type); +/* Gets a buffer pointer for writing. + * 'size' is the number of processes. */ +void ompi_java_getWritePtrw( + void **ptr, ompi_java_buffer_t **item, JNIEnv *env, jobject buf, + jboolean db, int *counts, int *displs, int size, MPI_Datatype *types); + /* Releases a buffer used for writing. * It copies data to Java. */ void ompi_java_releaseWritePtr( @@ -123,6 +139,14 @@ void ompi_java_releaseWritePtrv( jobject buf, jboolean db, int off, int *counts, int *displs, int size, MPI_Datatype type, int baseType); +/* Releases a buffer used for writing. + * It copies data to Java. + * 'size' is the number of processes. */ +void ompi_java_releaseWritePtrw( + void *ptr, ompi_java_buffer_t *item, JNIEnv *env, + jobject buf, jboolean db, int *offs, int *counts, int *displs, + int size, MPI_Datatype *types, int *baseTypes); + void ompi_java_setStaticLongField(JNIEnv *env, jclass c, char *field, jlong value); diff --git a/ompi/mpi/java/c/mpi_Comm.c b/ompi/mpi/java/c/mpi_Comm.c index fbf15c73b4..8dd0f23663 100644 --- a/ompi/mpi/java/c/mpi_Comm.c +++ b/ompi/mpi/java/c/mpi_Comm.c @@ -13,7 +13,7 @@ * and Technology (RIST). All rights reserved. * Copyright (c) 2016 Los Alamos National Security, LLC. All rights * reserved. - * Copyright (c) 2017 FUJITSU LIMITED. All rights reserved. + * Copyright (c) 2017-2019 FUJITSU LIMITED. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -1612,39 +1612,67 @@ JNIEXPORT jlong JNICALL Java_mpi_Comm_iAllToAllv( } JNIEXPORT void JNICALL Java_mpi_Comm_allToAllw( - JNIEnv *env, jobject jthis, jlong jComm, - jobject sendBuf, jintArray sCount, jintArray sDispls, jlongArray sTypes, - jobject recvBuf, jintArray rCount, jintArray rDispls, jlongArray rTypes) + JNIEnv *env, jobject jthis, jlong jComm, + jobject sBuf, jboolean sdb, jintArray sOffs, jintArray sCount, + jintArray sDispls, jlongArray sTypes, jintArray sBtypes, + jobject rBuf, jboolean rdb, jintArray rOffs, jintArray rCount, + jintArray rDispls, jlongArray rTypes, jintArray rBtypes) { - MPI_Comm comm = (MPI_Comm)jComm; + MPI_Comm comm = (MPI_Comm)jComm; - jlong* jSTypes, *jRTypes; - MPI_Datatype *cSTypes, *cRTypes; + int inter = isInter(env, comm), + size = getSize(env, comm, inter); - ompi_java_getDatatypeArray(env, sTypes, &jSTypes, &cSTypes); - ompi_java_getDatatypeArray(env, rTypes, &jRTypes, &cRTypes); + jlong* jSTypes, *jRTypes; + MPI_Datatype *cSTypes, *cRTypes; - jint *jSCount, *jRCount, *jSDispls, *jRDispls; - int *cSCount, *cRCount, *cSDispls, *cRDispls; - ompi_java_getIntArray(env, sCount, &jSCount, &cSCount); - ompi_java_getIntArray(env, rCount, &jRCount, &cRCount); - ompi_java_getIntArray(env, sDispls, &jSDispls, &cSDispls); - ompi_java_getIntArray(env, rDispls, &jRDispls, &cRDispls); + ompi_java_getDatatypeArray(env, sTypes, &jSTypes, &cSTypes); + ompi_java_getDatatypeArray(env, rTypes, &jRTypes, &cRTypes); - void *sPtr = ompi_java_getDirectBufferAddress(env, sendBuf), - *rPtr = ompi_java_getDirectBufferAddress(env, recvBuf); + jint *jSCount, *jRCount, *jSDispls, *jRDispls; + int *cSCount, *cRCount, *cSDispls, *cRDispls; + jint *jSBtypes, *jRBtypes; + int *cSBtypes, *cRBtypes; + jint *jSOffs, *jROffs; + int *cSOffs, *cROffs; - int rc = MPI_Alltoallw( - sPtr, cSCount, cSDispls, cSTypes, - rPtr, cRCount, cRDispls, cRTypes, comm); + ompi_java_getIntArray(env, sCount, &jSCount, &cSCount); + ompi_java_getIntArray(env, rCount, &jRCount, &cRCount); + ompi_java_getIntArray(env, sDispls, &jSDispls, &cSDispls); + ompi_java_getIntArray(env, rDispls, &jRDispls, &cRDispls); + ompi_java_getIntArray(env, sBtypes, &jSBtypes, &cSBtypes); + ompi_java_getIntArray(env, rBtypes, &jRBtypes, &cRBtypes); + ompi_java_getIntArray(env, sOffs, &jSOffs, &cSOffs); + ompi_java_getIntArray(env, rOffs, &jROffs, &cROffs); - ompi_java_exceptionCheck(env, rc); - ompi_java_forgetIntArray(env, sCount, jSCount, cSCount); - ompi_java_forgetIntArray(env, rCount, jRCount, cRCount); - ompi_java_forgetIntArray(env, sDispls, jSDispls, cSDispls); - ompi_java_forgetIntArray(env, rDispls, jRDispls, cRDispls); - ompi_java_forgetDatatypeArray(env, sTypes, jSTypes, cSTypes); - ompi_java_forgetDatatypeArray(env, rTypes, jRTypes, cRTypes); + void *sPtr, *rPtr; + ompi_java_buffer_t *sItem, *rItem; + + ompi_java_getReadPtrw(&sPtr, &sItem, env, sBuf, sdb, cSOffs, + cSCount, cSDispls, size, -1, cSTypes, cSBtypes); + ompi_java_getWritePtrw(&rPtr, &rItem, env, rBuf, rdb, + cRCount, cRDispls, size, cRTypes); + + int rc = MPI_Alltoallw(sPtr, cSCount, cSDispls, cSTypes, + rPtr, cRCount, cRDispls, cRTypes, comm); + + ompi_java_exceptionCheck(env, rc); + ompi_java_releaseReadPtr(sPtr, sItem, sBuf, sdb); + + ompi_java_releaseWritePtrw(rPtr, rItem, env, rBuf, rdb, cROffs, + cRCount, cRDispls, size, cRTypes, cRBtypes); + + ompi_java_exceptionCheck(env, rc); + ompi_java_forgetIntArray(env, sCount, jSCount, cSCount); + ompi_java_forgetIntArray(env, rCount, jRCount, cRCount); + ompi_java_forgetIntArray(env, sDispls, jSDispls, cSDispls); + ompi_java_forgetIntArray(env, rDispls, jRDispls, cRDispls); + ompi_java_forgetIntArray(env, sBtypes, jSBtypes, cSBtypes); + ompi_java_forgetIntArray(env, rBtypes, jRBtypes, cRBtypes); + ompi_java_forgetIntArray(env, sOffs, jSOffs, cSOffs); + ompi_java_forgetIntArray(env, rOffs, jROffs, cROffs); + ompi_java_forgetDatatypeArray(env, sTypes, jSTypes, cSTypes); + ompi_java_forgetDatatypeArray(env, rTypes, jRTypes, cRTypes); } JNIEXPORT jlong JNICALL Java_mpi_Comm_iAllToAllw( diff --git a/ompi/mpi/java/c/mpi_MPI.c b/ompi/mpi/java/c/mpi_MPI.c index a2d4c4e672..d4c20ae044 100644 --- a/ompi/mpi/java/c/mpi_MPI.c +++ b/ompi/mpi/java/c/mpi_MPI.c @@ -17,6 +17,7 @@ * Copyright (c) 2015 Research Organization for Information Science * and Technology (RIST). All rights reserved. * Copyright (c) 2016-2017 IBM Corporation. All rights reserved. + * Copyright (c) 2019 FUJITSU LIMITED. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -672,6 +673,39 @@ static void* getReadPtrvRank( return ptr; } +static void* getReadPtrwRank( + ompi_java_buffer_t **item, JNIEnv *env, jobject buf, + int *offsets, int *counts, int *displs, int size, + int rank, MPI_Datatype *types, int *baseTypes) +{ + int extent = getTypeExtent(env, types[rank]), + length = getCountv(counts, displs, size); + void *ptr = getBuffer(env, item, length); + int rootOff = offsets[rank] + displs[rank]; + + if(opal_datatype_is_contiguous_memory_layout(&types[rank]->super, counts[rank])) + { + int rootLength = extent * counts[rank]; + void *rootPtr = (char*)ptr + displs[rank]; + getArrayRegion(env, buf, baseTypes[rank], rootOff, rootLength, rootPtr); + } + else + { + void *inBuf, *inBase; + inBuf = ompi_java_getArrayCritical(&inBase, env, buf, rootOff); + + int rc = opal_datatype_copy_content_same_ddt( + &types[rank]->super, counts[rank], ptr, inBuf); + + ompi_java_exceptionCheck(env, + rc==OPAL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR); + + (*env)->ReleasePrimitiveArrayCritical(env, buf, inBase, JNI_ABORT); + } + + return ptr; +} + static void* getReadPtrvAll( ompi_java_buffer_t **item, JNIEnv *env, jobject buf, int offset, int *counts, int *displs, int size, @@ -716,6 +750,49 @@ static void* getReadPtrvAll( return ptr; } +static void* getReadPtrwAll( + ompi_java_buffer_t **item, JNIEnv *env, jobject buf, + int *offsets, int *counts, int *displs, int size, + MPI_Datatype *types, int *baseTypes) +{ + + int length = getCountv(counts, displs, size); + void *ptr = getBuffer(env, item, length); + + for(int i = 0; i < size; i++) + { + int extent = getTypeExtent(env, types[i]); + + if(opal_datatype_is_contiguous_memory_layout(&types[i]->super, 2)) + { + int iOff = offsets[i] + displs[i], + iLen = extent * counts[i]; + void *iPtr = (char*)ptr + displs[i]; + getArrayRegion(env, buf, baseTypes[i], iOff, iLen, iPtr); + } + else + { + void *bufPtr, *bufBase; + bufPtr = ompi_java_getArrayCritical(&bufBase, env, buf, offsets[i]); + + int iOff = displs[i]; + char *iBuf = iOff + (char*)bufPtr, + *iPtr = iOff + (char*)ptr; + + int rc = opal_datatype_copy_content_same_ddt( + &types[i]->super, counts[i], iPtr, iBuf); + + ompi_java_exceptionCheck(env, + rc==OPAL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR); + + (*env)->ReleasePrimitiveArrayCritical(env, buf, bufBase, JNI_ABORT); + } + + } + + return ptr; +} + static void* getWritePtr(ompi_java_buffer_t **item, JNIEnv *env, int count, MPI_Datatype type) { @@ -735,6 +812,14 @@ static void* getWritePtrv(ompi_java_buffer_t **item, JNIEnv *env, return getBuffer(env, item, length); } +static void* getWritePtrw(ompi_java_buffer_t **item, JNIEnv *env, + int *counts, int *displs, int size, MPI_Datatype *types) +{ + int length = getCountv(counts, displs, size); + + return getBuffer(env, item, length); +} + void ompi_java_getReadPtr( void **ptr, ompi_java_buffer_t **item, JNIEnv *env, jobject buf, jboolean db, int offset, int count, MPI_Datatype type, int baseType) @@ -810,6 +895,39 @@ void ompi_java_getReadPtrv( } } +void ompi_java_getReadPtrw( + void **ptr, ompi_java_buffer_t **item, JNIEnv *env, + jobject buf, jboolean db, int *offsets, int *counts, int *displs, + int size, int rank, MPI_Datatype *types, int *baseTypes) +{ + int i; + + if(buf == NULL) + { + /* Allow NULL buffers to send/recv 0 items as control messages. */ + *ptr = NULL; + *item = NULL; + } + else if(db) + { + for(i = 0; i < size; i++){ + assert(offsets[i] == 0); + } + *ptr = (*env)->GetDirectBufferAddress(env, buf); + *item = NULL; + } + else if(rank == -1) + { + *ptr = getReadPtrwAll(item, env, buf, offsets, counts, + displs, size, types, baseTypes); + } + else + { + *ptr = getReadPtrwRank(item, env, buf, offsets, counts, + displs, size, rank, types, baseTypes); + } +} + void ompi_java_releaseReadPtr( void *ptr, ompi_java_buffer_t *item, jobject buf, jboolean db) { @@ -859,6 +977,27 @@ void ompi_java_getWritePtrv( } } +void ompi_java_getWritePtrw( + void **ptr, ompi_java_buffer_t **item, JNIEnv *env, jobject buf, + jboolean db, int *counts, int *displs, int size, MPI_Datatype *types) +{ + if(buf == NULL) + { + /* Allow NULL buffers to send/recv 0 items as control messages. */ + *ptr = NULL; + *item = NULL; + } + else if(db) + { + *ptr = (*env)->GetDirectBufferAddress(env, buf); + *item = NULL; + } + else + { + *ptr = getWritePtrw(item, env, counts, displs, size, types); + } +} + void ompi_java_releaseWritePtr( void *ptr, ompi_java_buffer_t *item, JNIEnv *env, jobject buf, jboolean db, int offset, int count, MPI_Datatype type, int baseType) @@ -933,6 +1072,49 @@ void ompi_java_releaseWritePtrv( releaseBuffer(ptr, item); } +void ompi_java_releaseWritePtrw( + void *ptr, ompi_java_buffer_t *item, JNIEnv *env, + jobject buf, jboolean db, int *offsets, int *counts, int *displs, + int size, MPI_Datatype *types, int *baseTypes) +{ + if(db || !buf || !ptr) + return; + + int i; + + for(i = 0; i < size; i++) + { + int extent = getTypeExtent(env, types[i]); + + if(opal_datatype_is_contiguous_memory_layout(&types[i]->super, 2)) + { + int iOff = offsets[i] + displs[i], + iLen = extent * counts[i]; + void *iPtr = (char*)ptr + displs[i]; + setArrayRegion(env, buf, baseTypes[i], iOff, iLen, iPtr); + } + else + { + void *bufPtr, *bufBase; + + bufPtr = ompi_java_getArrayCritical(&bufBase, env, buf, offsets[i]); + int iOff = displs[i]; + char *iBuf = iOff + (char*)bufPtr, + *iPtr = iOff + (char*)ptr; + + int rc = opal_datatype_copy_content_same_ddt( + &types[i]->super, counts[i], iBuf, iPtr); + + ompi_java_exceptionCheck(env, + rc==OPAL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR); + + (*env)->ReleasePrimitiveArrayCritical(env, buf, bufBase, 0); + } + + } + releaseBuffer(ptr, item); +} + jobject ompi_java_Integer_valueOf(JNIEnv *env, jint i) { return (*env)->CallStaticObjectMethod(env, diff --git a/ompi/mpi/java/java/Comm.java b/ompi/mpi/java/java/Comm.java index f51c28c798..3d973163d9 100644 --- a/ompi/mpi/java/java/Comm.java +++ b/ompi/mpi/java/java/Comm.java @@ -13,7 +13,7 @@ * and Technology (RIST). All rights reserved. * Copyright (c) 2015 Los Alamos National Security, LLC. All rights * reserved. - * Copyright (c) 2017-2018 FUJITSU LIMITED. All rights reserved. + * Copyright (c) 2017-2019 FUJITSU LIMITED. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -2400,24 +2400,51 @@ public class Comm implements Freeable, Cloneable * @throws MPIException Signals that an MPI exception of some sort has occurred. */ public final void allToAllw( - Buffer sendBuf, int[] sendCount, int[] sDispls, Datatype[] sendTypes, - Buffer recvBuf, int[] recvCount, int[] rDispls, Datatype[] recvTypes) + Object sendBuf, int[] sendCount, int[] sDispls, Datatype[] sendTypes, + Object recvBuf, int[] recvCount, int[] rDispls, Datatype[] recvTypes) throws MPIException { MPI.check(); - assertDirectBuffer(sendBuf, recvBuf); + + int[] sendoffs = new int[sendTypes.length]; + int[] recvoffs = new int[recvTypes.length]; + + boolean sdb = false, + rdb = false; + + if(sendBuf instanceof Buffer && !(sdb = ((Buffer)sendBuf).isDirect())) + { + + for (int i = 0; i < sendTypes.length; i++){ + sendoffs[i] = sendTypes[i].getOffset(sendBuf); + } + sendBuf = ((Buffer)sendBuf).array(); + } + + if(recvBuf instanceof Buffer && !(rdb = ((Buffer)recvBuf).isDirect())) + { + for (int i = 0; i < recvTypes.length; i++){ + recvoffs[i] = recvTypes[i].getOffset(recvBuf); + } + recvBuf = ((Buffer)recvBuf).array(); + } long[] sendHandles = convertTypeArray(sendTypes); long[] recvHandles = convertTypeArray(recvTypes); + int[] sendHandles_btypes = convertTypeArrayBtype(sendTypes); + int[] recvHandles_btypes = convertTypeArrayBtype(recvTypes); - allToAllw(handle, sendBuf, sendCount, sDispls, - sendHandles, recvBuf, recvCount, rDispls, - recvHandles); + allToAllw(handle, sendBuf, sdb, sendoffs, sendCount, sDispls, + sendHandles, sendHandles_btypes, + recvBuf, rdb, recvoffs, recvCount, rDispls, + recvHandles, recvHandles_btypes); } private native void allToAllw(long comm, - Buffer sendBuf, int[] sendCount, int[] sDispls, long[] sendTypes, - Buffer recvBuf, int[] recvCount, int[] rDispls, long[] recvTypes) + Object sendBuf, boolean sdb, int[] sendOffsets, + int[] sendCount, int[] sDispls, long[] sendTypes, int[] sendBaseTypes, + Object recvBuf, boolean rdb, int[] recvOffsets, + int[] recvCount, int[] rDispls, long[] recvTypes, int[] recvBaseTypes) throws MPIException; /** @@ -3421,4 +3448,21 @@ public class Comm implements Freeable, Cloneable return lArray; } + /** + * A helper method to convert an array of Datatypes to + * an array of ints (basetypes). + * @param dArray Array of Datatypes + * @return converted basetypes + */ + private int[] convertTypeArrayBtype(Datatype[] dArray) { + int[] lArray = new int[dArray.length]; + + for(int i = 0; i < lArray.length; i++) { + if(dArray[i] != null) { + lArray[i] = dArray[i].baseType; + } + } + return lArray; + } + } // Comm