diff --git a/src/iperf_tcp.c b/src/iperf_tcp.c index c78f4f5..67b00e3 100644 --- a/src/iperf_tcp.c +++ b/src/iperf_tcp.c @@ -369,101 +369,16 @@ iperf_tcp_listen(struct iperf_test *test) int iperf_tcp_connect(struct iperf_test *test) { - struct addrinfo hints, *local_res, *server_res; - char portstr[6]; + struct addrinfo *server_res; int s, opt; socklen_t optlen; int saved_errno; int rcvbuf_actual, sndbuf_actual; - if (test->bind_address) { - memset(&hints, 0, sizeof(hints)); - hints.ai_family = test->settings->domain; - hints.ai_socktype = SOCK_STREAM; - if ((gerror = getaddrinfo(test->bind_address, NULL, &hints, &local_res)) != 0) { - i_errno = IESTREAMCONNECT; - return -1; - } - } - - memset(&hints, 0, sizeof(hints)); - hints.ai_family = test->settings->domain; - hints.ai_socktype = SOCK_STREAM; - snprintf(portstr, sizeof(portstr), "%d", test->server_port); - if ((gerror = getaddrinfo(test->server_hostname, portstr, &hints, &server_res)) != 0) { - if (test->bind_address) - freeaddrinfo(local_res); - i_errno = IESTREAMCONNECT; - return -1; - } - - if ((s = socket(server_res->ai_family, SOCK_STREAM, 0)) < 0) { - if (test->bind_address) - freeaddrinfo(local_res); - freeaddrinfo(server_res); - i_errno = IESTREAMCONNECT; - return -1; - } - - /* - * Various ways to bind the local end of the connection. - * 1. --bind (with or without --cport). - */ - if (test->bind_address) { - struct sockaddr_in *lcladdr; - lcladdr = (struct sockaddr_in *)local_res->ai_addr; - lcladdr->sin_port = htons(test->bind_port); - - if (bind(s, (struct sockaddr *) local_res->ai_addr, local_res->ai_addrlen) < 0) { - saved_errno = errno; - close(s); - freeaddrinfo(local_res); - freeaddrinfo(server_res); - errno = saved_errno; - i_errno = IESTREAMCONNECT; - return -1; - } - freeaddrinfo(local_res); - } - /* --cport, no --bind */ - else if (test->bind_port) { - size_t addrlen; - struct sockaddr_storage lcl; - - /* IPv4 */ - if (server_res->ai_family == AF_INET) { - struct sockaddr_in *lcladdr = (struct sockaddr_in *) &lcl; - lcladdr->sin_family = AF_INET; - lcladdr->sin_port = htons(test->bind_port); - lcladdr->sin_addr.s_addr = INADDR_ANY; - addrlen = sizeof(struct sockaddr_in); - } - /* IPv6 */ - else if (server_res->ai_family == AF_INET6) { - struct sockaddr_in6 *lcladdr = (struct sockaddr_in6 *) &lcl; - lcladdr->sin6_family = AF_INET6; - lcladdr->sin6_port = htons(test->bind_port); - lcladdr->sin6_addr = in6addr_any; - addrlen = sizeof(struct sockaddr_in6); - } - /* Unknown protocol */ - else { - saved_errno = errno; - close(s); - freeaddrinfo(server_res); - errno = saved_errno; - i_errno = IEPROTOCOL; - return -1; - } - - if (bind(s, (struct sockaddr *) &lcl, addrlen) < 0) { - saved_errno = errno; - close(s); - freeaddrinfo(server_res); - errno = saved_errno; - i_errno = IESTREAMCONNECT; - return -1; - } + s = create_socket(test->settings->domain, SOCK_STREAM, test->bind_address, test->bind_dev, test->bind_port, test->server_hostname, test->server_port, &server_res); + if (s < 0) { + i_errno = IESTREAMCONNECT; + return -1; } /* Set socket options */ diff --git a/src/net.c b/src/net.c index 2c3aaf3..aca2a4c 100644 --- a/src/net.c +++ b/src/net.c @@ -119,12 +119,13 @@ timeout_connect(int s, const struct sockaddr *name, socklen_t namelen, * Copyright: http://swtch.com/libtask/COPYRIGHT */ -/* make connection to server */ +/* create a socket */ int -netdial(int domain, int proto, const char *local, const char *bind_dev, int local_port, const char *server, int port, int timeout) +create_socket(int domain, int proto, const char *local, const char *bind_dev, int local_port, const char *server, int port, struct addrinfo **server_res_out) { struct addrinfo hints, *local_res = NULL, *server_res = NULL; int s, saved_errno; + char portstr[6]; if (local) { memset(&hints, 0, sizeof(hints)); @@ -137,8 +138,12 @@ netdial(int domain, int proto, const char *local, const char *bind_dev, int loca memset(&hints, 0, sizeof(hints)); hints.ai_family = domain; hints.ai_socktype = proto; - if ((gerror = getaddrinfo(server, NULL, &hints, &server_res)) != 0) + snprintf(portstr, sizeof(portstr), "%d", port); + if ((gerror = getaddrinfo(server, portstr, &hints, &server_res)) != 0) { + if (local) + freeaddrinfo(local_res); return -1; + } s = socket(server_res->ai_family, proto, 0); if (s < 0) { @@ -204,6 +209,8 @@ netdial(int domain, int proto, const char *local, const char *bind_dev, int loca } /* Unknown protocol */ else { + close(s); + freeaddrinfo(server_res); errno = EAFNOSUPPORT; return -1; } @@ -217,7 +224,22 @@ netdial(int domain, int proto, const char *local, const char *bind_dev, int loca } } - ((struct sockaddr_in *) server_res->ai_addr)->sin_port = htons(port); + *server_res_out = server_res; + return s; +} + +/* make connection to server */ +int +netdial(int domain, int proto, const char *local, const char *bind_dev, int local_port, const char *server, int port, int timeout) +{ + struct addrinfo *server_res = NULL; + int s, saved_errno; + + s = create_socket(domain, proto, local, bind_dev, local_port, server, port, &server_res); + if (s < 0) { + return -1; + } + if (timeout_connect(s, (struct sockaddr *) server_res->ai_addr, server_res->ai_addrlen, timeout) < 0 && errno != EINPROGRESS) { saved_errno = errno; close(s); diff --git a/src/net.h b/src/net.h index 44c0d7e..f0e1b4f 100644 --- a/src/net.h +++ b/src/net.h @@ -28,6 +28,7 @@ #define __NET_H int timeout_connect(int s, const struct sockaddr *name, socklen_t namelen, int timeout); +int create_socket(int domain, int proto, const char *local, const char *bind_dev, int local_port, const char *server, int port, struct addrinfo **server_res_out); int netdial(int domain, int proto, const char *local, const char *bind_dev, int local_port, const char *server, int port, int timeout); int netannounce(int domain, int proto, const char *local, const char *bind_dev, int port); int Nread(int fd, char *buf, size_t count, int prot);