diff --git a/ompi/mpi/cxx/comm.h b/ompi/mpi/cxx/comm.h index d22d2c13cf..a94d5c9170 100644 --- a/ompi/mpi/cxx/comm.h +++ b/ompi/mpi/cxx/comm.h @@ -445,8 +445,6 @@ public: // JGS hmmm, these used by errhandler_intercept my_errhandler = (Errhandler*)0; } - static Op* current_op; - #endif }; diff --git a/ompi/mpi/cxx/comm_inln.h b/ompi/mpi/cxx/comm_inln.h index d4e52112b0..b69014ec10 100644 --- a/ompi/mpi/cxx/comm_inln.h +++ b/ompi/mpi/cxx/comm_inln.h @@ -436,18 +436,14 @@ MPI::Comm::Reduce(const void *sendbuf, void *recvbuf, int count, const MPI::Datatype & datatype, const MPI::Op& op, int root) const { - current_op = const_cast(&op); (void)MPI_Reduce(const_cast(sendbuf), recvbuf, count, datatype, op, root, mpi_comm); - current_op = (Op*)0; } inline void MPI::Comm::Allreduce(const void *sendbuf, void *recvbuf, int count, const MPI::Datatype & datatype, const MPI::Op& op) const { - current_op = const_cast(&op); (void)MPI_Allreduce (const_cast(sendbuf), recvbuf, count, datatype, op, mpi_comm); - current_op = (Op*)0; } inline void @@ -456,10 +452,8 @@ MPI::Comm::Reduce_scatter(const void *sendbuf, void *recvbuf, const MPI::Datatype & datatype, const MPI::Op& op) const { - current_op = const_cast(&op); (void)MPI_Reduce_scatter(const_cast(sendbuf), recvbuf, recvcounts, datatype, op, mpi_comm); - current_op = (Op*)0; } // diff --git a/ompi/mpi/cxx/intercepts.cc b/ompi/mpi/cxx/intercepts.cc index a5a195af15..f530e5cc8c 100644 --- a/ompi/mpi/cxx/intercepts.cc +++ b/ompi/mpi/cxx/intercepts.cc @@ -61,23 +61,123 @@ void ompi_mpi_cxx_errhandler_intercept(MPI_Comm *mpi_comm, int *err, ...) } } -MPI::Op* MPI::Comm::current_op; - +// This is a bit weird; bear with me. The user-supplied function for +// MPI::Op contains a C++ object reference. So it must be called from +// a C++-compiled function. However, libmpi does not contain any C++ +// code because there are portability and bootstrapping issues +// involved if someone tries to make a 100% C application link against +// a libmpi that contains C++ code. At a minimum, the user will have +// to use the C++ compiler to link. LA-MPI has shown that users don't +// want to do this (there are other problems, but this one is easy to +// cite). +// +// Hence, there are two problems when trying to invoke the user's +// callback funcion from an MPI::Op: +// +// 1. The MPI_Datatype that the C library has must be converted to an +// (MPI::Datatype) +// 2. The C++ callback function must then be called with a +// (MPI::Datatype&) +// +// Some relevant facts for the discussion: +// +// - The main engine for invoking Op callback functions is in libmpi +// (i.e., in C code). +// +// - The C++ bindings are a thin layer on top of the C bindings. +// +// - The C++ bindings are a separate library from the C bindings +// (libmpi_cxx.la). +// +// - As a direct result, the mpiCC wrapper compiler must generate a +// link order thus: "... -lmpi_cxx -lmpi ...", meaning that we cannot +// have a direct function call from the libmpi to libmpi_cxx. We can +// only do it by function pointer. +// +// So the problem remains -- how to invoke a C++ MPI::Op callback +// function (which only occurrs for user-defined datatypes, BTW) from +// within the C Op callback engine in libmpi? +// +// It is easy to cache a function pointer to the +// ompi_mpi_cxx_op_intercept() function on the MPI_Op (that is located +// in the libmpi_cxx library, and is therefore compiled with a C++ +// compiler). But the normal C callback MPI_User_function type +// signature is (void*, void*, int*, MPI_Datatype*) -- so if +// ompi_mpi_cxx_op_intercept() is invoked with these arguments, it has +// no way to deduce what the user-specified callback function is that +// is associated with the MPI::Op. +// +// One can easily imagine a scenario of caching the callback pointer +// of the current MPI::Op in a global variable somewhere, and when +// ompi_mpi_cxx_op_intercept() is invoked, simply use that global +// variable. This is unfortunately not thread safe. +// +// So what we do is as follows: +// +// 1. The C++ dispatch function ompi_mpi_cxx_op_intercept() is *not* +// of type (MPI_User_function*). More specifically, it takes an +// additional argument: a function pointer. its signature is (void*, +// void*, int*, MPI_Datatype*, MPI_Op*, MPI::User_function*). This +// last argument is the function pointer of the user callback function +// to be invoked. +// +// The careful reader will notice that it is impossible for the C Op +// dispatch code in libmpi to call this function properly because the +// last argument is of a type that is not defined in libmpi (i.e., +// it's only in libmpi_cxx). Keep reading -- this is explained below. +// +// 2. When the MPI::Op is created (in MPI::Op::Init()), we call the +// back-end C MPI_Op_create() function as normal (just like the F77 +// bindings, in fact), and pass it the ompi_mpi_cxx_op_intercept() +// function (casting it to (MPI_User_function*) -- it's a function +// pointer, so its size is guaranteed to be the same, even if the +// signature of the real function is different). +// +// 3. The function pointer to ompi_mpi_cxx_op_intercept() will be +// cached in the MPI_Op in op->o_func[0].cxx_intercept_fn. +// +// Recall that MPI_Op is implemented to have an array of function +// pointers so that optimized versions of reduction operations can be +// invoked based on the corresponding datatype. But when an MPI_Op +// represents a user-defined function operation, there is only one +// function, so it is always stored in function pointer array index 0. +// +// 4. When MPI_Op_create() returns, the C++ MPI::Op::Init function +// manually sets OMPI_OP_FLAGS_CXX_FUNC flag on the resulting MPI_Op +// (again, very similar to the F77 MPI_OP_CREATE wrapper). It also +// caches the user's C++ callback function in op->o_func[1].c_fn +// (recall that the array of function pointers is actually a union of +// multiple different function pointer types -- it doesn't matter +// which type the user's callback function pointer is stored in; since +// all the types in the union are function pointers, it's guaranteed +// to be large enough to hold what we need. +// +// Note that we don't have a member of the union for the C++ callback +// function because its signature includes a (MPI::Datatype&), which +// we can't put in the C library libmpi. +// +// 5. When the user invokes an function that uses the MPI::Op (or, +// more specifically, when the Op dispatch engine in ompi/op/op.c [in +// libmpi] tries to dispatch off to it), it will see the +// OMPI_OP_FLAGS_CXX_FUNC flag and know to use the +// op->o_func[0].cxx_intercept_fn and also pass as the 4th argument, +// op->o_func[1].c_fn. +// +// 6. ompi_mpi_cxx_op_intercept() is therefore invoked and receives +// both the (MPI_Datatype*) (which is easy to convert to +// (MPI::Datatype&)) and a pointer to the user's C++ callback function +// (albiet cast as the wrong type). So it casts the callback function +// pointer to (MPI::User_function*) and invokes it. +// +// Wasn't that simple? +// extern "C" void ompi_mpi_cxx_op_intercept(void *invec, void *outvec, int *len, - MPI_Datatype *datatype) + MPI_Datatype *datatype, MPI_User_function *c_fn) { - MPI::Op* op = MPI::Comm::current_op; - MPI::Datatype thedata = *datatype; - ((MPI::User_function*)op->op_user_function)(invec, outvec, *len, thedata); - //JGS the above cast is a bit of a hack, I'll explain: - // the type for the PMPI::Op::op_user_function is PMPI::User_function - // but what it really stores is the user's MPI::User_function supplied when - // the user did an Op::Init. We need to cast the function pointer back to - // the MPI::User_function. The reason the PMPI::Op::op_user_function was - // not declared a MPI::User_function instead of a PMPI::User_function is - // that without namespaces we cannot do forward declarations. - // Anyway, without the cast the code breaks on HP LAM with the aCC compiler. + MPI::Datatype cxx_datatype = *datatype; + MPI::User_function *cxx_callback = (MPI::User_function*) c_fn; + cxx_callback(invec, outvec, *len, cxx_datatype); } extern "C" int diff --git a/ompi/mpi/cxx/intracomm_inln.h b/ompi/mpi/cxx/intracomm_inln.h index d2e4b2de97..d9343da304 100644 --- a/ompi/mpi/cxx/intracomm_inln.h +++ b/ompi/mpi/cxx/intracomm_inln.h @@ -37,9 +37,7 @@ inline void MPI::Intracomm::Scan(const void *sendbuf, void *recvbuf, int count, const MPI::Datatype & datatype, const MPI::Op& op) const { - current_op = const_cast(&op); (void)MPI_Scan(const_cast(sendbuf), recvbuf, count, datatype, op, mpi_comm); - current_op = (Op*)0; } inline void @@ -47,9 +45,7 @@ MPI::Intracomm::Exscan(const void *sendbuf, void *recvbuf, int count, const MPI::Datatype & datatype, const MPI::Op& op) const { - current_op = const_cast(&op); (void)MPI_Exscan(const_cast(sendbuf), recvbuf, count, datatype, op, mpi_comm); - current_op = (Op*)0; } inline MPI::Intracomm diff --git a/ompi/mpi/cxx/mpicxx.h b/ompi/mpi/cxx/mpicxx.h index 7fc73693e8..47373e0df0 100644 --- a/ompi/mpi/cxx/mpicxx.h +++ b/ompi/mpi/cxx/mpicxx.h @@ -39,10 +39,10 @@ // forward declare so that we can still do inlining struct opal_mutex_t; -//JGS: this is used for implementing user functions for MPI::Op +// See lengthy explanation in intercepts.cc about this function. extern "C" void ompi_mpi_cxx_op_intercept(void *invec, void *outvec, int *len, - MPI_Datatype *datatype); + MPI_Datatype *datatype, MPI_User_function *fn); //JGS: this is used as the MPI_Handler_function for // the mpi_errhandler in ERRORS_THROW_EXCEPTIONS diff --git a/ompi/mpi/cxx/op.h b/ompi/mpi/cxx/op.h index b93afcb9f9..455eed5547 100644 --- a/ompi/mpi/cxx/op.h +++ b/ompi/mpi/cxx/op.h @@ -47,7 +47,6 @@ public: virtual void Free(); #if ! 0 /* OMPI_ENABLE_MPI_PROFILING */ - User_function *op_user_function; //JGS move to private protected: MPI_Op mpi_op; #endif diff --git a/ompi/mpi/cxx/op_inln.h b/ompi/mpi/cxx/op_inln.h index d305a8a754..d48e0e7e0c 100644 --- a/ompi/mpi/cxx/op_inln.h +++ b/ompi/mpi/cxx/op_inln.h @@ -69,7 +69,7 @@ MPI::Op::Op(const MPI_Op &i) : mpi_op(i) { } inline MPI::Op::Op(const MPI::Op& op) - : op_user_function(op.op_user_function), mpi_op(op.mpi_op) { } + : mpi_op(op.mpi_op) { } inline MPI::Op::~Op() @@ -83,7 +83,6 @@ MPI::Op::~Op() inline MPI::Op& MPI::Op::operator=(const MPI::Op& op) { mpi_op = op.mpi_op; - op_user_function = op.op_user_function; return *this; } @@ -106,11 +105,21 @@ MPI::Op::operator MPI_Op () const { return mpi_op; } #endif +// Extern this function here rather than include an internal Open MPI +// header file (and therefore force installing the internal Open MPI +// header file so that user apps can #include it) + +extern "C" void ompi_op_set_cxx_callback(MPI_Op op, MPI_User_function*); + +// There is a lengthy comment in ompi/mpi/cxx/intercepts.cc explaining +// what this function is doing. Please read it before modifying this +// function. inline void MPI::Op::Init(MPI::User_function *func, bool commute) { - (void)MPI_Op_create(ompi_mpi_cxx_op_intercept , (int) commute, &mpi_op); - op_user_function = (User_function*)func; + (void)MPI_Op_create((MPI_User_function*) ompi_mpi_cxx_op_intercept, + (int) commute, &mpi_op); + ompi_op_set_cxx_callback(mpi_op, (MPI_User_function*) func); } diff --git a/ompi/op/op.c b/ompi/op/op.c index d169973588..bd421f0830 100644 --- a/ompi/op/op.c +++ b/ompi/op/op.c @@ -721,6 +721,13 @@ ompi_op_t *ompi_op_create(bool commute, } +void ompi_op_set_cxx_callback(ompi_op_t *op, MPI_User_function *fn) +{ + op->o_flags |= OMPI_OP_FLAGS_CXX_FUNC; + op->o_func[1].c_fn = fn; +} + + /************************************************************************** * * Static functions diff --git a/ompi/op/op.h b/ompi/op/op.h index 3e27e2fa4c..a32a0157b1 100644 --- a/ompi/op/op.h +++ b/ompi/op/op.h @@ -188,6 +188,17 @@ typedef void (ompi_op_fortran_handler_fn_t)(void *, void *, MPI_Fint *, MPI_Fint *); +/** + * Typedef for C++ op functions intercept. + * + * See the lengthy explanation for why this is different than the C + * intercept in ompi/mpi/cxx/intercepts.cc in the + * ompi_mpi_cxx_op_intercept() function. + */ +typedef void (ompi_op_cxx_handler_fn_t)(void *, void *, int *, + MPI_Datatype *, MPI_User_function *op); + + /* * Flags for MPI_Op */ @@ -195,18 +206,20 @@ typedef void (ompi_op_fortran_handler_fn_t)(void *, void *, #define OMPI_OP_FLAGS_INTRINSIC 0x0001 /** Set if the callback function is in Fortran */ #define OMPI_OP_FLAGS_FORTRAN_FUNC 0x0002 +/** Set if the callback function is in C++ */ +#define OMPI_OP_FLAGS_CXX_FUNC 0x0004 /** Set if the callback function is associative (MAX and SUM will both have ASSOC set -- in fact, it will only *not* be set if we implement some extensions to MPI, because MPI says that all MPI_Op's should be associative, so this flag is really here for future expansion) */ -#define OMPI_OP_FLAGS_ASSOC 0x0004 +#define OMPI_OP_FLAGS_ASSOC 0x0008 /** Set if the callback function is associative for floating point operands (e.g., MPI_SUM will have ASSOC set, but will *not* have FLOAT_ASSOC set) */ -#define OMPI_OP_FLAGS_FLOAT_ASSOC 0x0008 +#define OMPI_OP_FLAGS_FLOAT_ASSOC 0x0010 /** Set if the callback function is communative */ -#define OMPI_OP_FLAGS_COMMUTE 0x0010 +#define OMPI_OP_FLAGS_COMMUTE 0x0020 /** @@ -223,17 +236,21 @@ struct ompi_op_t { /**< Flags about the op */ union { - ompi_op_c_handler_fn_t *c_fn; - /**< C handler function pointer */ - ompi_op_fortran_handler_fn_t *fort_fn; - /**< Fortran handler function pointer */ + /** C handler function pointer */ + ompi_op_c_handler_fn_t *c_fn; + /** Fortran handler function pointer */ + ompi_op_fortran_handler_fn_t *fort_fn; + /** C++ intercept function pointer -- see lengthy comment in + ompi/mpi/cxx/intercepts.cc::ompi_mpi_cxx_op_intercept() for + an explanation */ + ompi_op_cxx_handler_fn_t *cxx_intercept_fn; } o_func[OMPI_OP_TYPE_MAX]; /**< Array of function pointers, indexed on the operation type. For non-intrinsice MPI_Op's, only the 0th element will be meaningful. */ + /** Index in Fortran <-> C translation array */ int o_f_to_c_index; - /**< Index in Fortran <-> C translation array */ }; /** * Convenience typedef @@ -403,6 +420,14 @@ extern "C" { */ ompi_op_t *ompi_op_create(bool commute, ompi_op_fortran_handler_fn_t *func); + /** + * Mark an MPI_Op as holding a C++ callback function, and cache + * that function in the MPI_Op. See a lenghty comment in + * ompi/mpi/cxx/op.c::ompi_mpi_cxx_op_intercept() for a full + * expalantion. + */ + void ompi_op_set_cxx_callback(ompi_op_t *op, MPI_User_function *fn); + #if defined(c_plusplus) || defined(__cplusplus) } #endif @@ -578,6 +603,9 @@ static inline void ompi_op_reduce(ompi_op_t *op, void *source, void *target, f_dtype = OMPI_INT_2_FINT(dtype->d_f_to_c_index); f_count = OMPI_INT_2_FINT(count); op->o_func[0].fort_fn(source, target, &f_count, &f_dtype); + } else if (0 != (op->o_flags & OMPI_OP_FLAGS_CXX_FUNC)) { + op->o_func[0].cxx_intercept_fn(source, target, &count, &dtype, + op->o_func[1].c_fn); } else { op->o_func[0].c_fn(source, target, &count, &dtype); }