diff --git a/oshmem/mca/atomic/ucx/Makefile.am b/oshmem/mca/atomic/ucx/Makefile.am index d922456e72..a350eb02f3 100644 --- a/oshmem/mca/atomic/ucx/Makefile.am +++ b/oshmem/mca/atomic/ucx/Makefile.am @@ -35,7 +35,8 @@ mcacomponentdir = $(ompilibdir) mcacomponent_LTLIBRARIES = $(component_install) mca_atomic_ucx_la_SOURCES = $(ucx_sources) mca_atomic_ucx_la_LIBADD = $(top_builddir)/oshmem/liboshmem.la \ - $(atomic_ucx_LIBS) + $(atomic_ucx_LIBS) \ + $(OPAL_TOP_BUILDDIR)/opal/mca/common/ucx/lib@OPAL_LIB_PREFIX@mca_common_ucx.la mca_atomic_ucx_la_LDFLAGS = -module -avoid-version $(atomic_ucx_LDFLAGS) noinst_LTLIBRARIES = $(component_noinst) diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c index f6740f3897..1dbef56a2b 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c @@ -37,7 +37,7 @@ int mca_atomic_ucx_cswap_inner(void *target, uint64_t cmp; val = (4 == nlong) ? *(uint32_t*)value : *(uint64_t*)value; - ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva); + ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva, mca_spml_self); if (NULL == cond) { status_ptr = ucp_atomic_fetch_nb(mca_spml_self->ucp_peers[pe].ucp_conn, UCP_ATOMIC_FETCH_OP_SWAP, val, prev, nlong, diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c b/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c index f92c53ed22..37d8be0974 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_fadd.c @@ -40,7 +40,7 @@ int mca_atomic_ucx_fadd(void *target, return OSHMEM_ERROR; } - ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva); + ucx_mkey = mca_spml_ucx_get_mkey(pe, target, (void *)&rva, mca_spml_self); if (NULL == prev) { status = ucp_atomic_post(mca_spml_self->ucp_peers[pe].ucp_conn, UCP_ATOMIC_POST_OP_ADD, val, nlong, rva, diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index b549357b12..c9068fafad 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -44,6 +44,9 @@ #define SPML_UCX_PUT_DEBUG 0 #endif +static +spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(int pe, void *va, void **rva); + mca_spml_ucx_t mca_spml_ucx = { { /* Init mca_spml_base_module_t */ @@ -75,7 +78,9 @@ mca_spml_ucx_t mca_spml_ucx = { NULL, /* ucp_peers */ 0, /* using_mem_hooks */ 1, /* num_disconnect */ - 0 /* heap_reg_nb */ + 0, /* heap_reg_nb */ + 0, /* enabled */ + mca_spml_ucx_get_mkey_slow }; int mca_spml_ucx_enable(bool enable) @@ -330,6 +335,7 @@ error: } +static spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(int pe, void *va, void **rva) { sshmem_mkey_t *r_mkey; @@ -555,7 +561,7 @@ int mca_spml_ucx_get(void *src_addr, size_t size, void *dst_addr, int src) ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; - ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva); + ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva, &mca_spml_ucx); status = ucp_get(mca_spml_ucx.ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey); @@ -568,7 +574,7 @@ int mca_spml_ucx_get_nb(void *src_addr, size_t size, void *dst_addr, int src, vo ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; - ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva); + ucx_mkey = mca_spml_ucx_get_mkey(src, src_addr, &rva, &mca_spml_ucx); status = ucp_get_nbi(mca_spml_ucx.ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey); @@ -581,7 +587,7 @@ int mca_spml_ucx_put(void* dst_addr, size_t size, void* src_addr, int dst) ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; - ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva); + ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva, &mca_spml_ucx); status = ucp_put(mca_spml_ucx.ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey); @@ -594,7 +600,7 @@ int mca_spml_ucx_put_nb(void* dst_addr, size_t size, void* src_addr, int dst, vo ucs_status_t status; spml_ucx_mkey_t *ucx_mkey; - ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva); + ucx_mkey = mca_spml_ucx_get_mkey(dst, dst_addr, &rva, &mca_spml_ucx); status = ucp_put_nbi(mca_spml_ucx.ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey); diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index b57850414b..4aeed1481f 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -58,6 +58,8 @@ struct ucp_peer { }; typedef struct ucp_peer ucp_peer_t; +typedef spml_ucx_mkey_t * (*mca_spml_ucx_get_mkey_slow_fn_t)(int pe, void *va, void **rva); + struct mca_spml_ucx { mca_spml_base_module_t super; ucp_context_h ucp_context; @@ -68,6 +70,8 @@ struct mca_spml_ucx { int priority; /* component priority */ bool enabled; + + mca_spml_ucx_get_mkey_slow_fn_t get_mkey_slow; }; typedef struct mca_spml_ucx mca_spml_ucx_t; @@ -121,17 +125,16 @@ extern int mca_spml_ucx_quiet(void); extern int spml_ucx_progress(void); -spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(int pe, void *va, void **rva); - static inline spml_ucx_mkey_t * -mca_spml_ucx_get_mkey(int pe, void *va, void **rva) +mca_spml_ucx_get_mkey(int pe, void *va, void **rva, mca_spml_ucx_t* module) { spml_ucx_cached_mkey_t *mkey; - mkey = mca_spml_ucx.ucp_peers[pe].mkeys; + mkey = module->ucp_peers[pe].mkeys; mkey = (spml_ucx_cached_mkey_t *)map_segment_find_va(&mkey->super.super, sizeof(*mkey), va); if (OPAL_UNLIKELY(NULL == mkey)) { - return mca_spml_ucx_get_mkey_slow(pe, va, rva); + assert(module->get_mkey_slow); + return module->get_mkey_slow(pe, va, rva); } *rva = map_segment_va2rva(&mkey->super, va); return &mkey->key;