diff --git a/oshmem/mca/memheap/memheap.h b/oshmem/mca/memheap/memheap.h index 3492812a32..7cad1e9e3f 100644 --- a/oshmem/mca/memheap/memheap.h +++ b/oshmem/mca/memheap/memheap.h @@ -138,13 +138,18 @@ typedef struct mca_memheap_base_module_t mca_memheap_base_module_t; OSHMEM_DECLSPEC extern mca_memheap_base_module_t mca_memheap; +static inline int mca_memheap_base_mkey_is_shm(sshmem_mkey_t *mkey) +{ + return (0 == mkey->len) && (MAP_SEGMENT_SHM_INVALID != (int)mkey->u.key); +} + /** * check if memcpy() can be used to copy data to dst_addr * must be memheap address and segment must be mapped */ static inline int mca_memheap_base_can_local_copy(sshmem_mkey_t *mkey, void *dst_addr) { return mca_memheap.memheap_is_symmetric_addr(dst_addr) && - (0 == mkey->len) && (MAP_SEGMENT_SHM_INVALID != (int)mkey->u.key); + mca_memheap_base_mkey_is_shm(mkey); } diff --git a/oshmem/mca/spml/base/base.h b/oshmem/mca/spml/base/base.h index 4a0eb3e735..58025561ca 100644 --- a/oshmem/mca/spml/base/base.h +++ b/oshmem/mca/spml/base/base.h @@ -73,6 +73,8 @@ OSHMEM_DECLSPEC int mca_spml_base_oob_get_mkeys(int pe, OSHMEM_DECLSPEC void mca_spml_base_rmkey_unpack(sshmem_mkey_t *mkey, uint32_t seg, int pe, int tr_id); OSHMEM_DECLSPEC void mca_spml_base_rmkey_free(sshmem_mkey_t *mkey); +OSHMEM_DECLSPEC void *mca_spml_base_rmkey_ptr(const void *dst_addr, sshmem_mkey_t *mkey, int pe); + OSHMEM_DECLSPEC int mca_spml_base_put_nb(void *dst_addr, size_t size, void *src_addr, diff --git a/oshmem/mca/spml/base/spml_base.c b/oshmem/mca/spml/base/spml_base.c index ce156be4e6..75c0f71bb5 100644 --- a/oshmem/mca/spml/base/spml_base.c +++ b/oshmem/mca/spml/base/spml_base.c @@ -175,6 +175,11 @@ void mca_spml_base_rmkey_free(sshmem_mkey_t *mkey) { } +void *mca_spml_base_rmkey_ptr(const void *dst_addr, sshmem_mkey_t *mkey, int pe) +{ + return NULL; +} + int mca_spml_base_put_nb(void *dst_addr, size_t size, void *src_addr, int dst, void **handle) { diff --git a/oshmem/mca/spml/ikrit/spml_ikrit.c b/oshmem/mca/spml/ikrit/spml_ikrit.c index e9f89ce90b..d90099caf4 100644 --- a/oshmem/mca/spml/ikrit/spml_ikrit.c +++ b/oshmem/mca/spml/ikrit/spml_ikrit.c @@ -171,6 +171,7 @@ mca_spml_ikrit_t mca_spml_ikrit = { mca_spml_ikrit_fence, mca_spml_ikrit_cache_mkeys, mca_spml_base_rmkey_free, + mca_spml_base_rmkey_ptr, mca_spml_base_memuse_hook, (void*)&mca_spml_ikrit diff --git a/oshmem/mca/spml/spml.h b/oshmem/mca/spml/spml.h index 16b372b2dc..d919cc12be 100644 --- a/oshmem/mca/spml/spml.h +++ b/oshmem/mca/spml/spml.h @@ -120,6 +120,17 @@ typedef int (*mca_spml_base_module_wait_fn_t)(void* addr, */ typedef void (*mca_spml_base_module_mkey_unpack_fn_t)(sshmem_mkey_t *, uint32_t segno, int remote_pe, int tr_id); +/** + * If possible, get a pointer to the remote memory described by the mkey + * + * @param dst_addr address of the symmetric variable + * @param mkey remote memory key + * @param pe remote PE + * + * @return pointer to remote memory or NULL + */ +typedef void * (*mca_spml_base_module_mkey_ptr_fn_t)(const void *dst_addr, sshmem_mkey_t *mkey, int pe); + /** * free resources used by deserialized remote mkey * @@ -313,6 +324,7 @@ struct mca_spml_base_module_1_0_0_t { mca_spml_base_module_mkey_unpack_fn_t spml_rmkey_unpack; mca_spml_base_module_mkey_free_fn_t spml_rmkey_free; + mca_spml_base_module_mkey_ptr_fn_t spml_rmkey_ptr; mca_spml_base_module_memuse_hook_fn_t spml_memuse_hook; void *self; diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 1e7ac7f537..c371b56c6f 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -64,6 +64,7 @@ mca_spml_ucx_t mca_spml_ucx = { every spml */ mca_spml_ucx_rmkey_unpack, mca_spml_ucx_rmkey_free, + mca_spml_ucx_rmkey_ptr, mca_spml_ucx_memuse_hook, (void*)&mca_spml_ucx }, @@ -353,6 +354,23 @@ void mca_spml_ucx_rmkey_free(sshmem_mkey_t *mkey) ucp_rkey_destroy(ucx_mkey->rkey); } +void *mca_spml_ucx_rmkey_ptr(const void *dst_addr, sshmem_mkey_t *mkey, int pe) +{ +#if (((UCP_API_MAJOR >= 1) && (UCP_API_MINOR >= 3)) || (UCP_API_MAJOR >= 2)) + void *rva; + ucs_status_t err; + spml_ucx_mkey_t *ucx_mkey = (spml_ucx_mkey_t *)(mkey->spml_context); + + err = ucp_rkey_ptr(ucx_mkey->rkey, (uint64_t)dst_addr, &rva); + if (UCS_OK != err) { + return NULL; + } + return rva; +#else + return NULL; +#endif +} + static void mca_spml_ucx_cache_mkey(sshmem_mkey_t *mkey, uint32_t segno, int dst_pe) { ucp_peer_t *peer; diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index b524031d3f..b57850414b 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -112,6 +112,7 @@ extern void mca_spml_ucx_memuse_hook(void *addr, size_t length); extern void mca_spml_ucx_rmkey_unpack(sshmem_mkey_t *mkey, uint32_t segno, int pe, int tr_id); extern void mca_spml_ucx_rmkey_free(sshmem_mkey_t *mkey); +extern void *mca_spml_ucx_rmkey_ptr(const void *dst_addr, sshmem_mkey_t *, int pe); extern int mca_spml_ucx_add_procs(ompi_proc_t** procs, size_t nprocs); extern int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs); diff --git a/oshmem/shmem/c/shmem_ptr.c b/oshmem/shmem/c/shmem_ptr.c index 12413b29b9..35a324c221 100644 --- a/oshmem/shmem/c/shmem_ptr.c +++ b/oshmem/shmem/c/shmem_ptr.c @@ -19,6 +19,9 @@ #include "oshmem/shmem/shmem_api_logger.h" #include "oshmem/runtime/runtime.h" +#include "oshmem/mca/memheap/memheap.h" +#include "oshmem/mca/memheap/base/base.h" + #if OSHMEM_PROFILING #include "oshmem/include/pshmem.h" @@ -26,11 +29,43 @@ #include "oshmem/shmem/c/profile/defines.h" #endif -void *shmem_ptr(const void *ptr, int pe) +void *shmem_ptr(const void *dst_addr, int pe) { - SHMEM_API_VERBOSE(10, - "*************** WARNING!!! NOT SUPPORTED FUNCTION **********************\n" - "shmem_ptr() function is available only on systems where ordinary memory loads\n" - "and stores are used to implement OpenSHMEM put and get operations."); - return 0; + ompi_proc_t *proc; + sshmem_mkey_t *mkey; + int i; + void *rva; + + RUNTIME_CHECK_INIT(); + RUNTIME_CHECK_PE(pe); + RUNTIME_CHECK_ADDR(dst_addr); + + /* process can access its own memory */ + if (pe == oshmem_my_proc_id()) { + return (void *)dst_addr; + } + + /* The memory must be on the local node */ + proc = oshmem_proc_group_find(oshmem_group_all, pe); + if (!OPAL_PROC_ON_LOCAL_NODE(proc->super.proc_flags)) { + return NULL; + } + + for (i = 0; i < mca_memheap_base_num_transports(); i++) { + mkey = mca_memheap_base_get_cached_mkey(pe, (void *)dst_addr, i, &rva); + if (!mkey) { + continue; + } + + if (mca_memheap_base_mkey_is_shm(mkey)) { + return rva; + } + + rva = MCA_SPML_CALL(rmkey_ptr(dst_addr, mkey, pe)); + if (rva != NULL) { + return rva; + } + } + + return NULL; }