diff --git a/src/poll.c b/src/poll.c index 2f6d1c4b..aa2a0821 100644 --- a/src/poll.c +++ b/src/poll.c @@ -102,6 +102,7 @@ typedef int (*poll_fn)(ssh_pollfd_t *, nfds_t, int); static poll_fn ssh_poll_emu; #include +#include #ifdef _WIN32 #ifndef STRICT @@ -125,6 +126,27 @@ static poll_fn ssh_poll_emu; #include #endif +static bool bsd_socket_disconnected(int sock_err) +{ + switch (sock_err) { +#ifdef _WIN32 + case WSAECONNABORTED: + case WSAECONNRESET: + case WSAENETRESET: + case WSAESHUTDOWN: +#else + case ECONNABORTED: + case ECONNRESET: + case ENETRESET: + case ESHUTDOWN: +#endif + return true; + default: + return false; + } + + return false; +} /* * This is a poll(2)-emulation using select for systems not providing a native @@ -135,118 +157,112 @@ static poll_fn ssh_poll_emu; * a value as high as 1024 on Linux you'll pay dearly in every single call. * poll() will be orders of magnitude faster. */ -static int bsd_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout) { - fd_set readfds, writefds, exceptfds; - struct timeval tv, *ptv; - socket_t max_fd; - int rc; - nfds_t i; +static int bsd_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout) +{ + fd_set readfds, writefds, exceptfds; + struct timeval tv, *ptv = NULL; + socket_t max_fd; + int rc; + nfds_t i; - if (fds == NULL) { - errno = EFAULT; - return -1; - } - - FD_ZERO (&readfds); - FD_ZERO (&writefds); - FD_ZERO (&exceptfds); - - /* compute fd_sets and find largest descriptor */ - for (rc = -1, max_fd = 0, i = 0; i < nfds; i++) { - if (fds[i].fd == SSH_INVALID_SOCKET) { - continue; - } -#ifndef _WIN32 - if (fds[i].fd >= FD_SETSIZE) { - rc = -1; - break; - } -#endif - - if (fds[i].events & (POLLIN | POLLRDNORM)) { - FD_SET (fds[i].fd, &readfds); - } - if (fds[i].events & (POLLOUT | POLLWRNORM | POLLWRBAND)) { - FD_SET (fds[i].fd, &writefds); - } - if (fds[i].events & (POLLPRI | POLLRDBAND)) { - FD_SET (fds[i].fd, &exceptfds); - } - if (fds[i].fd > max_fd && - (fds[i].events & (POLLIN | POLLOUT | POLLPRI | - POLLRDNORM | POLLRDBAND | - POLLWRNORM | POLLWRBAND))) { - max_fd = fds[i].fd; - rc = 0; - } - } - - if (max_fd == SSH_INVALID_SOCKET || rc == -1) { - errno = EINVAL; - return -1; - } - - if (timeout < 0) { - ptv = NULL; - } else { - ptv = &tv; - if (timeout == 0) { - tv.tv_sec = 0; - tv.tv_usec = 0; - } else { - tv.tv_sec = timeout / 1000; - tv.tv_usec = (timeout % 1000) * 1000; - } - } - - rc = select (max_fd + 1, &readfds, &writefds, &exceptfds, ptv); - if (rc < 0) { - return -1; - } - - for (rc = 0, i = 0; i < nfds; i++) - if (fds[i].fd >= 0) { - fds[i].revents = 0; - - if (FD_ISSET(fds[i].fd, &readfds)) { - int save_errno = errno; - char data[64] = {0}; - int ret; - - /* support for POLLHUP */ - ret = recv(fds[i].fd, data, 64, MSG_PEEK); -#ifdef _WIN32 - if ((ret == -1) && - (errno == WSAESHUTDOWN || errno == WSAECONNRESET || - errno == WSAECONNABORTED || errno == WSAENETRESET)) { -#else - if ((ret == -1) && - (errno == ESHUTDOWN || errno == ECONNRESET || - errno == ECONNABORTED || errno == ENETRESET)) { -#endif - fds[i].revents |= POLLHUP; - } else { - fds[i].revents |= fds[i].events & (POLLIN | POLLRDNORM); - } - - errno = save_errno; - } - if (FD_ISSET(fds[i].fd, &writefds)) { - fds[i].revents |= fds[i].events & (POLLOUT | POLLWRNORM | POLLWRBAND); - } - - if (FD_ISSET(fds[i].fd, &exceptfds)) { - fds[i].revents |= fds[i].events & (POLLPRI | POLLRDBAND); - } - - if (fds[i].revents & ~POLLHUP) { - rc++; - } - } else { - fds[i].revents = POLLNVAL; + if (fds == NULL) { + errno = EFAULT; + return -1; } - return rc; + FD_ZERO(&readfds); + FD_ZERO(&writefds); + FD_ZERO(&exceptfds); + + /* compute fd_sets and find largest descriptor */ + for (rc = -1, max_fd = 0, i = 0; i < nfds; i++) { + if (fds[i].fd == SSH_INVALID_SOCKET) { + continue; + } +#ifndef _WIN32 + if (fds[i].fd >= FD_SETSIZE) { + rc = -1; + break; + } +#endif + + if (fds[i].events & (POLLIN | POLLRDNORM)) { + FD_SET (fds[i].fd, &readfds); + } + if (fds[i].events & (POLLOUT | POLLWRNORM | POLLWRBAND)) { + FD_SET (fds[i].fd, &writefds); + } + if (fds[i].events & (POLLPRI | POLLRDBAND)) { + FD_SET (fds[i].fd, &exceptfds); + } + if (fds[i].fd > max_fd && + (fds[i].events & (POLLIN | POLLOUT | POLLPRI | + POLLRDNORM | POLLRDBAND | + POLLWRNORM | POLLWRBAND))) { + max_fd = fds[i].fd; + rc = 0; + } + } + + if (max_fd == SSH_INVALID_SOCKET || rc == -1) { + errno = EINVAL; + return -1; + } + + if (timeout < 0) { + ptv = NULL; + } else { + ptv = &tv; + if (timeout == 0) { + tv.tv_sec = 0; + tv.tv_usec = 0; + } else { + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout % 1000) * 1000; + } + } + + rc = select(max_fd + 1, &readfds, &writefds, &exceptfds, ptv); + if (rc < 0) { + return -1; + } + + for (rc = 0, i = 0; i < nfds; i++) { + if (fds[i].fd >= 0) { + fds[i].revents = 0; + + if (FD_ISSET(fds[i].fd, &readfds)) { + int save_errno = errno; + char data[64] = {0}; + int ret; + + /* support for POLLHUP */ + ret = recv(fds[i].fd, data, 64, MSG_PEEK); + if ((ret == -1) && bsd_socket_disconnected(errno)) { + fds[i].revents |= POLLHUP; + } else { + fds[i].revents |= fds[i].events & (POLLIN | POLLRDNORM); + } + + errno = save_errno; + } + if (FD_ISSET(fds[i].fd, &writefds)) { + fds[i].revents |= fds[i].events & (POLLOUT | POLLWRNORM | POLLWRBAND); + } + + if (FD_ISSET(fds[i].fd, &exceptfds)) { + fds[i].revents |= fds[i].events & (POLLPRI | POLLRDBAND); + } + + if (fds[i].revents & ~POLLHUP) { + rc++; + } + } else { + fds[i].revents = POLLNVAL; + } + } + + return rc; } void ssh_poll_init(void) {