493 строки
12 KiB
C
493 строки
12 KiB
C
/*
|
|
* Copyright (c) 2013 Mellanox Technologies, Inc.
|
|
* All rights reserved.
|
|
* Copyright (c) 2014 Research Organization for Information Science
|
|
* and Technology (RIST). All rights reserved.
|
|
* $COPYRIGHT$
|
|
*
|
|
* Additional copyrights may follow
|
|
*
|
|
* $HEADER$
|
|
*/
|
|
|
|
#define _GNU_SOURCE
|
|
#include <stdio.h>
|
|
|
|
#include <sys/types.h>
|
|
#include <unistd.h>
|
|
#include <stdint.h>
|
|
|
|
#include "oshmem_config.h"
|
|
#include "opal/datatype/opal_convertor.h"
|
|
#include "orte/include/orte/types.h"
|
|
#include "orte/runtime/orte_globals.h"
|
|
#include "ompi/datatype/ompi_datatype.h"
|
|
#include "ompi/mca/pml/pml.h"
|
|
|
|
|
|
#include "oshmem/mca/spml/ucx/spml_ucx.h"
|
|
#include "oshmem/include/shmem.h"
|
|
#include "oshmem/mca/memheap/memheap.h"
|
|
#include "oshmem/mca/memheap/base/base.h"
|
|
#include "oshmem/proc/proc.h"
|
|
#include "oshmem/mca/spml/base/base.h"
|
|
#include "oshmem/mca/spml/base/spml_base_putreq.h"
|
|
#include "oshmem/runtime/runtime.h"
|
|
#include "orte/util/show_help.h"
|
|
|
|
#include "oshmem/mca/spml/ucx/spml_ucx_component.h"
|
|
|
|
/* Turn ON/OFF debug output from build (default 0) */
|
|
#ifndef SPML_UCX_PUT_DEBUG
|
|
#define SPML_UCX_PUT_DEBUG 0
|
|
#endif
|
|
|
|
|
|
mca_spml_ucx_t mca_spml_ucx = {
|
|
{
|
|
/* Init mca_spml_base_module_t */
|
|
mca_spml_ucx_add_procs,
|
|
mca_spml_ucx_del_procs,
|
|
mca_spml_ucx_enable,
|
|
mca_spml_ucx_register,
|
|
mca_spml_ucx_deregister,
|
|
mca_spml_base_oob_get_mkeys,
|
|
mca_spml_ucx_put,
|
|
mca_spml_ucx_put_nb,
|
|
mca_spml_ucx_get,
|
|
mca_spml_ucx_get_nb,
|
|
mca_spml_ucx_recv,
|
|
mca_spml_ucx_send,
|
|
mca_spml_base_wait,
|
|
mca_spml_base_wait_nb,
|
|
mca_spml_ucx_quiet, /* At the moment fence is the same as quite for
|
|
every spml */
|
|
mca_spml_ucx_rmkey_unpack,
|
|
mca_spml_ucx_rmkey_free,
|
|
(void*)&mca_spml_ucx
|
|
}
|
|
};
|
|
|
|
int mca_spml_ucx_enable(bool enable)
|
|
{
|
|
SPML_VERBOSE(50, "*** ucx ENABLED ****");
|
|
if (false == enable) {
|
|
return OSHMEM_SUCCESS;
|
|
}
|
|
|
|
mca_spml_ucx.enabled = true;
|
|
|
|
return OSHMEM_SUCCESS;
|
|
}
|
|
|
|
int mca_spml_ucx_del_procs(oshmem_proc_t** procs, size_t nprocs)
|
|
{
|
|
size_t i, n;
|
|
int my_rank = oshmem_my_proc_id();
|
|
|
|
oshmem_shmem_barrier();
|
|
|
|
if (!mca_spml_ucx.ucp_peers) {
|
|
return OSHMEM_SUCCESS;
|
|
}
|
|
|
|
for (n = 0; n < nprocs; n++) {
|
|
i = (my_rank + n) % nprocs;
|
|
if (mca_spml_ucx.ucp_peers[i].ucp_conn) {
|
|
ucp_ep_destroy(mca_spml_ucx.ucp_peers[i].ucp_conn);
|
|
}
|
|
}
|
|
|
|
free(mca_spml_ucx.ucp_peers);
|
|
return OSHMEM_SUCCESS;
|
|
}
|
|
|
|
/* TODO: move func into common place, use it with rkey exchng too */
|
|
static int oshmem_shmem_xchng(
|
|
void *local_data, int local_size, int nprocs,
|
|
void **rdata_p, int **roffsets_p, int **rsizes_p)
|
|
{
|
|
int *rcv_sizes = NULL;
|
|
int *rcv_offsets = NULL;
|
|
void *rcv_buf = NULL;
|
|
int rc;
|
|
int i;
|
|
|
|
/* do llgatherv */
|
|
rcv_offsets = malloc(nprocs * sizeof(*rcv_offsets));
|
|
if (NULL == rcv_offsets) {
|
|
goto err;
|
|
}
|
|
|
|
/* todo: move into separate function. do allgatherv */
|
|
rcv_sizes = malloc(nprocs * sizeof(*rcv_sizes));
|
|
if (NULL == rcv_sizes) {
|
|
goto err;
|
|
}
|
|
|
|
rc = oshmem_shmem_allgather(&local_size, rcv_sizes, sizeof(int));
|
|
if (MPI_SUCCESS != rc) {
|
|
goto err;
|
|
}
|
|
|
|
/* calculate displacements */
|
|
rcv_offsets[0] = 0;
|
|
for (i = 1; i < nprocs; i++) {
|
|
rcv_offsets[i] = rcv_offsets[i - 1] + rcv_sizes[i - 1];
|
|
}
|
|
|
|
rcv_buf = malloc(rcv_offsets[nprocs - 1] + rcv_sizes[nprocs - 1]);
|
|
if (NULL == rcv_buf) {
|
|
goto err;
|
|
}
|
|
|
|
rc = oshmem_shmem_allgatherv(local_data, rcv_buf, local_size, rcv_sizes, rcv_offsets);
|
|
if (MPI_SUCCESS != rc) {
|
|
goto err;
|
|
}
|
|
|
|
*rdata_p = rcv_buf;
|
|
*roffsets_p = rcv_offsets;
|
|
*rsizes_p = rcv_sizes;
|
|
return OSHMEM_SUCCESS;
|
|
|
|
err:
|
|
if (rcv_buf)
|
|
free(rcv_buf);
|
|
if (rcv_offsets)
|
|
free(rcv_offsets);
|
|
if (rcv_sizes)
|
|
free(rcv_sizes);
|
|
return OSHMEM_ERROR;
|
|
}
|
|
|
|
static void dump_address(int pe, char *addr, size_t len)
|
|
{
|
|
#ifdef SPML_UCX_DEBUG
|
|
int my_rank = oshmem_my_proc_id();
|
|
unsigned i;
|
|
|
|
printf("me=%d dest_pe=%d addr=%p len=%d\n", my_rank, pe, addr, len);
|
|
for (i = 0; i < len; i++) {
|
|
printf("%02X ", (unsigned)0xFF&addr[i]);
|
|
}
|
|
printf("\n");
|
|
#endif
|
|
}
|
|
|
|
static char spml_ucx_transport_ids[1] = { 0 };
|
|
|
|
int mca_spml_ucx_add_procs(oshmem_proc_t** procs, size_t nprocs)
|
|
{
|
|
size_t i, n;
|
|
int rc = OSHMEM_ERROR;
|
|
int my_rank = oshmem_my_proc_id();
|
|
ucs_status_t err;
|
|
ucp_address_t *wk_local_addr;
|
|
size_t wk_addr_len;
|
|
int *wk_roffs, *wk_rsizes;
|
|
char *wk_raddrs;
|
|
|
|
|
|
mca_spml_ucx.ucp_peers = (ucp_peer_t *) calloc(nprocs, sizeof(*(mca_spml_ucx.ucp_peers)));
|
|
if (NULL == mca_spml_ucx.ucp_peers) {
|
|
goto error;
|
|
}
|
|
|
|
err = ucp_worker_get_address(mca_spml_ucx.ucp_worker, &wk_local_addr, &wk_addr_len);
|
|
if (err != UCS_OK) {
|
|
goto error;
|
|
}
|
|
dump_address(my_rank, (char *)wk_local_addr, wk_addr_len);
|
|
|
|
rc = oshmem_shmem_xchng(wk_local_addr, wk_addr_len, nprocs,
|
|
(void **)&wk_raddrs, &wk_roffs, &wk_rsizes);
|
|
if (rc != OSHMEM_SUCCESS) {
|
|
goto error;
|
|
}
|
|
|
|
opal_progress_register(spml_ucx_progress);
|
|
|
|
/* Get the EP connection requests for all the processes from modex */
|
|
for (n = 0; n < nprocs; ++n) {
|
|
i = (my_rank + n) % nprocs;
|
|
dump_address(i, (char *)(wk_raddrs + wk_roffs[i]), wk_rsizes[i]);
|
|
err = ucp_ep_create(mca_spml_ucx.ucp_worker,
|
|
(ucp_address_t *)(wk_raddrs + wk_roffs[i]),
|
|
&mca_spml_ucx.ucp_peers[i].ucp_conn);
|
|
if (UCS_OK != err) {
|
|
SPML_ERROR("ucp_ep_create failed!!!\n");
|
|
goto error2;
|
|
}
|
|
procs[i]->num_transports = 1;
|
|
procs[i]->transport_ids = spml_ucx_transport_ids;
|
|
}
|
|
|
|
ucp_worker_release_address(mca_spml_ucx.ucp_worker, wk_local_addr);
|
|
free(wk_raddrs);
|
|
free(wk_rsizes);
|
|
free(wk_roffs);
|
|
|
|
SPML_VERBOSE(50, "*** ADDED PROCS ***");
|
|
return OSHMEM_SUCCESS;
|
|
|
|
error2:
|
|
for (i = 0; i < nprocs; ++i) {
|
|
if (mca_spml_ucx.ucp_peers[i].ucp_conn) {
|
|
ucp_ep_destroy(mca_spml_ucx.ucp_peers[i].ucp_conn);
|
|
}
|
|
}
|
|
if (mca_spml_ucx.ucp_peers)
|
|
free(mca_spml_ucx.ucp_peers);
|
|
if (wk_raddrs)
|
|
free(wk_raddrs);
|
|
if (wk_rsizes)
|
|
free(wk_rsizes);
|
|
if (wk_roffs)
|
|
free(wk_roffs);
|
|
if (mca_spml_ucx.ucp_peers)
|
|
free(mca_spml_ucx.ucp_peers);
|
|
error:
|
|
rc = OSHMEM_ERR_OUT_OF_RESOURCE;
|
|
SPML_ERROR("add procs FAILED rc=%d", rc);
|
|
return rc;
|
|
|
|
}
|
|
|
|
void mca_spml_ucx_rmkey_free(sshmem_mkey_t *mkey)
|
|
{
|
|
spml_ucx_mkey_t *ucx_mkey;
|
|
|
|
if (!mkey->spml_context) {
|
|
return;
|
|
}
|
|
ucx_mkey = (spml_ucx_mkey_t *)(mkey->spml_context);
|
|
ucp_rkey_destroy(ucx_mkey->rkey);
|
|
free(ucx_mkey);
|
|
}
|
|
|
|
void mca_spml_ucx_rmkey_unpack(sshmem_mkey_t *mkey, int pe)
|
|
{
|
|
spml_ucx_mkey_t *ucx_mkey;
|
|
ucs_status_t err;
|
|
|
|
ucx_mkey = (spml_ucx_mkey_t *)malloc(sizeof(*ucx_mkey));
|
|
if (!ucx_mkey) {
|
|
SPML_ERROR("not enough memory to allocate mkey");
|
|
goto error_fatal;
|
|
}
|
|
|
|
err = ucp_ep_rkey_unpack(mca_spml_ucx.ucp_peers[pe].ucp_conn,
|
|
mkey->u.data,
|
|
&ucx_mkey->rkey);
|
|
if (UCS_OK != err) {
|
|
SPML_ERROR("failed to unpack rkey");
|
|
goto error_fatal;
|
|
}
|
|
|
|
mkey->spml_context = ucx_mkey;
|
|
return;
|
|
|
|
error_fatal:
|
|
oshmem_shmem_abort(-1);
|
|
return;
|
|
}
|
|
|
|
sshmem_mkey_t *mca_spml_ucx_register(void* addr,
|
|
size_t size,
|
|
uint64_t shmid,
|
|
int *count)
|
|
{
|
|
sshmem_mkey_t *mkeys;
|
|
ucs_status_t err;
|
|
spml_ucx_mkey_t *ucx_mkey;
|
|
size_t len;
|
|
|
|
*count = 0;
|
|
mkeys = (sshmem_mkey_t *) calloc(1, sizeof(*mkeys));
|
|
if (!mkeys) {
|
|
return NULL ;
|
|
}
|
|
|
|
ucx_mkey = (spml_ucx_mkey_t *)malloc(sizeof(*ucx_mkey));
|
|
if (!ucx_mkey) {
|
|
goto error_out;
|
|
}
|
|
|
|
mkeys[0].spml_context = ucx_mkey;
|
|
err = ucp_mem_map(mca_spml_ucx.ucp_context,
|
|
&addr, size, 0, &ucx_mkey->mem_h);
|
|
if (UCS_OK != err) {
|
|
goto error_out1;
|
|
}
|
|
|
|
err = ucp_rkey_pack(mca_spml_ucx.ucp_context, ucx_mkey->mem_h,
|
|
&mkeys[0].u.data, &len);
|
|
if (UCS_OK != err) {
|
|
goto error_unmap;
|
|
}
|
|
if (len >= 0xffff) {
|
|
SPML_ERROR("packed rkey is too long: %llu >= %d",
|
|
(unsigned long long)len,
|
|
0xffff);
|
|
oshmem_shmem_abort(-1);
|
|
}
|
|
|
|
err = ucp_ep_rkey_unpack(mca_spml_ucx.ucp_peers[oshmem_group_self->my_pe].ucp_conn,
|
|
mkeys[0].u.data,
|
|
&ucx_mkey->rkey);
|
|
if (UCS_OK != err) {
|
|
SPML_ERROR("failed to unpack rkey");
|
|
goto error_unmap;
|
|
}
|
|
|
|
mkeys[0].len = len;
|
|
mkeys[0].va_base = addr;
|
|
*count = 1;
|
|
return mkeys;
|
|
|
|
error_unmap:
|
|
ucp_mem_unmap(mca_spml_ucx.ucp_context, ucx_mkey->mem_h);
|
|
error_out1:
|
|
free(ucx_mkey);
|
|
error_out:
|
|
free(mkeys);
|
|
|
|
return NULL ;
|
|
}
|
|
|
|
int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
|
|
{
|
|
spml_ucx_mkey_t *ucx_mkey;
|
|
|
|
MCA_SPML_CALL(fence());
|
|
if (!mkeys)
|
|
return OSHMEM_SUCCESS;
|
|
|
|
if (!mkeys[0].spml_context)
|
|
return OSHMEM_SUCCESS;
|
|
|
|
ucx_mkey = (spml_ucx_mkey_t *)mkeys[0].spml_context;
|
|
ucp_mem_unmap(mca_spml_ucx.ucp_context, ucx_mkey->mem_h);
|
|
|
|
if (0 < mkeys[0].len) {
|
|
ucp_rkey_buffer_release(mkeys[0].u.data);
|
|
}
|
|
|
|
free(ucx_mkey);
|
|
return OSHMEM_SUCCESS;
|
|
}
|
|
|
|
int mca_spml_ucx_get(void *src_addr, size_t size, void *dst_addr, int src)
|
|
{
|
|
void *rva;
|
|
ucs_status_t status;
|
|
spml_ucx_mkey_t *ucx_mkey;
|
|
|
|
ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva);
|
|
status = ucp_get(mca_spml_ucx.ucp_peers[src].ucp_conn, dst_addr, size,
|
|
(uint64_t)rva, ucx_mkey->rkey);
|
|
|
|
return ucx_status_to_oshmem(status);
|
|
}
|
|
|
|
int mca_spml_ucx_get_nb(void *src_addr, size_t size, void *dst_addr, int src, void **handle)
|
|
{
|
|
void *rva;
|
|
ucs_status_t status;
|
|
spml_ucx_mkey_t *ucx_mkey;
|
|
|
|
ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva);
|
|
status = ucp_get_nbi(mca_spml_ucx.ucp_peers[src].ucp_conn, dst_addr, size,
|
|
(uint64_t)rva, ucx_mkey->rkey);
|
|
|
|
return ucx_status_to_oshmem(status);
|
|
}
|
|
|
|
int mca_spml_ucx_put(void* dst_addr, size_t size, void* src_addr, int dst)
|
|
{
|
|
void *rva;
|
|
ucs_status_t status;
|
|
spml_ucx_mkey_t *ucx_mkey;
|
|
|
|
ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva);
|
|
status = ucp_put(mca_spml_ucx.ucp_peers[dst].ucp_conn, src_addr, size,
|
|
(uint64_t)rva, ucx_mkey->rkey);
|
|
|
|
return ucx_status_to_oshmem(status);
|
|
}
|
|
|
|
int mca_spml_ucx_put_nb(void* dst_addr, size_t size, void* src_addr, int dst, void **handle)
|
|
{
|
|
void *rva;
|
|
ucs_status_t status;
|
|
spml_ucx_mkey_t *ucx_mkey;
|
|
|
|
ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva);
|
|
status = ucp_put_nbi(mca_spml_ucx.ucp_peers[dst].ucp_conn, src_addr, size,
|
|
(uint64_t)rva, ucx_mkey->rkey);
|
|
|
|
return ucx_status_to_oshmem(status);
|
|
}
|
|
|
|
int mca_spml_ucx_fence(void)
|
|
{
|
|
ucs_status_t err;
|
|
|
|
err = ucp_worker_flush(mca_spml_ucx.ucp_worker);
|
|
if (UCS_OK != err) {
|
|
SPML_ERROR("fence failed");
|
|
oshmem_shmem_abort(-1);
|
|
return OSHMEM_ERROR;
|
|
}
|
|
return OSHMEM_SUCCESS;
|
|
}
|
|
|
|
int mca_spml_ucx_quiet(void)
|
|
{
|
|
ucs_status_t err;
|
|
|
|
err = ucp_worker_flush(mca_spml_ucx.ucp_worker);
|
|
if (UCS_OK != err) {
|
|
SPML_ERROR("fence failed");
|
|
oshmem_shmem_abort(-1);
|
|
return OSHMEM_ERROR;
|
|
}
|
|
return OSHMEM_SUCCESS;
|
|
}
|
|
|
|
/* blocking receive */
|
|
int mca_spml_ucx_recv(void* buf, size_t size, int src)
|
|
{
|
|
int rc = OSHMEM_SUCCESS;
|
|
|
|
rc = MCA_PML_CALL(recv(buf,
|
|
size,
|
|
&(ompi_mpi_unsigned_char.dt),
|
|
src,
|
|
0,
|
|
&(ompi_mpi_comm_world.comm),
|
|
NULL));
|
|
|
|
return rc;
|
|
}
|
|
|
|
/* for now only do blocking copy send */
|
|
int mca_spml_ucx_send(void* buf,
|
|
size_t size,
|
|
int dst,
|
|
mca_spml_base_put_mode_t mode)
|
|
{
|
|
int rc = OSHMEM_SUCCESS;
|
|
|
|
rc = MCA_PML_CALL(send(buf,
|
|
size,
|
|
&(ompi_mpi_unsigned_char.dt),
|
|
dst,
|
|
0,
|
|
(mca_pml_base_send_mode_t)mode,
|
|
&(ompi_mpi_comm_world.comm)));
|
|
|
|
return rc;
|
|
}
|