diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index b07637d8af..9a268e4cd4 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -774,6 +774,30 @@ int mca_spml_ucx_get_nb(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_ return ucx_status_to_oshmem_nb(status); } +int mca_spml_ucx_get_nb_wprogress(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_addr, int src, void **handle) +{ + unsigned int i; + void *rva; + ucs_status_t status; + spml_ucx_mkey_t *ucx_mkey; + mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; + + ucx_mkey = mca_spml_ucx_get_mkey(ctx, src, src_addr, &rva, &mca_spml_ucx); + status = ucp_get_nbi(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size, + (uint64_t)rva, ucx_mkey->rkey); + + if (++ucx_ctx->nb_progress_cnt > mca_spml_ucx.nb_get_progress_thresh) { + for (i = 0; i < mca_spml_ucx.nb_ucp_worker_progress; i++) { + if (!ucp_worker_progress(ucx_ctx->ucp_worker)) { + ucx_ctx->nb_progress_cnt = 0; + break; + } + } + } + + return ucx_status_to_oshmem_nb(status); +} + int mca_spml_ucx_put(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_addr, int dst) { void *rva; @@ -822,7 +846,33 @@ int mca_spml_ucx_put_nb(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_ return ucx_status_to_oshmem_nb(status); } +int mca_spml_ucx_put_nb_wprogress(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_addr, int dst, void **handle) +{ + unsigned int i; + void *rva; + ucs_status_t status; + spml_ucx_mkey_t *ucx_mkey; + mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; + ucx_mkey = mca_spml_ucx_get_mkey(ctx, dst, dst_addr, &rva, &mca_spml_ucx); + status = ucp_put_nbi(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size, + (uint64_t)rva, ucx_mkey->rkey); + + if (OPAL_LIKELY(status >= 0)) { + mca_spml_ucx_remote_op_posted(ucx_ctx, dst); + } + + if (++ucx_ctx->nb_progress_cnt > mca_spml_ucx.nb_put_progress_thresh) { + for (i = 0; i < mca_spml_ucx.nb_ucp_worker_progress; i++) { + if (!ucp_worker_progress(ucx_ctx->ucp_worker)) { + ucx_ctx->nb_progress_cnt = 0; + break; + } + } + } + + return ucx_status_to_oshmem_nb(status); +} int mca_spml_ucx_fence(shmem_ctx_t ctx) { @@ -880,6 +930,8 @@ int mca_spml_ucx_quiet(shmem_ctx_t ctx) } } + ucx_ctx->nb_progress_cnt = 0; + return OSHMEM_SUCCESS; } diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index b81a10b136..457dbe3b82 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -71,6 +71,7 @@ struct mca_spml_ucx_ctx { ucp_peer_t *ucp_peers; long options; opal_bitmap_t put_op_bitmap; + unsigned long nb_progress_cnt; int *put_proc_indexes; unsigned put_proc_count; }; @@ -108,6 +109,10 @@ struct mca_spml_ucx { pthread_spinlock_t async_lock; int aux_refcnt; bool synchronized_quiet; + unsigned long nb_progress_thresh_global; + unsigned long nb_put_progress_thresh; + unsigned long nb_get_progress_thresh; + unsigned long nb_ucp_worker_progress; }; typedef struct mca_spml_ucx mca_spml_ucx_t; @@ -122,6 +127,7 @@ extern int mca_spml_ucx_get(shmem_ctx_t ctx, size_t size, void* src_addr, int src); + extern int mca_spml_ucx_get_nb(shmem_ctx_t ctx, void* dst_addr, size_t size, @@ -129,6 +135,13 @@ extern int mca_spml_ucx_get_nb(shmem_ctx_t ctx, int src, void **handle); +extern int mca_spml_ucx_get_nb_wprogress(shmem_ctx_t ctx, + void* dst_addr, + size_t size, + void* src_addr, + int src, + void **handle); + extern int mca_spml_ucx_put(shmem_ctx_t ctx, void* dst_addr, size_t size, @@ -142,6 +155,13 @@ extern int mca_spml_ucx_put_nb(shmem_ctx_t ctx, int dst, void **handle); +extern int mca_spml_ucx_put_nb_wprogress(shmem_ctx_t ctx, + void* dst_addr, + size_t size, + void* src_addr, + int dst, + void **handle); + extern int mca_spml_ucx_recv(void* buf, size_t size, int src); extern int mca_spml_ucx_send(void* buf, size_t size, diff --git a/oshmem/mca/spml/ucx/spml_ucx_component.c b/oshmem/mca/spml/ucx/spml_ucx_component.c index 0bfdc1d61e..bbe18d39c5 100644 --- a/oshmem/mca/spml/ucx/spml_ucx_component.c +++ b/oshmem/mca/spml/ucx/spml_ucx_component.c @@ -60,6 +60,20 @@ mca_spml_base_component_2_0_0_t mca_spml_ucx_component = { .spmlm_finalize = mca_spml_ucx_component_fini }; +static inline void mca_spml_ucx_param_register_ulong(const char* param_name, + unsigned long default_value, + const char *help_msg, + unsigned long *storage) +{ + *storage = default_value; + (void) mca_base_component_var_register(&mca_spml_ucx_component.spmlm_version, + param_name, + help_msg, + MCA_BASE_VAR_TYPE_UNSIGNED_LONG, NULL, 0, 0, + OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, + storage); +} static inline void mca_spml_ucx_param_register_int(const char* param_name, int default_value, @@ -132,6 +146,22 @@ static int mca_spml_ucx_component_register(void) "Use synchronized quiet on shmem_quiet or shmem_barrier_all operations", &mca_spml_ucx.synchronized_quiet); + mca_spml_ucx_param_register_ulong("nb_progress_thresh_global", 0, + "Number of nb_put or nb_get operations before ucx progress is triggered. Disabled by default (0)", + &mca_spml_ucx.nb_progress_thresh_global); + + mca_spml_ucx_param_register_ulong("nb_put_progress_thresh", mca_spml_ucx.nb_progress_thresh_global, + "Number of nb_put operations before ucx progress is triggered. Disabled by default (0), setting this value will override nb_progress_thresh_global", + &mca_spml_ucx.nb_put_progress_thresh); + + mca_spml_ucx_param_register_ulong("nb_get_progress_thresh", mca_spml_ucx.nb_progress_thresh_global, + "Number of nb_get operations before ucx progress is triggered. Disabled by default (0), setting this value will override nb_progress_thresh_global ", + &mca_spml_ucx.nb_get_progress_thresh); + + mca_spml_ucx_param_register_ulong("nb_ucp_worker_progress", 32, + "Maximum number of ucx worker progress calls if triggered during nb_put or nb_get", + &mca_spml_ucx.nb_ucp_worker_progress); + opal_common_ucx_mca_var_register(&mca_spml_ucx_component.spmlm_version); return OSHMEM_SUCCESS; @@ -294,6 +324,13 @@ static int spml_ucx_init(void) mca_spml_ucx.aux_ctx = NULL; mca_spml_ucx.aux_refcnt = 0; + if (mca_spml_ucx.nb_put_progress_thresh) { + mca_spml_ucx.super.spml_put_nb = &mca_spml_ucx_put_nb_wprogress; + } + if (mca_spml_ucx.nb_get_progress_thresh) { + mca_spml_ucx.super.spml_get_nb = &mca_spml_ucx_get_nb_wprogress; + } + oshmem_ctx_default = (shmem_ctx_t) &mca_spml_ucx_ctx_default; return OSHMEM_SUCCESS;