diff --git a/ompi/mca/coll/hcoll/coll_hcoll.h b/ompi/mca/coll/hcoll/coll_hcoll.h index a151615ee5..e7c11b375b 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll.h +++ b/ompi/mca/coll/hcoll/coll_hcoll.h @@ -102,6 +102,14 @@ struct mca_coll_hcoll_module_t { mca_coll_base_module_t *previous_gatherv_module; mca_coll_base_module_reduce_scatter_fn_t previous_reduce_scatter; mca_coll_base_module_t *previous_reduce_scatter_module; + mca_coll_base_module_ibcast_fn_t previous_ibcast; + mca_coll_base_module_t *previous_ibcast_module; + mca_coll_base_module_ibarrier_fn_t previous_ibarrier; + mca_coll_base_module_t *previous_ibarrier_module; + mca_coll_base_module_iallgather_fn_t previous_iallgather; + mca_coll_base_module_t *previous_iallgather_module; + mca_coll_base_module_iallreduce_fn_t previous_iallreduce; + mca_coll_base_module_t *previous_iallreduce_module; }; typedef struct mca_coll_hcoll_module_t mca_coll_hcoll_module_t; @@ -143,6 +151,31 @@ int mca_coll_hcoll_alltoall(void *sbuf, int scount, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); +int mca_coll_hcoll_ibarrier(struct ompi_communicator_t *comm, + ompi_request_t** request, + mca_coll_base_module_t *module); + +int mca_coll_hcoll_ibcast(void *buff, int count, + struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, + ompi_request_t** request, + mca_coll_base_module_t *module); + +int mca_coll_hcoll_iallgather(void *sbuf, int scount, + struct ompi_datatype_t *sdtype, + void *rbuf, int rcount, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + ompi_request_t** request, + mca_coll_base_module_t *module); + +int mca_coll_hcoll_iallreduce(void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + ompi_request_t** request, + mca_coll_base_module_t *module); + END_C_DECLS #endif diff --git a/ompi/mca/coll/hcoll/coll_hcoll_module.c b/ompi/mca/coll/hcoll/coll_hcoll_module.c index 85ba484bab..35bf71af40 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_module.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_module.c @@ -35,6 +35,10 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module) hcoll_module->previous_alltoallv = NULL; hcoll_module->previous_alltoallw = NULL; hcoll_module->previous_reduce_scatter = NULL; + hcoll_module->previous_ibarrier = NULL; + hcoll_module->previous_ibcast = NULL; + hcoll_module->previous_iallreduce = NULL; + hcoll_module->previous_iallgather = NULL; } static void mca_coll_hcoll_module_construct(mca_coll_hcoll_module_t *hcoll_module) @@ -56,6 +60,10 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module OBJ_RELEASE(hcoll_module->previous_alltoallv_module); OBJ_RELEASE(hcoll_module->previous_alltoallw_module); OBJ_RELEASE(hcoll_module->previous_reduce_scatter_module); + OBJ_RELEASE(hcoll_module->previous_ibarrier_module); + OBJ_RELEASE(hcoll_module->previous_ibcast_module); + OBJ_RELEASE(hcoll_module->previous_iallreduce_module); + OBJ_RELEASE(hcoll_module->previous_iallgather_module); hcoll_destroy_context(hcoll_module->hcoll_context, (rte_grp_handle_t) hcoll_module->comm); mca_coll_hcoll_module_clear(hcoll_module); @@ -87,6 +95,10 @@ static int __save_coll_handlers(mca_coll_hcoll_module_t *hcoll_module) HCOL_SAVE_PREV_COLL_API(alltoallv); HCOL_SAVE_PREV_COLL_API(alltoallw); HCOL_SAVE_PREV_COLL_API(reduce_scatter); + HCOL_SAVE_PREV_COLL_API(ibarrier); + HCOL_SAVE_PREV_COLL_API(ibcast); + HCOL_SAVE_PREV_COLL_API(iallreduce); + HCOL_SAVE_PREV_COLL_API(iallgather); return OMPI_SUCCESS; } @@ -214,6 +226,10 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority) hcoll_module->super.coll_allgather = hcoll_collectives.coll_allgather ? mca_coll_hcoll_allgather : NULL; hcoll_module->super.coll_allreduce = hcoll_collectives.coll_allreduce ? mca_coll_hcoll_allreduce : NULL; hcoll_module->super.coll_alltoall = /*hcoll_collectives.coll_alltoall ? mca_coll_hcoll_alltoall : */ NULL; + hcoll_module->super.coll_ibarrier = hcoll_collectives.coll_ibarrier ? mca_coll_hcoll_ibarrier : NULL; + hcoll_module->super.coll_ibcast = hcoll_collectives.coll_ibcast ? mca_coll_hcoll_ibcast : NULL; + hcoll_module->super.coll_iallgather = hcoll_collectives.coll_iallgather ? mca_coll_hcoll_iallgather : NULL; + hcoll_module->super.coll_iallreduce = hcoll_collectives.coll_iallreduce ? mca_coll_hcoll_iallreduce : NULL; *priority = mca_coll_hcoll_component.hcoll_priority; module = &hcoll_module->super; diff --git a/ompi/mca/coll/hcoll/coll_hcoll_ops.c b/ompi/mca/coll/hcoll/coll_hcoll_ops.c index d2c436471f..d10c736f2e 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_ops.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_ops.c @@ -78,7 +78,7 @@ int mca_coll_hcoll_allgather(void *sbuf, int scount, rc = hcoll_module->previous_allgather(sbuf,scount,sdtype, rbuf,rcount,rdtype, comm, - hcoll_module->previous_bcast_module); + hcoll_module->previous_allgather_module); return rc; } rc = hcoll_collectives.coll_allgather(sbuf,scount,stype,rbuf,rcount,rtype,hcoll_module->hcoll_context); @@ -87,7 +87,7 @@ int mca_coll_hcoll_allgather(void *sbuf, int scount, rc = hcoll_module->previous_allgather(sbuf,scount,sdtype, rbuf,rcount,rdtype, comm, - hcoll_module->previous_bcast_module); + hcoll_module->previous_allgather_module); } return rc; } @@ -112,7 +112,7 @@ int mca_coll_hcoll_allreduce(void *sbuf, void *rbuf, int count, dtype->super.name); rc = hcoll_module->previous_allreduce(sbuf,rbuf, count,dtype,op, - comm, hcoll_module->previous_bcast_module); + comm, hcoll_module->previous_allreduce_module); return rc; } @@ -125,7 +125,7 @@ int mca_coll_hcoll_allreduce(void *sbuf, void *rbuf, int count, op->o_name); rc = hcoll_module->previous_allreduce(sbuf,rbuf, count,dtype,op, - comm, hcoll_module->previous_bcast_module); + comm, hcoll_module->previous_allreduce_module); return rc; } @@ -134,7 +134,7 @@ int mca_coll_hcoll_allreduce(void *sbuf, void *rbuf, int count, HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE"); rc = hcoll_module->previous_allreduce(sbuf,rbuf, count,dtype,op, - comm, hcoll_module->previous_bcast_module); + comm, hcoll_module->previous_allreduce_module); } return rc; } @@ -163,7 +163,7 @@ int mca_coll_hcoll_alltoall(void *sbuf, int scount, rc = hcoll_module->previous_alltoall(sbuf,scount,sdtype, rbuf,rcount,rdtype, comm, - hcoll_module->previous_bcast_module); + hcoll_module->previous_alltoall_module); return rc; } rc = hcoll_collectives.coll_alltoall(sbuf,scount,stype,rbuf,rcount,rtype,hcoll_module->hcoll_context); @@ -172,7 +172,149 @@ int mca_coll_hcoll_alltoall(void *sbuf, int scount, rc = hcoll_module->previous_alltoall(sbuf,scount,sdtype, rbuf,rcount,rdtype, comm, - hcoll_module->previous_bcast_module); + hcoll_module->previous_alltoall_module); } return rc; } + +int mca_coll_hcoll_ibarrier(struct ompi_communicator_t *comm, + ompi_request_t ** request, + mca_coll_base_module_t *module) +{ + int rc; + void** rt_handle; + HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING BARRIER"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + rt_handle = (void**) request; + rc = hcoll_collectives.coll_ibarrier(hcoll_module->hcoll_context, rt_handle); + if (HCOLL_SUCCESS != rc){ + HCOL_VERBOSE(20,"RUNNING FALLBACK NON-BLOCKING BARRIER"); + rc = hcoll_module->previous_ibarrier(comm, request, hcoll_module->previous_ibarrier_module); + } + return rc; +} + +int mca_coll_hcoll_ibcast(void *buff, int count, + struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, + ompi_request_t ** request, + mca_coll_base_module_t *module) +{ + dte_data_representation_t dtype; + int rc; + void** rt_handle; + HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING BCAST"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + rt_handle = (void**) request; + dtype = ompi_dtype_2_dte_dtype(datatype); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(dtype))){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: %s; calling fallback non-blocking bcast;",datatype->super.name); + rc = hcoll_module->previous_ibcast(buff,count,datatype,root, + comm, request, hcoll_module->previous_ibcast_module); + return rc; + } + rc = hcoll_collectives.coll_ibcast(buff, count, dtype, root, rt_handle, hcoll_module->hcoll_context); + if (HCOLL_SUCCESS != rc){ + HCOL_VERBOSE(20,"RUNNING FALLBACK NON-BLOCKING BCAST"); + rc = hcoll_module->previous_ibcast(buff,count,datatype,root, + comm, request, hcoll_module->previous_ibcast_module); + } + return rc; +} + +int mca_coll_hcoll_iallgather(void *sbuf, int scount, + struct ompi_datatype_t *sdtype, + void *rbuf, int rcount, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + ompi_request_t ** request, + mca_coll_base_module_t *module) +{ + dte_data_representation_t stype; + dte_data_representation_t rtype; + int rc; + void** rt_handle; + HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING ALLGATHER"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + rt_handle = (void**) request; + stype = ompi_dtype_2_dte_dtype(sdtype); + rtype = ompi_dtype_2_dte_dtype(rdtype); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: sdtype = %s, rdtype = %s; calling fallback non-blocking allgather;", + sdtype->super.name, + rdtype->super.name); + rc = hcoll_module->previous_iallgather(sbuf,scount,sdtype, + rbuf,rcount,rdtype, + comm, + request, + hcoll_module->previous_iallgather_module); + return rc; + } + rc = hcoll_collectives.coll_iallgather(sbuf, scount, stype, rbuf, rcount, rtype, hcoll_module->hcoll_context, rt_handle); + if (HCOLL_SUCCESS != rc){ + HCOL_VERBOSE(20,"RUNNING FALLBACK NON-BLOCKING ALLGATHER"); + rc = hcoll_module->previous_iallgather(sbuf,scount,sdtype, + rbuf,rcount,rdtype, + comm, + request, + hcoll_module->previous_iallgather_module); + } + return rc; +} + +int mca_coll_hcoll_iallreduce(void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + ompi_request_t ** request, + mca_coll_base_module_t *module) +{ + dte_data_representation_t Dtype; + hcoll_dte_op_t *Op; + int rc; + void** rt_handle; + HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING ALLREDUCE"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + rt_handle = (void**) request; + Dtype = ompi_dtype_2_dte_dtype(dtype); + if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback non-blocking allreduce;", + dtype->super.name); + rc = hcoll_module->previous_iallreduce(sbuf,rbuf, + count,dtype,op, + comm, request, hcoll_module->previous_iallreduce_module); + return rc; + } + + Op = ompi_op_2_hcolrte_op(op); + if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback non-blocking allreduce;", + op->o_name); + rc = hcoll_module->previous_iallreduce(sbuf,rbuf, + count,dtype,op, + comm, request, hcoll_module->previous_iallreduce_module); + return rc; + } + + rc = hcoll_collectives.coll_iallreduce(sbuf, rbuf, count, Dtype, Op, hcoll_module->hcoll_context, rt_handle); + if (HCOLL_SUCCESS != rc){ + HCOL_VERBOSE(20,"RUNNING FALLBACK NON-BLOCKING ALLREDUCE"); + rc = hcoll_module->previous_iallreduce(sbuf,rbuf, + count,dtype,op, + comm, request, hcoll_module->previous_iallreduce_module); + } + return rc; +} + diff --git a/ompi/mca/coll/hcoll/coll_hcoll_rte.c b/ompi/mca/coll/hcoll/coll_hcoll_rte.c index 6b014ec244..a93f589031 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_rte.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_rte.c @@ -390,6 +390,18 @@ static int group_id(rte_grp_handle_t group){ return ((ompi_communicator_t *)group)->c_contextid; } +static int +request_free(struct ompi_request_t **ompi_req) +{ + ompi_request_t *req = *ompi_req; + if (!coll_handle_test(req)) { + return OMPI_ERROR; + } + coll_handle_free(req); + *ompi_req = &ompi_request_empty; + return OMPI_SUCCESS; +} + static void* get_coll_handle(void) { ompi_request_t *ompi_req; @@ -403,6 +415,7 @@ static void* get_coll_handle(void) OMPI_REQUEST_INIT(ompi_req,false); ompi_req->req_complete_cb = NULL; ompi_req->req_status.MPI_ERROR = MPI_SUCCESS; + ompi_req->req_free = request_free; return (void *)ompi_req; } diff --git a/ompi/request/request.h b/ompi/request/request.h index 20737ae0ae..a8623a64fc 100644 --- a/ompi/request/request.h +++ b/ompi/request/request.h @@ -351,7 +351,9 @@ static inline int ompi_request_cancel(ompi_request_t* request) */ static inline int ompi_request_free(ompi_request_t** request) { - return (*request)->req_free(request); + if ((*request)->req_free) { + return (*request)->req_free(request); + } } #define ompi_request_test (ompi_request_functions.req_test)