diff --git a/oshmem/include/pshmemx.h b/oshmem/include/pshmemx.h index 5a0f7f5a95..0b4ffcbd20 100644 --- a/oshmem/include/pshmemx.h +++ b/oshmem/include/pshmemx.h @@ -16,6 +16,11 @@ extern "C" { #endif +/* + * Symmetric heap routines + */ +OSHMEM_DECLSPEC void* pshmemx_malloc_with_hint(size_t size, long hint); + /* * Legacy API diff --git a/oshmem/include/shmemx.h b/oshmem/include/shmemx.h index d99ca11533..f7e7de6829 100644 --- a/oshmem/include/shmemx.h +++ b/oshmem/include/shmemx.h @@ -18,11 +18,29 @@ extern "C" { #endif +enum { + SHMEM_HINT_NONE = 0, + SHMEM_HINT_LOW_LAT_MEM = 1 << 0, + SHMEM_HINT_HIGH_BW_MEM = 1 << 1, + SHMEM_HINT_NEAR_NIC_MEM = 1 << 2, + SHMEM_HINT_DEVICE_GPU_MEM = 1 << 3, + SHMEM_HINT_DEVICE_NIC_MEM = 1 << 4, + + SHMEM_HINT_PSYNC = 1 << 16, + SHMEM_HINT_PWORK = 1 << 17, + SHMEM_HINT_ATOMICS = 1 << 18 +}; + /* * All OpenSHMEM extension APIs that are not part of this specification must be defined in the shmemx.h include * file. These extensions shall use the shmemx_ prefix for all routine, variable, and constant names. */ +/* + * Symmetric heap routines + */ +OSHMEM_DECLSPEC void* shmemx_malloc_with_hint(size_t size, long hint); + /* * Elemental put routines */ diff --git a/oshmem/mca/memheap/base/base.h b/oshmem/mca/memheap/base/base.h index 7178685f0a..a91a03ae0c 100644 --- a/oshmem/mca/memheap/base/base.h +++ b/oshmem/mca/memheap/base/base.h @@ -41,14 +41,17 @@ OSHMEM_DECLSPEC int mca_memheap_base_select(void); extern int mca_memheap_base_already_opened; extern int mca_memheap_base_key_exchange; -#define MCA_MEMHEAP_MAX_SEGMENTS 4 -#define HEAP_SEG_INDEX 0 -#define SYMB_SEG_INDEX 1 -#define MCA_MEMHEAP_SEG_COUNT (SYMB_SEG_INDEX+1) +#define MCA_MEMHEAP_MAX_SEGMENTS 8 +#define HEAP_SEG_INDEX 0 #define MEMHEAP_SEG_INVALID 0xFFFF +typedef struct mca_memheap_base_config { + long device_nic_mem_seg_size; /* Used for SHMEM_HINT_DEVICE_NIC_MEM */ +} mca_memheap_base_config_t; + + typedef struct mca_memheap_map { map_segment_t mem_segs[MCA_MEMHEAP_MAX_SEGMENTS]; /* TODO: change into pointer array */ int n_segments; @@ -56,8 +59,9 @@ typedef struct mca_memheap_map { } mca_memheap_map_t; extern mca_memheap_map_t mca_memheap_base_map; +extern mca_memheap_base_config_t mca_memheap_base_config; -int mca_memheap_base_alloc_init(mca_memheap_map_t *, size_t); +int mca_memheap_base_alloc_init(mca_memheap_map_t *, size_t, long); void mca_memheap_base_alloc_exit(mca_memheap_map_t *); int mca_memheap_base_static_init(mca_memheap_map_t *); void mca_memheap_base_static_exit(mca_memheap_map_t *); @@ -173,10 +177,12 @@ static inline int memheap_is_va_in_segment(void *va, int segno) static inline int memheap_find_segnum(void *va) { - if (OPAL_LIKELY(memheap_is_va_in_segment(va, SYMB_SEG_INDEX))) { - return SYMB_SEG_INDEX; - } else if (memheap_is_va_in_segment(va, HEAP_SEG_INDEX)) { - return HEAP_SEG_INDEX; + int i; + + for (i = 0; i < mca_memheap_base_map.n_segments; i++) { + if (memheap_is_va_in_segment(va, i)) { + return i; + } } return MEMHEAP_SEG_INVALID; } @@ -193,18 +199,17 @@ static inline void *map_segment_va2rva(mkey_segment_t *seg, void *va) return memheap_va2rva(va, seg->super.va_base, seg->rva_base); } -static inline map_base_segment_t *map_segment_find_va(map_base_segment_t *segs, size_t elem_size, void *va) +static inline map_base_segment_t *map_segment_find_va(map_base_segment_t *segs, + size_t elem_size, void *va) { map_base_segment_t *rseg; + int i; - rseg = (map_base_segment_t *)((char *)segs + elem_size * HEAP_SEG_INDEX); - if (OPAL_LIKELY(map_segment_is_va_in(rseg, va))) { - return rseg; - } - - rseg = (map_base_segment_t *)((char *)segs + elem_size * SYMB_SEG_INDEX); - if (OPAL_LIKELY(map_segment_is_va_in(rseg, va))) { - return rseg; + for (i = 0; i < MCA_MEMHEAP_MAX_SEGMENTS; i++) { + rseg = (map_base_segment_t *)((char *)segs + elem_size * i); + if (OPAL_LIKELY(map_segment_is_va_in(rseg, va))) { + return rseg; + } } return NULL; @@ -214,21 +219,14 @@ void mkey_segment_init(mkey_segment_t *seg, sshmem_mkey_t *mkey, uint32_t segno) static inline map_segment_t *memheap_find_va(void* va) { - map_segment_t *s; + map_segment_t *s = NULL; + int i; - /* most probably there will be only two segments: heap and global data */ - if (OPAL_LIKELY(memheap_is_va_in_segment(va, SYMB_SEG_INDEX))) { - s = &memheap_map->mem_segs[SYMB_SEG_INDEX]; - } else if (memheap_is_va_in_segment(va, HEAP_SEG_INDEX)) { - s = &memheap_map->mem_segs[HEAP_SEG_INDEX]; - } else if (memheap_map->n_segments - 2 > 0) { - s = bsearch(va, - &memheap_map->mem_segs[SYMB_SEG_INDEX+1], - memheap_map->n_segments - 2, - sizeof(*s), - mca_memheap_seg_cmp); - } else { - s = NULL; + for (i = 0; i < memheap_map->n_segments; i++) { + if (memheap_is_va_in_segment(va, i)) { + s = &memheap_map->mem_segs[i]; + break; + } } #if MEMHEAP_BASE_DEBUG == 1 diff --git a/oshmem/mca/memheap/base/memheap_base_alloc.c b/oshmem/mca/memheap/base/memheap_base_alloc.c index 341eec97a9..b83499f250 100644 --- a/oshmem/mca/memheap/base/memheap_base_alloc.c +++ b/oshmem/mca/memheap/base/memheap_base_alloc.c @@ -19,17 +19,21 @@ #include "oshmem/mca/memheap/base/base.h" -int mca_memheap_base_alloc_init(mca_memheap_map_t *map, size_t size) +int mca_memheap_base_alloc_init(mca_memheap_map_t *map, size_t size, long hint) { int ret = OSHMEM_SUCCESS; char * seg_filename = NULL; assert(map); - assert(HEAP_SEG_INDEX == map->n_segments); + if (hint == 0) { + assert(HEAP_SEG_INDEX == map->n_segments); + } else { + assert(HEAP_SEG_INDEX < map->n_segments); + } map_segment_t *s = &map->mem_segs[map->n_segments]; seg_filename = oshmem_get_unique_file_name(oshmem_my_proc_id()); - ret = mca_sshmem_segment_create(s, seg_filename, size); + ret = mca_sshmem_segment_create(s, seg_filename, size, hint); if (OSHMEM_SUCCESS == ret) { map->n_segments++; @@ -45,12 +49,34 @@ int mca_memheap_base_alloc_init(mca_memheap_map_t *map, size_t size) void mca_memheap_base_alloc_exit(mca_memheap_map_t *map) { - if (map) { - map_segment_t *s = &map->mem_segs[HEAP_SEG_INDEX]; + int i; - assert(s); + if (!map) { + return; + } - mca_sshmem_segment_detach(s, NULL); - mca_sshmem_unlink(s); + for (i = 0; i < map->n_segments; ++i) { + map_segment_t *s = &map->mem_segs[i]; + if (s->type != MAP_SEGMENT_STATIC) { + mca_sshmem_segment_detach(s, NULL); + mca_sshmem_unlink(s); + } } } + +int mca_memheap_alloc_with_hint(size_t size, long hint, void** ptr) +{ + int i; + + for (i = 0; i < mca_memheap_base_map.n_segments; i++) { + map_segment_t *s = &mca_memheap_base_map.mem_segs[i]; + if (s->allocator && (hint && s->alloc_hints)) { + /* Do not fall back to default allocator since it will break the + * symmetry between PEs + */ + return s->allocator->realloc(s, size, NULL, ptr); + } + } + + return MCA_MEMHEAP_CALL(alloc(size, ptr)); +} diff --git a/oshmem/mca/memheap/base/memheap_base_frame.c b/oshmem/mca/memheap/base/memheap_base_frame.c index 6f4d3c75b2..23ebf0860d 100644 --- a/oshmem/mca/memheap/base/memheap_base_frame.c +++ b/oshmem/mca/memheap/base/memheap_base_frame.c @@ -52,6 +52,12 @@ static int mca_memheap_base_register(mca_base_register_flag_t flags) MCA_BASE_VAR_SCOPE_READONLY, &mca_memheap_base_key_exchange); + mca_base_var_register("oshmem", "memheap", "base", "device_nic_mem_seg_size", + "Size of memory block used for allocations with hint SHMEM_HINT_DEVICE_NIC_MEM", + MCA_BASE_VAR_TYPE_LONG, NULL, 0, + MCA_BASE_VAR_FLAG_SETTABLE, OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_memheap_base_config.device_nic_mem_seg_size); return OSHMEM_SUCCESS; } diff --git a/oshmem/mca/memheap/base/memheap_base_mkey.c b/oshmem/mca/memheap/base/memheap_base_mkey.c index fea00694ba..0245abdd1d 100644 --- a/oshmem/mca/memheap/base/memheap_base_mkey.c +++ b/oshmem/mca/memheap/base/memheap_base_mkey.c @@ -749,7 +749,7 @@ void mkey_segment_init(mkey_segment_t *seg, sshmem_mkey_t *mkey, uint32_t segno) { map_segment_t *s; - if (segno >= MCA_MEMHEAP_SEG_COUNT) { + if (segno >= MCA_MEMHEAP_MAX_SEGMENTS) { return; } diff --git a/oshmem/mca/memheap/base/memheap_base_select.c b/oshmem/mca/memheap/base/memheap_base_select.c index e0c1c3a638..f3bd740837 100644 --- a/oshmem/mca/memheap/base/memheap_base_select.c +++ b/oshmem/mca/memheap/base/memheap_base_select.c @@ -21,6 +21,14 @@ #include "oshmem/util/oshmem_util.h" #include "oshmem/mca/memheap/memheap.h" #include "oshmem/mca/memheap/base/base.h" +#include "orte/mca/errmgr/errmgr.h" +#include "oshmem/include/shmemx.h" +#include "oshmem/mca/sshmem/base/base.h" + + +mca_memheap_base_config_t mca_memheap_base_config = { + .device_nic_mem_seg_size = 0 +}; mca_memheap_base_module_t mca_memheap = {0}; @@ -94,7 +102,7 @@ static memheap_context_t* _memheap_create(void) { int rc = OSHMEM_SUCCESS; static memheap_context_t context; - size_t user_size; + size_t user_size, size; user_size = _memheap_size(); if (user_size < MEMHEAP_BASE_MIN_SIZE) { @@ -105,7 +113,18 @@ static memheap_context_t* _memheap_create(void) /* Inititialize symmetric area */ if (OSHMEM_SUCCESS == rc) { rc = mca_memheap_base_alloc_init(&mca_memheap_base_map, - user_size + MEMHEAP_BASE_PRIVATE_SIZE); + user_size + MEMHEAP_BASE_PRIVATE_SIZE, 0); + } + + /* Initialize atomic symmetric area */ + size = mca_memheap_base_config.device_nic_mem_seg_size; + if ((OSHMEM_SUCCESS == rc) && (size > 0)) { + rc = mca_memheap_base_alloc_init(&mca_memheap_base_map, size, + SHMEM_HINT_DEVICE_NIC_MEM); + if (rc == OSHMEM_ERR_NOT_IMPLEMENTED) { + /* do not treat NOT_IMPLEMENTED as error */ + rc = OSHMEM_SUCCESS; + } } /* Inititialize static/global variables area */ diff --git a/oshmem/mca/memheap/base/memheap_base_static.c b/oshmem/mca/memheap/base/memheap_base_static.c index edbb11aa31..4e97253a9e 100644 --- a/oshmem/mca/memheap/base/memheap_base_static.c +++ b/oshmem/mca/memheap/base/memheap_base_static.c @@ -49,7 +49,7 @@ int mca_memheap_base_static_init(mca_memheap_map_t *map) int ret = OSHMEM_SUCCESS; assert(map); - assert(SYMB_SEG_INDEX <= map->n_segments); + assert(HEAP_SEG_INDEX < map->n_segments); ret = _load_segments(); diff --git a/oshmem/mca/memheap/memheap.h b/oshmem/mca/memheap/memheap.h index 7cad1e9e3f..07c4e2f2f0 100644 --- a/oshmem/mca/memheap/memheap.h +++ b/oshmem/mca/memheap/memheap.h @@ -138,6 +138,8 @@ typedef struct mca_memheap_base_module_t mca_memheap_base_module_t; OSHMEM_DECLSPEC extern mca_memheap_base_module_t mca_memheap; +int mca_memheap_alloc_with_hint(size_t size, long hint, void**); + static inline int mca_memheap_base_mkey_is_shm(sshmem_mkey_t *mkey) { return (0 == mkey->len) && (MAP_SEGMENT_SHM_INVALID != (int)mkey->u.key); diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index fbc30d4e9f..5d36235399 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -36,6 +36,7 @@ #include "oshmem/runtime/runtime.h" #include "oshmem/mca/spml/ucx/spml_ucx_component.h" +#include "oshmem/mca/sshmem/ucx/sshmem_ucx.h" /* Turn ON/OFF debug output from build (default 0) */ #ifndef SPML_UCX_PUT_DEBUG @@ -267,7 +268,7 @@ int mca_spml_ucx_add_procs(ompi_proc_t** procs, size_t nprocs) OSHMEM_PROC_DATA(procs[i])->num_transports = 1; OSHMEM_PROC_DATA(procs[i])->transport_ids = spml_ucx_transport_ids; - for (j = 0; j < MCA_MEMHEAP_SEG_COUNT; j++) { + for (j = 0; j < MCA_MEMHEAP_MAX_SEGMENTS; j++) { mca_spml_ucx_ctx_default.ucp_peers[i].mkeys[j].key.rkey = NULL; } @@ -438,7 +439,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr, } } else { - ucx_mkey->mem_h = (ucp_mem_h)mem_seg->context; + mca_sshmem_ucx_segment_context_t *ctx = mem_seg->context; + ucx_mkey->mem_h = ctx->ucp_memh; } status = ucp_rkey_pack(mca_spml_ucx.ucp_context, ucx_mkey->mem_h, @@ -589,17 +591,19 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx goto error2; } - for (j = 0; j < MCA_MEMHEAP_SEG_COUNT; j++) { + for (j = 0; j < memheap_map->n_segments; j++) { mkey = &memheap_map->mem_segs[j].mkeys_cache[i][0]; ucx_mkey = &ucx_ctx->ucp_peers[i].mkeys[j].key; - err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[i].ucp_conn, - mkey->u.data, - &ucx_mkey->rkey); - if (UCS_OK != err) { - SPML_UCX_ERROR("failed to unpack rkey"); - goto error2; + if (mkey->u.data) { + err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[i].ucp_conn, + mkey->u.data, + &ucx_mkey->rkey); + if (UCS_OK != err) { + SPML_UCX_ERROR("failed to unpack rkey"); + goto error2; + } + mca_spml_ucx_cache_mkey(ucx_ctx, mkey, j, i); } - mca_spml_ucx_cache_mkey(ucx_ctx, mkey, j, i); } } @@ -747,6 +751,8 @@ int mca_spml_ucx_fence(shmem_ctx_t ctx) ucs_status_t err; mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; + opal_atomic_wmb(); + err = ucp_worker_fence(ucx_ctx->ucp_worker); if (UCS_OK != err) { SPML_UCX_ERROR("fence failed: %s", ucs_status_string(err)); @@ -761,6 +767,8 @@ int mca_spml_ucx_quiet(shmem_ctx_t ctx) int ret; mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; + opal_atomic_wmb(); + ret = opal_common_ucx_worker_flush(ucx_ctx->ucp_worker); if (OMPI_SUCCESS != ret) { oshmem_shmem_abort(-1); diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index 5f05c5c87f..071b222462 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -61,7 +61,7 @@ typedef struct spml_ucx_cached_mkey spml_ucx_cached_mkey_t; struct ucp_peer { ucp_ep_h ucp_conn; - spml_ucx_cached_mkey_t mkeys[MCA_MEMHEAP_SEG_COUNT]; + spml_ucx_cached_mkey_t mkeys[MCA_MEMHEAP_MAX_SEGMENTS]; }; typedef struct ucp_peer ucp_peer_t; diff --git a/oshmem/mca/spml/ucx/spml_ucx_component.c b/oshmem/mca/spml/ucx/spml_ucx_component.c index 273c41a09c..6f842991a9 100644 --- a/oshmem/mca/spml/ucx/spml_ucx_component.c +++ b/oshmem/mca/spml/ucx/spml_ucx_component.c @@ -314,7 +314,7 @@ static void _ctx_cleanup(mca_spml_ucx_ctx_t *ctx) del_procs = malloc(sizeof(*del_procs) * nprocs); for (i = 0; i < nprocs; ++i) { - for (j = 0; j < MCA_MEMHEAP_SEG_COUNT; j++) { + for (j = 0; j < memheap_map->n_segments; j++) { if (ctx->ucp_peers[i].mkeys[j].key.rkey != NULL) { ucp_rkey_destroy(ctx->ucp_peers[i].mkeys[j].key.rkey); } diff --git a/oshmem/mca/sshmem/base/base.h b/oshmem/mca/sshmem/base/base.h index 94c5e28e5e..44dfc39e82 100644 --- a/oshmem/mca/sshmem/base/base.h +++ b/oshmem/mca/sshmem/base/base.h @@ -31,7 +31,7 @@ extern char* mca_sshmem_base_backing_file_dir; OSHMEM_DECLSPEC int mca_sshmem_segment_create(map_segment_t *ds_buf, const char *file_name, - size_t size); + size_t size, long hint); OSHMEM_DECLSPEC void * mca_sshmem_segment_attach(map_segment_t *ds_buf, sshmem_mkey_t *mkey); diff --git a/oshmem/mca/sshmem/base/sshmem_base_wrappers.c b/oshmem/mca/sshmem/base/sshmem_base_wrappers.c index d8bc64fe25..0bc60bfcce 100644 --- a/oshmem/mca/sshmem/base/sshmem_base_wrappers.c +++ b/oshmem/mca/sshmem/base/sshmem_base_wrappers.c @@ -20,13 +20,13 @@ int mca_sshmem_segment_create(map_segment_t *ds_buf, const char *file_name, - size_t size) + size_t size, long hint) { if (!mca_sshmem_base_selected) { return OSHMEM_ERROR; } - return mca_sshmem_base_module->segment_create(ds_buf, file_name, size); + return mca_sshmem_base_module->segment_create(ds_buf, file_name, size, hint); } void * diff --git a/oshmem/mca/sshmem/mmap/sshmem_mmap_module.c b/oshmem/mca/sshmem/mmap/sshmem_mmap_module.c index 1afaaf400b..3414c2e9f2 100644 --- a/oshmem/mca/sshmem/mmap/sshmem_mmap_module.c +++ b/oshmem/mca/sshmem/mmap/sshmem_mmap_module.c @@ -63,7 +63,7 @@ module_init(void); static int segment_create(map_segment_t *ds_buf, const char *file_name, - size_t size); + size_t size, long hint); static void * segment_attach(map_segment_t *ds_buf, sshmem_mkey_t *mkey); @@ -112,7 +112,7 @@ module_finalize(void) static int segment_create(map_segment_t *ds_buf, const char *file_name, - size_t size) + size_t size, long hint) { int rc = OSHMEM_SUCCESS; void *addr = NULL; diff --git a/oshmem/mca/sshmem/sshmem.h b/oshmem/mca/sshmem/sshmem.h index a2b570aab8..8ba1057492 100644 --- a/oshmem/mca/sshmem/sshmem.h +++ b/oshmem/mca/sshmem/sshmem.h @@ -83,14 +83,19 @@ typedef int * @param file_name file_name unique string identifier that must be a valid, * writable path (IN). * + * @param address address to attach the segment at, or 0 allocate + * any available address in the process. + * * @param size size of the shared memory segment. * + * @param hint hint of the shared memory segment. + * * @return OSHMEM_SUCCESS on success. */ typedef int (*mca_sshmem_base_module_segment_create_fn_t)(map_segment_t *ds_buf, const char *file_name, - size_t size); + size_t size, long hint); /** * attach to an existing shared memory segment initialized by segment_create. diff --git a/oshmem/mca/sshmem/sshmem_types.h b/oshmem/mca/sshmem/sshmem_types.h index ccdf8995b5..4e1d937901 100644 --- a/oshmem/mca/sshmem/sshmem_types.h +++ b/oshmem/mca/sshmem/sshmem_types.h @@ -107,6 +107,8 @@ typedef struct mkey_segment { void *rva_base; /* base va on remote pe */ } mkey_segment_t; +typedef struct segment_allocator segment_allocator_t; + typedef struct map_segment { map_base_segment_t super; sshmem_mkey_t **mkeys_cache; /* includes remote segment bases in va_base */ @@ -115,10 +117,17 @@ typedef struct map_segment { int seg_id; size_t seg_size; /* length of the segment */ segment_type_t type; /* type of the segment */ + long alloc_hints; /* allocation hints this segment supports */ void *context; /* allocator can use this field to store its own private data */ + segment_allocator_t *allocator; /* segment-specific allocator */ } map_segment_t; +struct segment_allocator { + int (*realloc)(map_segment_t*, size_t newsize, void *, void **); + int (*free)(map_segment_t*, void*); +}; + END_C_DECLS #endif /* MCA_SSHMEM_TYPES_H */ diff --git a/oshmem/mca/sshmem/sysv/sshmem_sysv_module.c b/oshmem/mca/sshmem/sysv/sshmem_sysv_module.c index 56fd4df00d..bde8ed5ac7 100644 --- a/oshmem/mca/sshmem/sysv/sshmem_sysv_module.c +++ b/oshmem/mca/sshmem/sysv/sshmem_sysv_module.c @@ -61,7 +61,7 @@ module_init(void); static int segment_create(map_segment_t *ds_buf, const char *file_name, - size_t size); + size_t size, long hint); static void * segment_attach(map_segment_t *ds_buf, sshmem_mkey_t *mkey); @@ -110,7 +110,7 @@ module_finalize(void) static int segment_create(map_segment_t *ds_buf, const char *file_name, - size_t size) + size_t size, long hint) { int rc = OSHMEM_SUCCESS; void *addr = NULL; diff --git a/oshmem/mca/sshmem/ucx/Makefile.am b/oshmem/mca/sshmem/ucx/Makefile.am index bf3a08b547..ce37cd0e90 100644 --- a/oshmem/mca/sshmem/ucx/Makefile.am +++ b/oshmem/mca/sshmem/ucx/Makefile.am @@ -15,7 +15,8 @@ AM_CPPFLAGS = $(sshmem_ucx_CPPFLAGS) sources = \ sshmem_ucx.h \ sshmem_ucx_component.c \ - sshmem_ucx_module.c + sshmem_ucx_module.c \ + sshmem_ucx_shadow.c # Make the output library in this directory, and name it either # mca__.la (for DSO builds) or libmca__.la diff --git a/oshmem/mca/sshmem/ucx/configure.m4 b/oshmem/mca/sshmem/ucx/configure.m4 index aafa4f4e02..7448b2dadf 100644 --- a/oshmem/mca/sshmem/ucx/configure.m4 +++ b/oshmem/mca/sshmem/ucx/configure.m4 @@ -22,6 +22,40 @@ AC_DEFUN([MCA_oshmem_sshmem_ucx_CONFIG],[ [$1], [$2]) + # Check for UCX device memory allocation support + save_LDFLAGS="$LDFLAGS" + save_LIBS="$LIBS" + save_CPPFLAGS="$CPPFLAGS" + + alloc_dm_LDFLAGS=" -L$ompi_check_ucx_libdir/ucx" + alloc_dm_LIBS=" -luct_ib" + CPPFLAGS+=" $sshmem_ucx_CPPFLAGS" + LDFLAGS+=" $sshmem_ucx_LDFLAGS $alloc_dm_LDFLAGS" + LIBS+=" $sshmem_ucx_LIBS $alloc_dm_LIBS" + + AC_LANG_PUSH([C]) + AC_LINK_IFELSE([AC_LANG_PROGRAM( + [[ + #include + #include + ]], + [[ + uct_md_h md = ucp_context_find_tl_md((ucp_context_h)NULL, ""); + (void)uct_ib_md_alloc_device_mem(md, NULL, NULL, 0, "", NULL); + uct_ib_md_release_device_mem(NULL); + ]])], + [ + AC_MSG_NOTICE([UCX device memory allocation is supported]) + AC_DEFINE([HAVE_UCX_DEVICE_MEM], [1], [Support for device memory allocation]) + sshmem_ucx_LIBS+=" $alloc_dm_LIBS" + sshmem_ucx_LDFLAGS+=" $alloc_dm_LDFLAGS" + ], + [AC_MSG_NOTICE([UCX device memory allocation is not supported])]) + AC_LANG_POP([C]) + + CPPFLAGS="$save_CPPFLAGS" + LDFLAGS="$save_LDFLAGS" + LIBS="$save_LIBS" # substitute in the things needed to build ucx AC_SUBST([sshmem_ucx_CFLAGS]) diff --git a/oshmem/mca/sshmem/ucx/sshmem_ucx.h b/oshmem/mca/sshmem/ucx/sshmem_ucx.h index 0b625fcc46..f171fe641b 100644 --- a/oshmem/mca/sshmem/ucx/sshmem_ucx.h +++ b/oshmem/mca/sshmem/ucx/sshmem_ucx.h @@ -15,8 +15,12 @@ #include "oshmem/mca/sshmem/sshmem.h" +#include + BEGIN_C_DECLS +typedef struct sshmem_ucx_shadow_allocator sshmem_ucx_shadow_allocator_t; + /** * globally exported variable to hold the ucx component. */ @@ -30,11 +34,26 @@ typedef struct mca_sshmem_ucx_component_t { OSHMEM_MODULE_DECLSPEC extern mca_sshmem_ucx_component_t mca_sshmem_ucx_component; +typedef struct mca_sshmem_ucx_segment_context { + void *dev_mem; + sshmem_ucx_shadow_allocator_t *shadow_allocator; + ucp_mem_h ucp_memh; +} mca_sshmem_ucx_segment_context_t; + typedef struct mca_sshmem_ucx_module_t { mca_sshmem_base_module_t super; } mca_sshmem_ucx_module_t; extern mca_sshmem_ucx_module_t mca_sshmem_ucx_module; +sshmem_ucx_shadow_allocator_t *sshmem_ucx_shadow_create(unsigned count); +void sshmem_ucx_shadow_destroy(sshmem_ucx_shadow_allocator_t *allocator); +int sshmem_ucx_shadow_alloc(sshmem_ucx_shadow_allocator_t *allocator, + unsigned count, unsigned *index); +int sshmem_ucx_shadow_free(sshmem_ucx_shadow_allocator_t *allocator, + unsigned index); +size_t sshmem_ucx_shadow_size(sshmem_ucx_shadow_allocator_t *allocator, + unsigned index); + END_C_DECLS #endif /* MCA_SHMEM_UCX_EXPORT_H */ diff --git a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c index 90ee1704dc..244eb7a169 100644 --- a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c +++ b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c @@ -19,12 +19,24 @@ #include "oshmem/proc/proc.h" #include "oshmem/mca/sshmem/sshmem.h" +#include "oshmem/include/shmemx.h" #include "oshmem/mca/sshmem/base/base.h" #include "oshmem/util/oshmem_util.h" #include "oshmem/mca/spml/ucx/spml_ucx.h" #include "sshmem_ucx.h" +//#include + +#if HAVE_UCX_DEVICE_MEM +#include +#include +#endif + +#define ALLOC_ELEM_SIZE sizeof(uint64_t) +#define min(a,b) ((a) < (b) ? (a) : (b)) +#define max(a,b) ((a) > (b) ? (a) : (b)) + /* ////////////////////////////////////////////////////////////////////////// */ /*local functions */ /* local functions */ @@ -34,7 +46,7 @@ module_init(void); static int segment_create(map_segment_t *ds_buf, const char *file_name, - size_t size); + size_t size, long hint); static void * segment_attach(map_segment_t *ds_buf, sshmem_mkey_t *mkey); @@ -48,6 +60,11 @@ segment_unlink(map_segment_t *ds_buf); static int module_finalize(void); +static int sshmem_ucx_memheap_realloc(map_segment_t *s, size_t size, + void* old_ptr, void** new_ptr); + +static int sshmem_ucx_memheap_free(map_segment_t *s, void* ptr); + /* * ucx shmem module */ @@ -80,13 +97,18 @@ module_finalize(void) /* ////////////////////////////////////////////////////////////////////////// */ +static segment_allocator_t sshmem_ucx_allocator = { + .realloc = sshmem_ucx_memheap_realloc, + .free = sshmem_ucx_memheap_free +}; + static int -segment_create(map_segment_t *ds_buf, - const char *file_name, - size_t size) +segment_create_internal(map_segment_t *ds_buf, void *address, size_t size, + unsigned flags, long hint, void *dev_mem) { + mca_sshmem_ucx_segment_context_t *ctx; int rc = OSHMEM_SUCCESS; - mca_spml_ucx_t *spml = (mca_spml_ucx_t *)mca_spml.self; + mca_spml_ucx_t *spml = (mca_spml_ucx_t*)mca_spml.self; ucp_mem_map_params_t mem_map_params; ucp_mem_h mem_h; ucs_status_t status; @@ -100,25 +122,51 @@ segment_create(map_segment_t *ds_buf, UCP_MEM_MAP_PARAM_FIELD_LENGTH | UCP_MEM_MAP_PARAM_FIELD_FLAGS; - mem_map_params.address = (void *)mca_sshmem_base_start_address; + mem_map_params.address = address; mem_map_params.length = size; - mem_map_params.flags = UCP_MEM_MAP_ALLOCATE|UCP_MEM_MAP_FIXED; - - if (spml->heap_reg_nb) { - mem_map_params.flags |= UCP_MEM_MAP_NONBLOCK; - } + mem_map_params.flags = flags; status = ucp_mem_map(spml->ucp_context, &mem_map_params, &mem_h); if (UCS_OK != status) { + SSHMEM_ERROR("ucp_mem_map() failed: %s\n", ucs_status_string(status)); rc = OSHMEM_ERROR; goto out; } - ds_buf->super.va_base = mem_map_params.address; + if (!(flags & UCP_MEM_MAP_FIXED)) { + /* Memory was allocated at an arbitrary address; obtain it */ + ucp_mem_attr_t mem_attr; + mem_attr.field_mask = UCP_MEM_ATTR_FIELD_ADDRESS; + status = ucp_mem_query(mem_h, &mem_attr); + if (status != UCS_OK) { + SSHMEM_ERROR("ucp_mem_query() failed: %s\n", ucs_status_string(status)); + ucp_mem_unmap(spml->ucp_context, mem_h); + rc = OSHMEM_ERROR; + goto out; + } + + ds_buf->super.va_base = mem_attr.address; + } else { + ds_buf->super.va_base = mem_map_params.address; + } + + ctx = calloc(1, sizeof(*ctx)); + if (!ctx) { + ucp_mem_unmap(spml->ucp_context, mem_h); + rc = OSHMEM_ERR_OUT_OF_RESOURCE; + goto out; + } + ds_buf->seg_size = size; ds_buf->super.va_end = (void*)((uintptr_t)ds_buf->super.va_base + ds_buf->seg_size); - ds_buf->context = mem_h; + ds_buf->context = ctx; ds_buf->type = MAP_SEGMENT_ALLOC_UCX; + ds_buf->alloc_hints = hint; + ctx->ucp_memh = mem_h; + ctx->dev_mem = dev_mem; + if (hint) { + ds_buf->allocator = &sshmem_ucx_allocator; + } out: OPAL_OUTPUT_VERBOSE( @@ -133,6 +181,84 @@ out: return rc; } +#if HAVE_UCX_DEVICE_MEM +static uct_ib_device_mem_h alloc_device_mem(mca_spml_ucx_t *spml, size_t size, + void **address_p) +{ + uct_ib_device_mem_h dev_mem = NULL; + ucs_status_t status; + uct_md_h uct_md; + void *address; + size_t length; + int ret; + + uct_md = ucp_context_find_tl_md(spml->ucp_context, "mlx5"); + if (uct_md == NULL) { + SSHMEM_VERBOSE(1, "ucp_context_find_tl_md() returned NULL\n"); + return NULL; + } + + /* If found a matching memory domain, allocate device memory on it */ + length = size; + address = NULL; + status = uct_ib_md_alloc_device_mem(uct_md, &length, &address, + UCT_MD_MEM_ACCESS_ALL, "sshmem_seg", + &dev_mem); + if (status != UCS_OK) { + /* If could not allocate device memory - fallback to mmap (since some + * PEs in the job may succeed and while others failed */ + SSHMEM_VERBOSE(1, "uct_ib_md_alloc_dm() failed: %s\n", + ucs_status_string(status)); + return NULL; + } + + SSHMEM_VERBOSE(3, "uct_ib_md_alloc_dm() returned address %p\n", address); + *address_p = address; + return dev_mem; +} +#endif + +static int +segment_create(map_segment_t *ds_buf, + const char *file_name, + size_t size, long hint) +{ + mca_spml_ucx_t *spml = (mca_spml_ucx_t*)mca_spml.self; + unsigned flags; + int ret; + +#if HAVE_UCX_DEVICE_MEM + if (hint & SHMEM_HINT_DEVICE_NIC_MEM) { + if (size > UINT_MAX) { + return OSHMEM_ERR_BAD_PARAM; + } + + void *dev_mem_address; + uct_ib_device_mem_h dev_mem = alloc_device_mem(spml, size, + &dev_mem_address); + if (dev_mem != NULL) { + ret = segment_create_internal(ds_buf, dev_mem_address, size, 0, + hint, dev_mem); + if (ret == OSHMEM_SUCCESS) { + return OSHMEM_SUCCESS; + } else if (dev_mem != NULL) { + uct_ib_md_release_device_mem(dev_mem); + /* fallback to regular allocation */ + } + } + } +#endif + + flags = UCP_MEM_MAP_ALLOCATE | (spml->heap_reg_nb ? UCP_MEM_MAP_NONBLOCK : 0); + if (hint) { + return segment_create_internal(ds_buf, NULL, size, flags, hint, NULL); + } else { + return segment_create_internal(ds_buf, mca_sshmem_base_start_address, + size, flags | UCP_MEM_MAP_FIXED, hint, + NULL); + } +} + static void * segment_attach(map_segment_t *ds_buf, sshmem_mkey_t *mkey) { @@ -169,10 +295,22 @@ static int segment_unlink(map_segment_t *ds_buf) { mca_spml_ucx_t *spml = (mca_spml_ucx_t *)mca_spml.self; + mca_sshmem_ucx_segment_context_t *ctx = ds_buf->context; - assert(ds_buf); + if (ctx->shadow_allocator) { + sshmem_ucx_shadow_destroy(ctx->shadow_allocator); + } - ucp_mem_unmap(spml->ucp_context, (ucp_mem_h)ds_buf->context); + ucp_mem_unmap(spml->ucp_context, ctx->ucp_memh); + +#if HAVE_UCX_DEVICE_MEM + if (ctx->dev_mem) { + uct_ib_md_release_device_mem(ctx->dev_mem); + } +#endif + + ds_buf->context = NULL; + free(ctx); OPAL_OUTPUT_VERBOSE( (70, oshmem_sshmem_base_framework.framework_output, @@ -189,3 +327,79 @@ segment_unlink(map_segment_t *ds_buf) return OSHMEM_SUCCESS; } +static void *sshmem_ucx_memheap_index2ptr(map_segment_t *s, unsigned index) +{ + return (char*)s->super.va_base + (index * ALLOC_ELEM_SIZE); +} + +static unsigned sshmem_ucx_memheap_ptr2index(map_segment_t *s, void *ptr) +{ + return ((char*)ptr - (char*)s->super.va_base) / ALLOC_ELEM_SIZE; +} + +void sshmem_ucx_memheap_wordcopy(void *dst, void *src, size_t size) +{ + const size_t count = (size + sizeof(uint64_t) - 1) / sizeof(uint64_t); + uint64_t *dst64 = (uint64_t*)dst; + uint64_t *src64 = (uint64_t*)src; + size_t i; + + for (i = 0; i < count; ++i) { + *(dst64++) = *(src64++); + } + opal_atomic_wmb(); +} + +static int sshmem_ucx_memheap_realloc(map_segment_t *s, size_t size, + void* old_ptr, void** new_ptr) +{ + mca_sshmem_ucx_segment_context_t *ctx = s->context; + unsigned alloc_count, index; + int res; + + if (size > s->seg_size) { + return OSHMEM_ERR_OUT_OF_RESOURCE; + } + + /* create allocator on demand */ + if (!ctx->shadow_allocator) { + ctx->shadow_allocator = sshmem_ucx_shadow_create(s->seg_size); + if (!ctx->shadow_allocator) { + return OSHMEM_ERR_OUT_OF_RESOURCE; + } + } + + /* Allocate new element. Zero-size allocation should still return a unique + * pointer, so allocate 1 byte */ + alloc_count = max((size + ALLOC_ELEM_SIZE - 1) / ALLOC_ELEM_SIZE, 1); + res = sshmem_ucx_shadow_alloc(ctx->shadow_allocator, alloc_count, &index); + if (res != OSHMEM_SUCCESS) { + return res; + } + + *new_ptr = sshmem_ucx_memheap_index2ptr(s, index); + + /* Copy to new segment and release old*/ + if (old_ptr) { + unsigned old_index = sshmem_ucx_memheap_ptr2index(s, old_ptr); + unsigned old_alloc_count = sshmem_ucx_shadow_size(ctx->shadow_allocator, + old_index); + sshmem_ucx_memheap_wordcopy(*new_ptr, old_ptr, + min(size, old_alloc_count * ALLOC_ELEM_SIZE)); + sshmem_ucx_shadow_free(ctx->shadow_allocator, old_index); + } + + return OSHMEM_SUCCESS; +} + +static int sshmem_ucx_memheap_free(map_segment_t *s, void* ptr) +{ + mca_sshmem_ucx_segment_context_t *ctx = s->context; + + if (!ptr) { + return OSHMEM_SUCCESS; + } + + return sshmem_ucx_shadow_free(ctx->shadow_allocator, + sshmem_ucx_memheap_ptr2index(s, ptr)); +} diff --git a/oshmem/mca/sshmem/ucx/sshmem_ucx_shadow.c b/oshmem/mca/sshmem/ucx/sshmem_ucx_shadow.c new file mode 100644 index 0000000000..92fa2bb0cf --- /dev/null +++ b/oshmem/mca/sshmem/ucx/sshmem_ucx_shadow.c @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2019 Mellanox Technologies, Inc. + * All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "oshmem_config.h" + +#include "oshmem/mca/sshmem/sshmem.h" +#include "oshmem/include/shmemx.h" +#include "oshmem/mca/sshmem/base/base.h" + +#include "sshmem_ucx.h" + +#define SSHMEM_UCX_SHADOW_ELEM_FLAG_FREE 0x1 + +typedef struct sshmem_ucx_shadow_alloc_elem { + unsigned flags; + unsigned block_size; +} sshmem_ucx_shadow_alloc_elem_t; + +struct sshmem_ucx_shadow_allocator { + size_t num_elems; + sshmem_ucx_shadow_alloc_elem_t elems[]; +}; + +static int sshmem_ucx_shadow_is_free(sshmem_ucx_shadow_alloc_elem_t *elem) +{ + return elem->flags & SSHMEM_UCX_SHADOW_ELEM_FLAG_FREE; +} + +static void sshmem_ucx_shadow_set_elem(sshmem_ucx_shadow_alloc_elem_t *elem, + unsigned flags, unsigned block_size) +{ + elem->flags = flags; + elem->block_size = block_size; +} + +sshmem_ucx_shadow_allocator_t *sshmem_ucx_shadow_create(unsigned count) +{ + sshmem_ucx_shadow_allocator_t *allocator; + + allocator = calloc(1, sizeof(*allocator) + + count * sizeof(*allocator->elems)); + if (allocator) { + /* initialization: set initial element to the whole buffer */ + sshmem_ucx_shadow_set_elem(&allocator->elems[0], + SSHMEM_UCX_SHADOW_ELEM_FLAG_FREE, count); + allocator->num_elems = count; + } + + return allocator; +} + +void sshmem_ucx_shadow_destroy(sshmem_ucx_shadow_allocator_t *allocator) +{ + free(allocator); /* no leak check. TODO add leak warnings/debug */ +} + +int sshmem_ucx_shadow_alloc(sshmem_ucx_shadow_allocator_t *allocator, + unsigned count, unsigned *index) +{ + sshmem_ucx_shadow_alloc_elem_t *end = &allocator->elems[allocator->num_elems]; + sshmem_ucx_shadow_alloc_elem_t *elem; + + assert(count > 0); + + for (elem = &allocator->elems[0]; elem < end; elem += elem->block_size) { + if (sshmem_ucx_shadow_is_free(elem) && (elem->block_size >= count)) { + /* found suitable free element */ + if (elem->block_size > count) { + /* create new 'free' element for tail of current buffer */ + sshmem_ucx_shadow_set_elem(elem + count, + SSHMEM_UCX_SHADOW_ELEM_FLAG_FREE, + elem->block_size - count); + } + + /* set the size and flags of the allocated element */ + sshmem_ucx_shadow_set_elem(elem, 0, count); + *index = elem - &allocator->elems[0]; + return OSHMEM_SUCCESS; + } + } + + return OSHMEM_ERR_OUT_OF_RESOURCE; +} + +static void sshmem_ucx_shadow_merge_blocks(sshmem_ucx_shadow_allocator_t *allocator) +{ + sshmem_ucx_shadow_alloc_elem_t *elem = &allocator->elems[0]; + sshmem_ucx_shadow_alloc_elem_t *end = &allocator->elems[allocator->num_elems]; + sshmem_ucx_shadow_alloc_elem_t *next_elem; + + while ( (next_elem = (elem + elem->block_size)) < end) { + if (sshmem_ucx_shadow_is_free(elem) && sshmem_ucx_shadow_is_free(next_elem)) { + /* current & next elements are free, should be merged */ + elem->block_size += next_elem->block_size; + /* clean element which is merged */ + sshmem_ucx_shadow_set_elem(next_elem, 0, 0); + } else { + elem = next_elem; + } + } +} + +int sshmem_ucx_shadow_free(sshmem_ucx_shadow_allocator_t *allocator, + unsigned index) +{ + sshmem_ucx_shadow_alloc_elem_t *elem = &allocator->elems[index]; + + elem->flags |= SSHMEM_UCX_SHADOW_ELEM_FLAG_FREE; + sshmem_ucx_shadow_merge_blocks(allocator); + return OSHMEM_SUCCESS; +} + +size_t sshmem_ucx_shadow_size(sshmem_ucx_shadow_allocator_t *allocator, + unsigned index) +{ + sshmem_ucx_shadow_alloc_elem_t *elem = &allocator->elems[index]; + + assert(!sshmem_ucx_shadow_is_free(elem)); + return elem->block_size; +} diff --git a/oshmem/shmem/c/profile/defines.h b/oshmem/shmem/c/profile/defines.h index 22936efcfa..fa30d78377 100644 --- a/oshmem/shmem/c/profile/defines.h +++ b/oshmem/shmem/c/profile/defines.h @@ -58,6 +58,8 @@ #define shrealloc pshrealloc /* shmem-compat.h */ #define shfree pshfree /* shmem-compat.h */ +#define shmemx_malloc_with_hint pshmemx_malloc_with_hint + /* * Remote pointer operations */ diff --git a/oshmem/shmem/c/shmem_alloc.c b/oshmem/shmem/c/shmem_alloc.c index 3f7a579a20..92592ce8ca 100644 --- a/oshmem/shmem/c/shmem_alloc.c +++ b/oshmem/shmem/c/shmem_alloc.c @@ -11,6 +11,7 @@ #include "oshmem/constants.h" #include "oshmem/include/shmem.h" +#include "oshmem/include/shmemx.h" #include "oshmem/shmem/shmem_api_logger.h" @@ -19,9 +20,11 @@ #if OSHMEM_PROFILING #include "oshmem/include/pshmem.h" -#pragma weak shmem_malloc = pshmem_malloc -#pragma weak shmem_calloc = pshmem_calloc -#pragma weak shmalloc = pshmalloc +#include "oshmem/include/pshmemx.h" +#pragma weak shmem_malloc = pshmem_malloc +#pragma weak shmem_calloc = pshmem_calloc +#pragma weak shmalloc = pshmalloc +#pragma weak shmemx_malloc_with_hint = pshmemx_malloc_with_hint #include "oshmem/shmem/c/profile/defines.h" #endif @@ -72,3 +75,33 @@ static inline void* _shmalloc(size_t size) #endif return pBuff; } + +void* shmemx_malloc_with_hint(size_t size, long hint) +{ + int rc; + void* pBuff = NULL; + + if (!hint) { + return _shmalloc(size); + } + + RUNTIME_CHECK_INIT(); + RUNTIME_CHECK_WITH_MEMHEAP_SIZE(size); + + SHMEM_MUTEX_LOCK(shmem_internal_mutex_alloc); + + rc = mca_memheap_alloc_with_hint(size, hint, &pBuff); + + SHMEM_MUTEX_UNLOCK(shmem_internal_mutex_alloc); + + if (OSHMEM_SUCCESS != rc) { + SHMEM_API_VERBOSE(10, + "Allocation with shmalloc(size=%lu) failed.", + (unsigned long)size); + return NULL ; + } +#if OSHMEM_SPEC_COMPAT == 1 + shmem_barrier_all(); +#endif + return pBuff; +} diff --git a/oshmem/shmem/c/shmem_free.c b/oshmem/shmem/c/shmem_free.c index f5c5ce0cae..91619a7224 100644 --- a/oshmem/shmem/c/shmem_free.c +++ b/oshmem/shmem/c/shmem_free.c @@ -18,6 +18,7 @@ #include "oshmem/runtime/runtime.h" #include "oshmem/mca/memheap/memheap.h" +#include "oshmem/mca/memheap/base/base.h" #if OSHMEM_PROFILING #include "oshmem/include/pshmem.h" @@ -41,6 +42,7 @@ void shfree(void* ptr) static inline void _shfree(void* ptr) { int rc; + map_segment_t *s; RUNTIME_CHECK_INIT(); if (NULL == ptr) { @@ -55,7 +57,17 @@ static inline void _shfree(void* ptr) SHMEM_MUTEX_LOCK(shmem_internal_mutex_alloc); - rc = MCA_MEMHEAP_CALL(free(ptr)); + if (ptr) { + s = memheap_find_va(ptr); + } else { + s = NULL; + } + + if (s && s->allocator) { + rc = s->allocator->free(s, ptr); + } else { + rc = MCA_MEMHEAP_CALL(free(ptr)); + } SHMEM_MUTEX_UNLOCK(shmem_internal_mutex_alloc); diff --git a/oshmem/shmem/c/shmem_realloc.c b/oshmem/shmem/c/shmem_realloc.c index 0a45cf9fe3..7aab27735f 100644 --- a/oshmem/shmem/c/shmem_realloc.c +++ b/oshmem/shmem/c/shmem_realloc.c @@ -18,6 +18,7 @@ #include "oshmem/shmem/shmem_api_logger.h" #include "oshmem/mca/memheap/memheap.h" +#include "oshmem/mca/memheap/base/base.h" #if OSHMEM_PROFILING #include "oshmem/include/pshmem.h" @@ -42,12 +43,23 @@ static inline void* _shrealloc(void *ptr, size_t size) { int rc; void* pBuff = NULL; + map_segment_t *s; RUNTIME_CHECK_INIT(); SHMEM_MUTEX_LOCK(shmem_internal_mutex_alloc); - rc = MCA_MEMHEAP_CALL(realloc(size, ptr, &pBuff)); + if (ptr) { + s = memheap_find_va(ptr); + } else { + s = NULL; + } + + if (s && s->allocator) { + rc = s->allocator->realloc(s, size, ptr, &pBuff); + } else { + rc = MCA_MEMHEAP_CALL(realloc(size, ptr, &pBuff)); + } SHMEM_MUTEX_UNLOCK(shmem_internal_mutex_alloc);