1
1

Merge pull request #2218 from yosefe/topic/ucx-pml-spml-update

ucx: adapt pml_ucx and spml_ucx to new UCX APIs
Этот коммит содержится в:
Joshua Ladd 2016-10-13 09:23:37 -04:00 коммит произвёл GitHub
родитель 958e29f929 05ca466c6b
Коммит b661307e6f
7 изменённых файлов: 218 добавлений и 73 удалений

Просмотреть файл

@ -70,8 +70,8 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
1ul << (PML_UCX_TAG_BITS - 1),
1ul << (PML_UCX_CONTEXT_BITS),
},
NULL,
NULL
NULL, /* ucp_context */
NULL /* ucp_worker */
};
static int mca_pml_ucx_send_worker_address(void)
@ -116,6 +116,7 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
int mca_pml_ucx_open(void)
{
ucp_context_attr_t attr;
ucp_params_t params;
ucp_config_t *config;
ucs_status_t status;
@ -128,10 +129,17 @@ int mca_pml_ucx_open(void)
return OMPI_ERROR;
}
/* Initialize UCX context */
params.field_mask = UCP_PARAM_FIELD_FEATURES |
UCP_PARAM_FIELD_REQUEST_SIZE |
UCP_PARAM_FIELD_REQUEST_INIT |
UCP_PARAM_FIELD_REQUEST_CLEANUP |
UCP_PARAM_FIELD_TAG_SENDER_MASK;
params.features = UCP_FEATURE_TAG;
params.request_size = sizeof(ompi_request_t);
params.request_init = mca_pml_ucx_request_init;
params.request_cleanup = mca_pml_ucx_request_cleanup;
params.tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK;
status = ucp_init(&params, config, &ompi_pml_ucx.ucp_context);
ucp_config_release(config);
@ -140,6 +148,17 @@ int mca_pml_ucx_open(void)
return OMPI_ERROR;
}
/* Query UCX attributes */
attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
if (UCS_OK != status) {
ucp_cleanup(ompi_pml_ucx.ucp_context);
ompi_pml_ucx.ucp_context = NULL;
return OMPI_ERROR;
}
ompi_pml_ucx.request_size = attr.request_size;
return OMPI_SUCCESS;
}
@ -252,6 +271,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
{
ucp_address_t *address;
ucs_status_t status;
ompi_proc_t *proc;
size_t addrlen;
ucp_ep_h ep;
size_t i;
@ -264,47 +284,109 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
}
for (i = 0; i < nprocs; ++i) {
ret = mca_pml_ucx_recv_worker_address(procs[i], &address, &addrlen);
proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
ret = mca_pml_ucx_recv_worker_address(proc, &address, &addrlen);
if (ret < 0) {
PML_UCX_ERROR("Failed to receive worker address from proc: %d", procs[i]->super.proc_name.vpid);
PML_UCX_ERROR("Failed to receive worker address from proc: %d",
proc->super.proc_name.vpid);
return ret;
}
if (procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) {
PML_UCX_VERBOSE(3, "already connected to proc. %d", procs[i]->super.proc_name.vpid);
if (proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) {
PML_UCX_VERBOSE(3, "already connected to proc. %d", proc->super.proc_name.vpid);
continue;
}
PML_UCX_VERBOSE(2, "connecting to proc. %d", procs[i]->super.proc_name.vpid);
PML_UCX_VERBOSE(2, "connecting to proc. %d", proc->super.proc_name.vpid);
status = ucp_ep_create(ompi_pml_ucx.ucp_worker, address, &ep);
free(address);
if (UCS_OK != status) {
PML_UCX_ERROR("Failed to connect to proc: %d, %s", procs[i]->super.proc_name.vpid,
PML_UCX_ERROR("Failed to connect to proc: %d, %s", proc->super.proc_name.vpid,
ucs_status_string(status));
return OMPI_ERROR;
}
procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
}
return OMPI_SUCCESS;
}
static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
{
ucs_status_t status;
size_t i;
PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", *count_p);
for (i = 0; i < *count_p; ++i) {
do {
opal_progress();
status = ucp_request_test(reqs[i], NULL);
} while (status == UCS_INPROGRESS);
if (status != UCS_OK) {
PML_UCX_ERROR("disconnect request failed: %s",
ucs_status_string(status));
}
ucp_request_release(reqs[i]);
reqs[i] = NULL;
}
*count_p = 0;
}
int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
{
ompi_proc_t *proc;
size_t num_reqs, max_reqs;
void *dreq, **dreqs;
ucp_ep_h ep;
size_t i;
max_reqs = ompi_pml_ucx.num_disconnect;
if (max_reqs > nprocs) {
max_reqs = nprocs;
}
dreqs = malloc(sizeof(*dreqs) * max_reqs);
if (dreqs == NULL) {
return OMPI_ERR_OUT_OF_RESOURCE;
}
num_reqs = 0;
for (i = 0; i < nprocs; ++i) {
PML_UCX_VERBOSE(2, "disconnecting from rank %d", procs[i]->super.proc_name.vpid);
ep = procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
if (ep != NULL) {
ucp_ep_destroy(ep);
proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
ep = proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
if (ep == NULL) {
continue;
}
procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
PML_UCX_VERBOSE(2, "disconnecting from rank %d", proc->super.proc_name.vpid);
dreq = ucp_disconnect_nb(ep);
if (dreq != NULL) {
if (UCS_PTR_IS_ERR(dreq)) {
PML_UCX_ERROR("ucp_disconnect_nb(%d) failed: %s",
proc->super.proc_name.vpid,
ucs_status_string(UCS_PTR_STATUS(dreq)));
} else {
dreqs[num_reqs++] = dreq;
}
}
proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
if (num_reqs >= ompi_pml_ucx.num_disconnect) {
mca_pml_ucx_waitall(dreqs, &num_reqs);
}
}
mca_pml_ucx_waitall(dreqs, &num_reqs);
free(dreqs);
opal_pmix.fence(NULL, 0);
return OMPI_SUCCESS;
}
@ -321,14 +403,7 @@ int mca_pml_ucx_enable(bool enable)
int mca_pml_ucx_progress(void)
{
static int inprogress = 0;
if (inprogress != 0) {
return 0;
}
++inprogress;
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
--inprogress;
return OMPI_SUCCESS;
}
@ -393,52 +468,32 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
return OMPI_SUCCESS;
}
static void
mca_pml_ucx_blocking_recv_completion(void *request, ucs_status_t status,
ucp_tag_recv_info_t *info)
{
ompi_request_t *req = request;
PML_UCX_VERBOSE(8, "blocking receive request %p completed with status %s tag %"PRIx64" len %zu",
(void*)req, ucs_status_string(status), info->sender_tag,
info->length);
mca_pml_ucx_set_recv_status(&req->req_status, status, info);
PML_UCX_ASSERT( !(REQUEST_COMPLETE(req)));
ompi_request_complete(req,true);
}
int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src,
int tag, struct ompi_communicator_t* comm,
ompi_status_public_t* mpi_status)
{
ucp_tag_t ucp_tag, ucp_tag_mask;
ompi_request_t *req;
ucp_tag_recv_info_t info;
ucs_status_t status;
void *req;
PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");
PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
req = alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size;
status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
mca_pml_ucx_get_datatype(datatype),
ucp_tag, ucp_tag_mask,
mca_pml_ucx_blocking_recv_completion);
if (UCS_PTR_IS_ERR(req)) {
PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
return OMPI_ERROR;
}
ucp_tag, ucp_tag_mask, req);
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
while ( !REQUEST_COMPLETE(req) ) {
for (;;) {
status = ucp_request_test(req, &info);
if (status != UCS_INPROGRESS) {
mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info);
return OMPI_SUCCESS;
}
opal_progress();
}
if (mpi_status != MPI_STATUS_IGNORE) {
*mpi_status = req->req_status;
}
req->req_complete = REQUEST_PENDING;
ucp_request_release(req);
return OMPI_SUCCESS;
}
static inline const char *mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode)
@ -583,6 +638,7 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
*matched = 1;
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
} else {
opal_progress();
*matched = 0;
}
return OMPI_SUCCESS;
@ -628,7 +684,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg);
*matched = 1;
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
} else if (UCS_PTR_STATUS(ucp_msg) == UCS_ERR_NO_MESSAGE) {
} else {
opal_progress();
*matched = 0;
}
return OMPI_SUCCESS;

Просмотреть файл

@ -40,8 +40,10 @@ struct mca_pml_ucx_module {
/* Requests */
mca_pml_ucx_freelist_t persistent_reqs;
ompi_request_t completed_send_req;
size_t request_size;
int num_disconnect;
/* Convertors pool */
/* Converters pool */
mca_pml_ucx_freelist_t convs;
int priority;

Просмотреть файл

@ -63,6 +63,13 @@ static int mca_pml_ucx_component_register(void)
MCA_BASE_VAR_SCOPE_LOCAL,
&ompi_pml_ucx.priority);
ompi_pml_ucx.num_disconnect = 1;
(void) mca_base_component_var_register(&mca_pml_ucx_component.pmlm_version, "num_disconnect",
"How may disconnects go in parallel",
MCA_BASE_VAR_TYPE_INT, NULL, 0, 0,
OPAL_INFO_LVL_3,
MCA_BASE_VAR_SCOPE_LOCAL,
&ompi_pml_ucx.num_disconnect);
return 0;
}

Просмотреть файл

@ -34,6 +34,9 @@ enum {
#define PML_UCX_TAG_BITS 24
#define PML_UCX_RANK_BITS 24
#define PML_UCX_CONTEXT_BITS 16
#define PML_UCX_ANY_SOURCE_MASK 0x800000000000fffful
#define PML_UCX_SPECIFIC_SOURCE_MASK 0x800000fffffffffful
#define PML_UCX_TAG_MASK 0x7fffff0000000000ul
#define PML_UCX_MAKE_SEND_TAG(_tag, _comm) \
@ -45,16 +48,16 @@ enum {
#define PML_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _src, _comm) \
{ \
if ((_src) == MPI_ANY_SOURCE) { \
_ucp_tag_mask = 0x800000000000fffful; \
_ucp_tag_mask = PML_UCX_ANY_SOURCE_MASK; \
} else { \
_ucp_tag_mask = 0x800000fffffffffful; \
_ucp_tag_mask = PML_UCX_SPECIFIC_SOURCE_MASK; \
} \
\
_ucp_tag = (((uint64_t)(_src) & UCS_MASK(PML_UCX_RANK_BITS)) << PML_UCX_CONTEXT_BITS) | \
(_comm)->c_contextid; \
\
if ((_tag) != MPI_ANY_TAG) { \
_ucp_tag_mask |= 0x7fffff0000000000ul; \
_ucp_tag_mask |= PML_UCX_TAG_MASK; \
_ucp_tag |= ((uint64_t)(_tag)) << (PML_UCX_RANK_BITS + PML_UCX_CONTEXT_BITS); \
} \
}

Просмотреть файл

@ -65,7 +65,13 @@ mca_spml_ucx_t mca_spml_ucx = {
mca_spml_ucx_rmkey_unpack,
mca_spml_ucx_rmkey_free,
(void*)&mca_spml_ucx
}
},
NULL, /* ucp_context */
NULL, /* ucp_worker */
NULL, /* ucp_peers */
0, /* using_mem_hooks */
1 /* num_disconnect */
};
int mca_spml_ucx_enable(bool enable)
@ -80,10 +86,37 @@ int mca_spml_ucx_enable(bool enable)
return OSHMEM_SUCCESS;
}
static void mca_spml_ucx_waitall(void **reqs, size_t *count_p)
{
ucs_status_t status;
size_t i;
SPML_VERBOSE(10, "waiting for %d disconnect requests", *count_p);
for (i = 0; i < *count_p; ++i) {
do {
opal_progress();
status = ucp_request_test(reqs[i], NULL);
} while (status == UCS_INPROGRESS);
if (status != UCS_OK) {
SPML_ERROR("disconnect request failed: %s",
ucs_status_string(status));
}
ucp_request_release(reqs[i]);
reqs[i] = NULL;
}
*count_p = 0;
}
int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs)
{
size_t i, n;
int my_rank = oshmem_my_proc_id();
size_t num_reqs, max_reqs;
void *dreq, **dreqs;
ompi_proc_t *proc;
ucp_ep_h ep;
size_t i, n;
oshmem_shmem_barrier();
@ -91,13 +124,47 @@ int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs)
return OSHMEM_SUCCESS;
}
for (n = 0; n < nprocs; n++) {
i = (my_rank + n) % nprocs;
if (mca_spml_ucx.ucp_peers[i].ucp_conn) {
ucp_ep_destroy(mca_spml_ucx.ucp_peers[i].ucp_conn);
max_reqs = mca_spml_ucx.num_disconnect;
if (max_reqs > nprocs) {
max_reqs = nprocs;
}
dreqs = malloc(sizeof(*dreqs) * max_reqs);
if (dreqs == NULL) {
return OMPI_ERR_OUT_OF_RESOURCE;
}
num_reqs = 0;
for (i = 0; i < nprocs; ++i) {
n = (i + my_rank) % nprocs;
ep = mca_spml_ucx.ucp_peers[n].ucp_conn;
if (ep == NULL) {
continue;
}
SPML_VERBOSE(10, "disconnecting from peer %d", n);
dreq = ucp_disconnect_nb(ep);
if (dreq != NULL) {
if (UCS_PTR_IS_ERR(dreq)) {
SPML_ERROR("ucp_disconnect_nb(%d) failed: %s", n,
ucs_status_string(UCS_PTR_STATUS(dreq)));
} else {
dreqs[num_reqs++] = dreq;
}
}
mca_spml_ucx.ucp_peers[n].ucp_conn = NULL;
if (num_reqs >= mca_spml_ucx.num_disconnect) {
mca_spml_ucx_waitall(dreqs, &num_reqs);
}
}
mca_spml_ucx_waitall(dreqs, &num_reqs);
free(dreqs);
opal_pmix.fence(NULL, 0);
free(mca_spml_ucx.ucp_peers);
return OSHMEM_SUCCESS;
}

Просмотреть файл

@ -50,6 +50,7 @@ struct mca_spml_ucx {
ucp_context_h ucp_context;
ucp_worker_h ucp_worker;
ucp_peer_t *ucp_peers;
int num_disconnect;
int priority; /* component priority */
bool enabled;

Просмотреть файл

@ -97,6 +97,10 @@ static int mca_spml_ucx_component_register(void)
"[integer] ucx priority",
&mca_spml_ucx.priority);
mca_spml_ucx_param_register_int("num_disconnect", 1,
"How may disconnects go in parallel",
&mca_spml_ucx.num_disconnect);
return OSHMEM_SUCCESS;
}
@ -118,6 +122,7 @@ static int mca_spml_ucx_component_open(void)
}
memset(&params, 0, sizeof(params));
params.field_mask = UCP_PARAM_FIELD_FEATURES;
params.features = UCP_FEATURE_RMA|UCP_FEATURE_AMO32|UCP_FEATURE_AMO64;
err = ucp_init(&params, ucp_config, &mca_spml_ucx.ucp_context);
@ -131,7 +136,10 @@ static int mca_spml_ucx_component_open(void)
static int mca_spml_ucx_component_close(void)
{
if (mca_spml_ucx.ucp_context) {
ucp_cleanup(mca_spml_ucx.ucp_context);
mca_spml_ucx.ucp_context = NULL;
}
return OSHMEM_SUCCESS;
}