diff --git a/opal/mca/btl/usnic/btl_usnic_cagent.c b/opal/mca/btl/usnic/btl_usnic_cagent.c index 7550a02424..db4d4369df 100644 --- a/opal/mca/btl/usnic/btl_usnic_cagent.c +++ b/opal/mca/btl/usnic/btl_usnic_cagent.c @@ -71,6 +71,7 @@ typedef struct { uint8_t *buffer; opal_event_t event; bool active; + opal_btl_usnic_module_t *module; } agent_udp_port_listener_t; OBJ_CLASS_DECLARATION(agent_udp_port_listener_t); @@ -323,6 +324,40 @@ static void agent_thread_noop(int fd, short flags, void *context) /* Intentionally a no op */ } +/* + * Check to ensure that we expected to receive a ping from this sender + * on the interface in which it was received (i.e., did the usnic + * module corresponding to the received interface choose to pair + * itself with the sender's interface). If not, discard it. + * + * Note that there may be a race condition here. We may get a ping + * before we've setup endpoints on the module in question. It's no + * problem -- if we don't find it, we'll drop the PING and let the + * sender try again later. + */ +static bool agent_thread_is_ping_expected(opal_btl_usnic_module_t *module, + uint32_t src_ipv4_addr) +{ + bool found = false; + opal_list_item_t *item; + + opal_mutex_lock(&module->all_endpoints_lock); + if (module->all_endpoints_constructed) { + OPAL_LIST_FOREACH(item, &module->all_endpoints, opal_list_item_t) { + opal_btl_usnic_endpoint_t *ep; + ep = container_of(item, opal_btl_usnic_endpoint_t, + endpoint_endpoint_li); + if (src_ipv4_addr == ep->endpoint_remote_addr.ipv4_addr) { + found = true; + break; + } + } + } + opal_mutex_unlock(&module->all_endpoints_lock); + + return found; +} + /* * Handle an incoming PING message (send an ACK) */ @@ -365,6 +400,19 @@ static void agent_thread_handle_ping(agent_udp_port_listener_t *listener, return; } + /* Finally, check that the ping is from an interface that the + module expects */ + if (!agent_thread_is_ping_expected(listener->module, + src_addr_in->sin_addr.s_addr)) { + opal_output_verbose(20, USNIC_OUT, + "usNIC connectivity got bad ping (from unexpected address: listener %s not paired with peer interface %s, discarded)", + listener->ipv4_addr_str, + real_ipv4_addr_str); + return; + } + + /* Ok, this is a good ping. Send the ACK back */ + opal_output_verbose(20, USNIC_OUT, "usNIC connectivity got PING (size=%ld) from %s; sending ACK", numbytes, msg_ipv4_addr_str); @@ -545,6 +593,7 @@ static void agent_thread_cmd_listen(agent_ipc_listener_t *ipc_listener) /* Will not return */ } + udp_listener->module = cmd.module; udp_listener->mtu = cmd.mtu; udp_listener->ipv4_addr = cmd.ipv4_addr; udp_listener->cidrmask = cmd.cidrmask; diff --git a/opal/mca/btl/usnic/btl_usnic_cclient.c b/opal/mca/btl/usnic/btl_usnic_cclient.c index ca8271dbf4..6d8df9f5fb 100644 --- a/opal/mca/btl/usnic/btl_usnic_cclient.c +++ b/opal/mca/btl/usnic/btl_usnic_cclient.c @@ -163,6 +163,7 @@ int opal_btl_usnic_connectivity_listen(opal_btl_usnic_module_t *module) /* Send the LISTEN command parameters */ opal_btl_usnic_connectivity_cmd_listen_t cmd = { + .module = module, .ipv4_addr = module->local_addr.ipv4_addr, .cidrmask = module->local_addr.cidrmask, .mtu = module->local_addr.mtu diff --git a/opal/mca/btl/usnic/btl_usnic_component.c b/opal/mca/btl/usnic/btl_usnic_component.c index 1c91308231..01cdb5aa9a 100644 --- a/opal/mca/btl/usnic/btl_usnic_component.c +++ b/opal/mca/btl/usnic/btl_usnic_component.c @@ -1555,6 +1555,7 @@ void opal_btl_usnic_component_debug(void) /* the all_endpoints list uses a different list item member */ opal_output(0, " all_endpoints:\n"); + opal_mutex_lock(&module->all_endpoints_lock); item = opal_list_get_first(&module->all_endpoints); while (item != opal_list_get_end(&module->all_endpoints)) { endpoint = container_of(item, mca_btl_base_endpoint_t, @@ -1562,6 +1563,7 @@ void opal_btl_usnic_component_debug(void) item = opal_list_get_next(item); dump_endpoint(endpoint); } + opal_mutex_unlock(&module->all_endpoints_lock); opal_output(0, " pending_resend_segs:\n"); OPAL_LIST_FOREACH(sseg, &module->pending_resend_segs, diff --git a/opal/mca/btl/usnic/btl_usnic_connectivity.h b/opal/mca/btl/usnic/btl_usnic_connectivity.h index 898fafc09f..1bf134a2ee 100644 --- a/opal/mca/btl/usnic/btl_usnic_connectivity.h +++ b/opal/mca/btl/usnic/btl_usnic_connectivity.h @@ -115,6 +115,7 @@ enum { * socket from the cclient to the cagent. */ typedef struct { + void *module; uint32_t ipv4_addr; uint32_t cidrmask; uint32_t mtu; diff --git a/opal/mca/btl/usnic/btl_usnic_endpoint.c b/opal/mca/btl/usnic/btl_usnic_endpoint.c index 7f2151f0bd..da80f122b7 100644 --- a/opal/mca/btl/usnic/btl_usnic_endpoint.c +++ b/opal/mca/btl/usnic/btl_usnic_endpoint.c @@ -118,8 +118,10 @@ static void endpoint_destruct(mca_btl_base_endpoint_t* endpoint) /* Remove this endpoint from module->all_endpoints list, then destruct the list_item_t */ + opal_mutex_lock(&endpoint->endpoint_module->all_endpoints_lock); opal_list_remove_item(&endpoint->endpoint_module->all_endpoints, &endpoint->endpoint_endpoint_li); + opal_mutex_unlock(&endpoint->endpoint_module->all_endpoints_lock); OBJ_DESTRUCT(&(endpoint->endpoint_endpoint_li)); if (endpoint->endpoint_hotel.rooms != NULL) { diff --git a/opal/mca/btl/usnic/btl_usnic_module.c b/opal/mca/btl/usnic/btl_usnic_module.c index 23f981c864..268f5c500a 100644 --- a/opal/mca/btl/usnic/btl_usnic_module.c +++ b/opal/mca/btl/usnic/btl_usnic_module.c @@ -1134,7 +1134,10 @@ static int usnic_finalize(struct mca_btl_base_module_t* btl) /* Note that usnic_del_procs will have been called for *all* procs by this point, so the module->all_endpoints list will be empty. Destruct it. */ + opal_mutex_lock(&module->all_endpoints_lock); OBJ_DESTRUCT(&(module->all_endpoints)); + module->all_endpoints_constructed = false; + opal_mutex_unlock(&module->all_endpoints_lock); /* _flush_endpoint should have emptied this list */ assert(opal_list_is_empty(&(module->pending_resend_segs))); @@ -2168,7 +2171,10 @@ int opal_btl_usnic_module_init(opal_btl_usnic_module_t *module) /* No more errors anticipated - initialize everything else */ /* list of all endpoints */ + opal_mutex_lock(&module->all_endpoints_lock); OBJ_CONSTRUCT(&(module->all_endpoints), opal_list_t); + module->all_endpoints_constructed = true; + opal_mutex_unlock(&module->all_endpoints_lock); /* Pending send segs list */ OBJ_CONSTRUCT(&module->pending_resend_segs, opal_list_t); diff --git a/opal/mca/btl/usnic/btl_usnic_module.h b/opal/mca/btl/usnic/btl_usnic_module.h index 3c624fb7aa..9ac19aa995 100644 --- a/opal/mca/btl/usnic/btl_usnic_module.h +++ b/opal/mca/btl/usnic/btl_usnic_module.h @@ -140,8 +140,16 @@ typedef struct opal_btl_usnic_module_t { /** local address information */ struct opal_btl_usnic_addr_t local_addr; - /** list of all endpoints */ + /** list of all endpoints. Note that the main application thread + reads and writes to this list, and the connectivity agent + reads from it. So all access to the list (but not the items + in the list) must be protected by a lock. Also, have a flag + that indicates that the list has been constructed. Probably + overkill, but you can't be too safe with multi-threaded + programming in non-performance-critical code paths... */ opal_list_t all_endpoints; + opal_mutex_t all_endpoints_lock; + bool all_endpoints_constructed; /** array of procs used by this module (can't use a list because a proc can be used by multiple modules) */ diff --git a/opal/mca/btl/usnic/btl_usnic_proc.c b/opal/mca/btl/usnic/btl_usnic_proc.c index 21ca2437e8..d666284ec1 100644 --- a/opal/mca/btl/usnic/btl_usnic_proc.c +++ b/opal/mca/btl/usnic/btl_usnic_proc.c @@ -138,6 +138,7 @@ opal_btl_usnic_proc_lookup_endpoint(opal_btl_usnic_module_t *receiver, MSGDEBUG1_OUT("lookup_endpoint: recvmodule=%p sendhash=0x%" PRIx64, (void *)receiver, sender_hashed_rte_name); + opal_mutex_lock(&receiver->all_endpoints_lock); for (item = opal_list_get_first(&receiver->all_endpoints); item != opal_list_get_end(&receiver->all_endpoints); item = opal_list_get_next(item)) { @@ -152,9 +153,11 @@ opal_btl_usnic_proc_lookup_endpoint(opal_btl_usnic_module_t *receiver, if (proc->proc_opal->proc_name == sender_proc_name) { MSGDEBUG1_OUT("lookup_endpoint: matched endpoint=%p", (void *)endpoint); + opal_mutex_unlock(&receiver->all_endpoints_lock); return endpoint; } } + opal_mutex_unlock(&receiver->all_endpoints_lock); /* Didn't find it */ return NULL; @@ -702,8 +705,10 @@ opal_btl_usnic_create_endpoint(opal_btl_usnic_module_t *module, OBJ_RETAIN(proc); /* also add endpoint to module's list of endpoints */ + opal_mutex_lock(&module->all_endpoints_lock); opal_list_append(&(module->all_endpoints), &(endpoint->endpoint_endpoint_li)); + opal_mutex_unlock(&module->all_endpoints_lock); *endpoint_o = endpoint; return OPAL_SUCCESS; diff --git a/opal/mca/btl/usnic/btl_usnic_stats.c b/opal/mca/btl/usnic/btl_usnic_stats.c index 65e9b3e1ce..c049636952 100644 --- a/opal/mca/btl/usnic/btl_usnic_stats.c +++ b/opal/mca/btl/usnic/btl_usnic_stats.c @@ -133,6 +133,7 @@ void opal_btl_usnic_print_stats( rd_min = su_min = WINDOW_SIZE * 2; rd_max = su_max = 0; + opal_mutex_lock(&module->all_endpoints_lock); item = opal_list_get_first(&module->all_endpoints); while (item != opal_list_get_end(&(module->all_endpoints))) { endpoint = container_of(item, mca_btl_base_endpoint_t, @@ -156,6 +157,7 @@ void opal_btl_usnic_print_stats( if (recv_depth > rd_max) rd_max = recv_depth; if (recv_depth < rd_min) rd_min = recv_depth; } + opal_mutex_unlock(&module->all_endpoints_lock); snprintf(tmp, sizeof(tmp), "PML S:%1ld, Win!A/R:%4ld/%4ld %4ld/%4ld", module->stats.pml_module_sends, su_min, su_max,