1
1

Merge pull request #2218 from yosefe/topic/ucx-pml-spml-update

ucx: adapt pml_ucx and spml_ucx to new UCX APIs
Этот коммит содержится в:
Joshua Ladd 2016-10-13 09:23:37 -04:00 коммит произвёл GitHub
родитель 958e29f929 05ca466c6b
Коммит b661307e6f
7 изменённых файлов: 218 добавлений и 73 удалений

Просмотреть файл

@ -70,8 +70,8 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
1ul << (PML_UCX_TAG_BITS - 1), 1ul << (PML_UCX_TAG_BITS - 1),
1ul << (PML_UCX_CONTEXT_BITS), 1ul << (PML_UCX_CONTEXT_BITS),
}, },
NULL, NULL, /* ucp_context */
NULL NULL /* ucp_worker */
}; };
static int mca_pml_ucx_send_worker_address(void) static int mca_pml_ucx_send_worker_address(void)
@ -116,6 +116,7 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
int mca_pml_ucx_open(void) int mca_pml_ucx_open(void)
{ {
ucp_context_attr_t attr;
ucp_params_t params; ucp_params_t params;
ucp_config_t *config; ucp_config_t *config;
ucs_status_t status; ucs_status_t status;
@ -128,10 +129,17 @@ int mca_pml_ucx_open(void)
return OMPI_ERROR; return OMPI_ERROR;
} }
/* Initialize UCX context */
params.field_mask = UCP_PARAM_FIELD_FEATURES |
UCP_PARAM_FIELD_REQUEST_SIZE |
UCP_PARAM_FIELD_REQUEST_INIT |
UCP_PARAM_FIELD_REQUEST_CLEANUP |
UCP_PARAM_FIELD_TAG_SENDER_MASK;
params.features = UCP_FEATURE_TAG; params.features = UCP_FEATURE_TAG;
params.request_size = sizeof(ompi_request_t); params.request_size = sizeof(ompi_request_t);
params.request_init = mca_pml_ucx_request_init; params.request_init = mca_pml_ucx_request_init;
params.request_cleanup = mca_pml_ucx_request_cleanup; params.request_cleanup = mca_pml_ucx_request_cleanup;
params.tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK;
status = ucp_init(&params, config, &ompi_pml_ucx.ucp_context); status = ucp_init(&params, config, &ompi_pml_ucx.ucp_context);
ucp_config_release(config); ucp_config_release(config);
@ -140,6 +148,17 @@ int mca_pml_ucx_open(void)
return OMPI_ERROR; return OMPI_ERROR;
} }
/* Query UCX attributes */
attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
if (UCS_OK != status) {
ucp_cleanup(ompi_pml_ucx.ucp_context);
ompi_pml_ucx.ucp_context = NULL;
return OMPI_ERROR;
}
ompi_pml_ucx.request_size = attr.request_size;
return OMPI_SUCCESS; return OMPI_SUCCESS;
} }
@ -163,7 +182,7 @@ int mca_pml_ucx_init(void)
/* TODO check MPI thread mode */ /* TODO check MPI thread mode */
status = ucp_worker_create(ompi_pml_ucx.ucp_context, UCS_THREAD_MODE_SINGLE, status = ucp_worker_create(ompi_pml_ucx.ucp_context, UCS_THREAD_MODE_SINGLE,
&ompi_pml_ucx.ucp_worker); &ompi_pml_ucx.ucp_worker);
if (UCS_OK != status) { if (UCS_OK != status) {
return OMPI_ERROR; return OMPI_ERROR;
} }
@ -252,6 +271,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
{ {
ucp_address_t *address; ucp_address_t *address;
ucs_status_t status; ucs_status_t status;
ompi_proc_t *proc;
size_t addrlen; size_t addrlen;
ucp_ep_h ep; ucp_ep_h ep;
size_t i; size_t i;
@ -264,47 +284,109 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
} }
for (i = 0; i < nprocs; ++i) { for (i = 0; i < nprocs; ++i) {
ret = mca_pml_ucx_recv_worker_address(procs[i], &address, &addrlen); proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
ret = mca_pml_ucx_recv_worker_address(proc, &address, &addrlen);
if (ret < 0) { if (ret < 0) {
PML_UCX_ERROR("Failed to receive worker address from proc: %d", procs[i]->super.proc_name.vpid); PML_UCX_ERROR("Failed to receive worker address from proc: %d",
proc->super.proc_name.vpid);
return ret; return ret;
} }
if (procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) { if (proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) {
PML_UCX_VERBOSE(3, "already connected to proc. %d", procs[i]->super.proc_name.vpid); PML_UCX_VERBOSE(3, "already connected to proc. %d", proc->super.proc_name.vpid);
continue; continue;
} }
PML_UCX_VERBOSE(2, "connecting to proc. %d", procs[i]->super.proc_name.vpid); PML_UCX_VERBOSE(2, "connecting to proc. %d", proc->super.proc_name.vpid);
status = ucp_ep_create(ompi_pml_ucx.ucp_worker, address, &ep); status = ucp_ep_create(ompi_pml_ucx.ucp_worker, address, &ep);
free(address); free(address);
if (UCS_OK != status) { if (UCS_OK != status) {
PML_UCX_ERROR("Failed to connect to proc: %d, %s", procs[i]->super.proc_name.vpid, PML_UCX_ERROR("Failed to connect to proc: %d, %s", proc->super.proc_name.vpid,
ucs_status_string(status)); ucs_status_string(status));
return OMPI_ERROR; return OMPI_ERROR;
} }
procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep; proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
} }
return OMPI_SUCCESS; return OMPI_SUCCESS;
} }
static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
{
ucs_status_t status;
size_t i;
PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", *count_p);
for (i = 0; i < *count_p; ++i) {
do {
opal_progress();
status = ucp_request_test(reqs[i], NULL);
} while (status == UCS_INPROGRESS);
if (status != UCS_OK) {
PML_UCX_ERROR("disconnect request failed: %s",
ucs_status_string(status));
}
ucp_request_release(reqs[i]);
reqs[i] = NULL;
}
*count_p = 0;
}
int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs) int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
{ {
ompi_proc_t *proc;
size_t num_reqs, max_reqs;
void *dreq, **dreqs;
ucp_ep_h ep; ucp_ep_h ep;
size_t i; size_t i;
for (i = 0; i < nprocs; ++i) { max_reqs = ompi_pml_ucx.num_disconnect;
PML_UCX_VERBOSE(2, "disconnecting from rank %d", procs[i]->super.proc_name.vpid); if (max_reqs > nprocs) {
ep = procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]; max_reqs = nprocs;
if (ep != NULL) {
ucp_ep_destroy(ep);
}
procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
} }
dreqs = malloc(sizeof(*dreqs) * max_reqs);
if (dreqs == NULL) {
return OMPI_ERR_OUT_OF_RESOURCE;
}
num_reqs = 0;
for (i = 0; i < nprocs; ++i) {
proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
ep = proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
if (ep == NULL) {
continue;
}
PML_UCX_VERBOSE(2, "disconnecting from rank %d", proc->super.proc_name.vpid);
dreq = ucp_disconnect_nb(ep);
if (dreq != NULL) {
if (UCS_PTR_IS_ERR(dreq)) {
PML_UCX_ERROR("ucp_disconnect_nb(%d) failed: %s",
proc->super.proc_name.vpid,
ucs_status_string(UCS_PTR_STATUS(dreq)));
} else {
dreqs[num_reqs++] = dreq;
}
}
proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
if (num_reqs >= ompi_pml_ucx.num_disconnect) {
mca_pml_ucx_waitall(dreqs, &num_reqs);
}
}
mca_pml_ucx_waitall(dreqs, &num_reqs);
free(dreqs);
opal_pmix.fence(NULL, 0); opal_pmix.fence(NULL, 0);
return OMPI_SUCCESS; return OMPI_SUCCESS;
} }
@ -321,14 +403,7 @@ int mca_pml_ucx_enable(bool enable)
int mca_pml_ucx_progress(void) int mca_pml_ucx_progress(void)
{ {
static int inprogress = 0;
if (inprogress != 0) {
return 0;
}
++inprogress;
ucp_worker_progress(ompi_pml_ucx.ucp_worker); ucp_worker_progress(ompi_pml_ucx.ucp_worker);
--inprogress;
return OMPI_SUCCESS; return OMPI_SUCCESS;
} }
@ -393,52 +468,32 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
return OMPI_SUCCESS; return OMPI_SUCCESS;
} }
static void
mca_pml_ucx_blocking_recv_completion(void *request, ucs_status_t status,
ucp_tag_recv_info_t *info)
{
ompi_request_t *req = request;
PML_UCX_VERBOSE(8, "blocking receive request %p completed with status %s tag %"PRIx64" len %zu",
(void*)req, ucs_status_string(status), info->sender_tag,
info->length);
mca_pml_ucx_set_recv_status(&req->req_status, status, info);
PML_UCX_ASSERT( !(REQUEST_COMPLETE(req)));
ompi_request_complete(req,true);
}
int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src, int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src,
int tag, struct ompi_communicator_t* comm, int tag, struct ompi_communicator_t* comm,
ompi_status_public_t* mpi_status) ompi_status_public_t* mpi_status)
{ {
ucp_tag_t ucp_tag, ucp_tag_mask; ucp_tag_t ucp_tag, ucp_tag_mask;
ompi_request_t *req; ucp_tag_recv_info_t info;
ucs_status_t status;
void *req;
PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv"); PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");
PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm); PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count, req = alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size;
mca_pml_ucx_get_datatype(datatype), status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
ucp_tag, ucp_tag_mask, mca_pml_ucx_get_datatype(datatype),
mca_pml_ucx_blocking_recv_completion); ucp_tag, ucp_tag_mask, req);
if (UCS_PTR_IS_ERR(req)) {
PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
return OMPI_ERROR;
}
ucp_worker_progress(ompi_pml_ucx.ucp_worker); ucp_worker_progress(ompi_pml_ucx.ucp_worker);
while ( !REQUEST_COMPLETE(req) ) { for (;;) {
status = ucp_request_test(req, &info);
if (status != UCS_INPROGRESS) {
mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info);
return OMPI_SUCCESS;
}
opal_progress(); opal_progress();
} }
if (mpi_status != MPI_STATUS_IGNORE) {
*mpi_status = req->req_status;
}
req->req_complete = REQUEST_PENDING;
ucp_request_release(req);
return OMPI_SUCCESS;
} }
static inline const char *mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode) static inline const char *mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode)
@ -583,6 +638,7 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
*matched = 1; *matched = 1;
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info); mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
} else { } else {
opal_progress();
*matched = 0; *matched = 0;
} }
return OMPI_SUCCESS; return OMPI_SUCCESS;
@ -628,7 +684,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg); PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg);
*matched = 1; *matched = 1;
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info); mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
} else if (UCS_PTR_STATUS(ucp_msg) == UCS_ERR_NO_MESSAGE) { } else {
opal_progress();
*matched = 0; *matched = 0;
} }
return OMPI_SUCCESS; return OMPI_SUCCESS;

Просмотреть файл

@ -40,8 +40,10 @@ struct mca_pml_ucx_module {
/* Requests */ /* Requests */
mca_pml_ucx_freelist_t persistent_reqs; mca_pml_ucx_freelist_t persistent_reqs;
ompi_request_t completed_send_req; ompi_request_t completed_send_req;
size_t request_size;
int num_disconnect;
/* Convertors pool */ /* Converters pool */
mca_pml_ucx_freelist_t convs; mca_pml_ucx_freelist_t convs;
int priority; int priority;

Просмотреть файл

@ -63,6 +63,13 @@ static int mca_pml_ucx_component_register(void)
MCA_BASE_VAR_SCOPE_LOCAL, MCA_BASE_VAR_SCOPE_LOCAL,
&ompi_pml_ucx.priority); &ompi_pml_ucx.priority);
ompi_pml_ucx.num_disconnect = 1;
(void) mca_base_component_var_register(&mca_pml_ucx_component.pmlm_version, "num_disconnect",
"How may disconnects go in parallel",
MCA_BASE_VAR_TYPE_INT, NULL, 0, 0,
OPAL_INFO_LVL_3,
MCA_BASE_VAR_SCOPE_LOCAL,
&ompi_pml_ucx.num_disconnect);
return 0; return 0;
} }

Просмотреть файл

@ -34,6 +34,9 @@ enum {
#define PML_UCX_TAG_BITS 24 #define PML_UCX_TAG_BITS 24
#define PML_UCX_RANK_BITS 24 #define PML_UCX_RANK_BITS 24
#define PML_UCX_CONTEXT_BITS 16 #define PML_UCX_CONTEXT_BITS 16
#define PML_UCX_ANY_SOURCE_MASK 0x800000000000fffful
#define PML_UCX_SPECIFIC_SOURCE_MASK 0x800000fffffffffful
#define PML_UCX_TAG_MASK 0x7fffff0000000000ul
#define PML_UCX_MAKE_SEND_TAG(_tag, _comm) \ #define PML_UCX_MAKE_SEND_TAG(_tag, _comm) \
@ -45,16 +48,16 @@ enum {
#define PML_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _src, _comm) \ #define PML_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _src, _comm) \
{ \ { \
if ((_src) == MPI_ANY_SOURCE) { \ if ((_src) == MPI_ANY_SOURCE) { \
_ucp_tag_mask = 0x800000000000fffful; \ _ucp_tag_mask = PML_UCX_ANY_SOURCE_MASK; \
} else { \ } else { \
_ucp_tag_mask = 0x800000fffffffffful; \ _ucp_tag_mask = PML_UCX_SPECIFIC_SOURCE_MASK; \
} \ } \
\ \
_ucp_tag = (((uint64_t)(_src) & UCS_MASK(PML_UCX_RANK_BITS)) << PML_UCX_CONTEXT_BITS) | \ _ucp_tag = (((uint64_t)(_src) & UCS_MASK(PML_UCX_RANK_BITS)) << PML_UCX_CONTEXT_BITS) | \
(_comm)->c_contextid; \ (_comm)->c_contextid; \
\ \
if ((_tag) != MPI_ANY_TAG) { \ if ((_tag) != MPI_ANY_TAG) { \
_ucp_tag_mask |= 0x7fffff0000000000ul; \ _ucp_tag_mask |= PML_UCX_TAG_MASK; \
_ucp_tag |= ((uint64_t)(_tag)) << (PML_UCX_RANK_BITS + PML_UCX_CONTEXT_BITS); \ _ucp_tag |= ((uint64_t)(_tag)) << (PML_UCX_RANK_BITS + PML_UCX_CONTEXT_BITS); \
} \ } \
} }

Просмотреть файл

@ -65,7 +65,13 @@ mca_spml_ucx_t mca_spml_ucx = {
mca_spml_ucx_rmkey_unpack, mca_spml_ucx_rmkey_unpack,
mca_spml_ucx_rmkey_free, mca_spml_ucx_rmkey_free,
(void*)&mca_spml_ucx (void*)&mca_spml_ucx
} },
NULL, /* ucp_context */
NULL, /* ucp_worker */
NULL, /* ucp_peers */
0, /* using_mem_hooks */
1 /* num_disconnect */
}; };
int mca_spml_ucx_enable(bool enable) int mca_spml_ucx_enable(bool enable)
@ -80,10 +86,37 @@ int mca_spml_ucx_enable(bool enable)
return OSHMEM_SUCCESS; return OSHMEM_SUCCESS;
} }
static void mca_spml_ucx_waitall(void **reqs, size_t *count_p)
{
ucs_status_t status;
size_t i;
SPML_VERBOSE(10, "waiting for %d disconnect requests", *count_p);
for (i = 0; i < *count_p; ++i) {
do {
opal_progress();
status = ucp_request_test(reqs[i], NULL);
} while (status == UCS_INPROGRESS);
if (status != UCS_OK) {
SPML_ERROR("disconnect request failed: %s",
ucs_status_string(status));
}
ucp_request_release(reqs[i]);
reqs[i] = NULL;
}
*count_p = 0;
}
int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs) int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs)
{ {
size_t i, n;
int my_rank = oshmem_my_proc_id(); int my_rank = oshmem_my_proc_id();
size_t num_reqs, max_reqs;
void *dreq, **dreqs;
ompi_proc_t *proc;
ucp_ep_h ep;
size_t i, n;
oshmem_shmem_barrier(); oshmem_shmem_barrier();
@ -91,13 +124,47 @@ int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs)
return OSHMEM_SUCCESS; return OSHMEM_SUCCESS;
} }
for (n = 0; n < nprocs; n++) { max_reqs = mca_spml_ucx.num_disconnect;
i = (my_rank + n) % nprocs; if (max_reqs > nprocs) {
if (mca_spml_ucx.ucp_peers[i].ucp_conn) { max_reqs = nprocs;
ucp_ep_destroy(mca_spml_ucx.ucp_peers[i].ucp_conn); }
}
}
dreqs = malloc(sizeof(*dreqs) * max_reqs);
if (dreqs == NULL) {
return OMPI_ERR_OUT_OF_RESOURCE;
}
num_reqs = 0;
for (i = 0; i < nprocs; ++i) {
n = (i + my_rank) % nprocs;
ep = mca_spml_ucx.ucp_peers[n].ucp_conn;
if (ep == NULL) {
continue;
}
SPML_VERBOSE(10, "disconnecting from peer %d", n);
dreq = ucp_disconnect_nb(ep);
if (dreq != NULL) {
if (UCS_PTR_IS_ERR(dreq)) {
SPML_ERROR("ucp_disconnect_nb(%d) failed: %s", n,
ucs_status_string(UCS_PTR_STATUS(dreq)));
} else {
dreqs[num_reqs++] = dreq;
}
}
mca_spml_ucx.ucp_peers[n].ucp_conn = NULL;
if (num_reqs >= mca_spml_ucx.num_disconnect) {
mca_spml_ucx_waitall(dreqs, &num_reqs);
}
}
mca_spml_ucx_waitall(dreqs, &num_reqs);
free(dreqs);
opal_pmix.fence(NULL, 0);
free(mca_spml_ucx.ucp_peers); free(mca_spml_ucx.ucp_peers);
return OSHMEM_SUCCESS; return OSHMEM_SUCCESS;
} }

Просмотреть файл

@ -50,6 +50,7 @@ struct mca_spml_ucx {
ucp_context_h ucp_context; ucp_context_h ucp_context;
ucp_worker_h ucp_worker; ucp_worker_h ucp_worker;
ucp_peer_t *ucp_peers; ucp_peer_t *ucp_peers;
int num_disconnect;
int priority; /* component priority */ int priority; /* component priority */
bool enabled; bool enabled;

Просмотреть файл

@ -97,6 +97,10 @@ static int mca_spml_ucx_component_register(void)
"[integer] ucx priority", "[integer] ucx priority",
&mca_spml_ucx.priority); &mca_spml_ucx.priority);
mca_spml_ucx_param_register_int("num_disconnect", 1,
"How may disconnects go in parallel",
&mca_spml_ucx.num_disconnect);
return OSHMEM_SUCCESS; return OSHMEM_SUCCESS;
} }
@ -118,7 +122,8 @@ static int mca_spml_ucx_component_open(void)
} }
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));
params.features = UCP_FEATURE_RMA|UCP_FEATURE_AMO32|UCP_FEATURE_AMO64; params.field_mask = UCP_PARAM_FIELD_FEATURES;
params.features = UCP_FEATURE_RMA|UCP_FEATURE_AMO32|UCP_FEATURE_AMO64;
err = ucp_init(&params, ucp_config, &mca_spml_ucx.ucp_context); err = ucp_init(&params, ucp_config, &mca_spml_ucx.ucp_context);
ucp_config_release(ucp_config); ucp_config_release(ucp_config);
@ -131,7 +136,10 @@ static int mca_spml_ucx_component_open(void)
static int mca_spml_ucx_component_close(void) static int mca_spml_ucx_component_close(void)
{ {
ucp_cleanup(mca_spml_ucx.ucp_context); if (mca_spml_ucx.ucp_context) {
ucp_cleanup(mca_spml_ucx.ucp_context);
mca_spml_ucx.ucp_context = NULL;
}
return OSHMEM_SUCCESS; return OSHMEM_SUCCESS;
} }