diff --git a/ompi/mca/osc/ucx/osc_ucx_comm.c b/ompi/mca/osc/ucx/osc_ucx_comm.c index 9211f20e79..ec760d4fda 100644 --- a/ompi/mca/osc/ucx/osc_ucx_comm.c +++ b/ompi/mca/osc/ucx/osc_ucx_comm.c @@ -17,6 +17,13 @@ #include "osc_ucx.h" #include "osc_ucx_request.h" + +#define CHECK_VALID_RKEY(_module, _target, _count) \ + if (!((_module)->win_info_array[_target]).rkey_init && ((_count) > 0)) { \ + OSC_UCX_VERBOSE(1, "window with non-zero length does not have an rkey"); \ + return OMPI_ERROR; \ + } + typedef struct ucx_iovec { void *addr; size_t len; @@ -380,6 +387,12 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data } } + CHECK_VALID_RKEY(module, target, target_count); + + if (!target_count) { + return OMPI_SUCCESS; + } + rkey = (module->win_info_array[target]).rkey; ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent); @@ -434,6 +447,12 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, } } + CHECK_VALID_RKEY(module, target, target_count); + + if (!target_count) { + return OMPI_SUCCESS; + } + rkey = (module->win_info_array[target]).rkey; ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent); @@ -860,6 +879,8 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, } } + CHECK_VALID_RKEY(module, target, target_count); + rkey = (module->win_info_array[target]).rkey; OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req); @@ -919,6 +940,8 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count, } } + CHECK_VALID_RKEY(module, target, target_count); + rkey = (module->win_info_array[target]).rkey; OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req); diff --git a/ompi/mca/osc/ucx/osc_ucx_component.c b/ompi/mca/osc/ucx/osc_ucx_component.c index 955857d974..149106830c 100644 --- a/ompi/mca/osc/ucx/osc_ucx_component.c +++ b/ompi/mca/osc/ucx/osc_ucx_component.c @@ -17,6 +17,10 @@ #include "osc_ucx.h" #include "osc_ucx_request.h" +#define memcpy_off(_dst, _src, _len, _off) \ + memcpy(((char*)(_dst)) + (_off), _src, _len); \ + (_off) += (_len); + static int component_open(void); static int component_register(void); static int component_init(bool enable_progress_threads, bool enable_mpi_threads); @@ -275,6 +279,8 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in int disps[comm_size]; int rkey_sizes[comm_size]; uint64_t zero = 0; + size_t info_offset; + uint64_t size_u64; /* the osc/sm component is the exclusive provider for support for * shared memory windows */ @@ -518,22 +524,27 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in goto error; } - my_info_len = 2 * sizeof(uint64_t) + rkey_buffer_size + state_rkey_buffer_size; + size_u64 = (uint64_t)size; + my_info_len = 3 * sizeof(uint64_t) + rkey_buffer_size + state_rkey_buffer_size; my_info = malloc(my_info_len); if (my_info == NULL) { ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE; goto error; } + info_offset = 0; + if (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE) { - memcpy(my_info, base, sizeof(uint64_t)); + memcpy_off(my_info, base, sizeof(uint64_t), info_offset); } else { - memcpy(my_info, &zero, sizeof(uint64_t)); + memcpy_off(my_info, &zero, sizeof(uint64_t), info_offset); } - memcpy((void *)((char *)my_info + sizeof(uint64_t)), &state_base, sizeof(uint64_t)); - memcpy((void *)((char *)my_info + 2 * sizeof(uint64_t)), rkey_buffer, rkey_buffer_size); - memcpy((void *)((char *)my_info + 2 * sizeof(uint64_t) + rkey_buffer_size), - state_rkey_buffer, state_rkey_buffer_size); + memcpy_off(my_info, &state_base, sizeof(uint64_t), info_offset); + memcpy_off(my_info, &size_u64, sizeof(uint64_t), info_offset); + memcpy_off(my_info, rkey_buffer, rkey_buffer_size, info_offset); + memcpy_off(my_info, state_rkey_buffer, state_rkey_buffer_size, info_offset); + + assert(my_info_len == info_offset); ret = allgather_len_and_info(my_info, (int)my_info_len, &recv_buf, disps, module->comm); if (ret != OMPI_SUCCESS) { @@ -549,25 +560,32 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in for (i = 0; i < comm_size; i++) { ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, i); + uint64_t dest_size; assert(ep != NULL); - memcpy(&(module->win_info_array[i]).addr, &recv_buf[disps[i]], sizeof(uint64_t)); - memcpy(&(module->state_info_array[i]).addr, &recv_buf[disps[i] + sizeof(uint64_t)], - sizeof(uint64_t)); + info_offset = disps[i]; + + memcpy(&(module->win_info_array[i]).addr, &recv_buf[info_offset], sizeof(uint64_t)); + info_offset += sizeof(uint64_t); + memcpy(&(module->state_info_array[i]).addr, &recv_buf[info_offset], sizeof(uint64_t)); + info_offset += sizeof(uint64_t); + memcpy(&dest_size, &recv_buf[info_offset], sizeof(uint64_t)); + info_offset += sizeof(uint64_t); (module->win_info_array[i]).rkey_init = false; - if (size > 0 && (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE)) { - status = ucp_ep_rkey_unpack(ep, &(recv_buf[disps[i] + 2 * sizeof(uint64_t)]), + if (dest_size > 0 && (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE)) { + status = ucp_ep_rkey_unpack(ep, &recv_buf[info_offset], &((module->win_info_array[i]).rkey)); if (status != UCS_OK) { OSC_UCX_VERBOSE(1, "ucp_ep_rkey_unpack failed: %d", status); ret = OMPI_ERROR; goto error; } + info_offset += rkey_sizes[i]; (module->win_info_array[i]).rkey_init = true; } - status = ucp_ep_rkey_unpack(ep, &(recv_buf[disps[i] + 2 * sizeof(uint64_t) + rkey_sizes[i]]), + status = ucp_ep_rkey_unpack(ep, &recv_buf[info_offset], &((module->state_info_array[i]).rkey)); if (status != UCS_OK) { OSC_UCX_VERBOSE(1, "ucp_ep_rkey_unpack failed: %d", status);