1
1

connector: Add checks if file descriptor is a socket

Fixes T104

Signed-off-by: David Wedderwille <davidwe@posteo.de>
Этот коммит содержится в:
David Wedderwille 2018-09-25 11:58:32 +02:00 коммит произвёл Andreas Schneider
родитель 1e5e09563a
Коммит 9adc2d36eb

Просмотреть файл

@ -26,6 +26,10 @@
#include "libssh/callbacks.h"
#include "libssh/session.h"
#include <stdlib.h>
#include <errno.h>
#include <stdbool.h>
#include <sys/stat.h>
#define CHUNKSIZE 4096
#ifdef _WIN32
@ -40,6 +44,9 @@
# undef unlink
# define unlink _unlink
# endif /* HAVE_IO_H */
#else
# include <sys/types.h>
# include <sys/socket.h>
#endif
struct ssh_connector_struct {
@ -51,6 +58,8 @@ struct ssh_connector_struct {
socket_t in_fd;
socket_t out_fd;
bool fd_is_socket;
ssh_poll_handle in_poll;
ssh_poll_handle out_poll;
@ -76,6 +85,13 @@ static int ssh_connector_channel_write_wontblock_cb(ssh_session session,
ssh_channel channel,
size_t bytes,
void *userdata);
static ssize_t ssh_connector_fd_read(ssh_connector connector,
void *buffer,
uint32_t len);
static ssize_t ssh_connector_fd_write(ssh_connector connector,
const void *buffer,
uint32_t len);
static bool ssh_connector_fd_is_socket(socket_t socket);
ssh_connector ssh_connector_new(ssh_session session)
{
@ -91,6 +107,8 @@ ssh_connector ssh_connector_new(ssh_session session)
connector->in_fd = SSH_INVALID_SOCKET;
connector->out_fd = SSH_INVALID_SOCKET;
connector->fd_is_socket = false;
ssh_callbacks_init(&connector->in_channel_cb);
ssh_callbacks_init(&connector->out_channel_cb);
@ -167,12 +185,14 @@ int ssh_connector_set_out_channel(ssh_connector connector,
void ssh_connector_set_in_fd(ssh_connector connector, socket_t fd)
{
connector->in_fd = fd;
connector->fd_is_socket = ssh_connector_fd_is_socket(fd);
connector->in_channel = NULL;
}
void ssh_connector_set_out_fd(ssh_connector connector, socket_t fd)
{
connector->out_fd = fd;
connector->fd_is_socket = ssh_connector_fd_is_socket(fd);
connector->out_channel = NULL;
}
@ -223,9 +243,9 @@ static void ssh_connector_reset_pollevents(ssh_connector connector)
static void ssh_connector_fd_in_cb(ssh_connector connector)
{
unsigned char buffer[CHUNKSIZE];
int r;
int toread = CHUNKSIZE;
int w;
uint32_t toread = CHUNKSIZE;
ssize_t r;
ssize_t w;
int total = 0;
int rc;
@ -239,7 +259,7 @@ static void ssh_connector_fd_in_cb(ssh_connector connector)
toread = MIN(size, CHUNKSIZE);
}
r = read(connector->in_fd, buffer, toread);
r = ssh_connector_fd_read(connector, buffer, toread);
if (r < 0) {
ssh_connector_except(connector, connector->in_fd);
return;
@ -277,7 +297,7 @@ static void ssh_connector_fd_in_cb(ssh_connector connector)
* bytes
*/
while (total != r) {
w = write(connector->out_fd, buffer + total, r - total);
w = ssh_connector_fd_write(connector, buffer + total, r - total);
if (w < 0){
ssh_connector_except(connector, connector->out_fd);
return;
@ -319,7 +339,7 @@ static void ssh_connector_fd_out_cb(ssh_connector connector){
} else if(r>0) {
/* loop around write in case the write blocks even for CHUNKSIZE bytes */
while (total != r){
w = write(connector->out_fd, buffer + total, r - total);
w = ssh_connector_fd_write(connector, buffer + total, r - total);
if (w < 0){
ssh_connector_except(connector, connector->out_fd);
return;
@ -451,7 +471,7 @@ static int ssh_connector_channel_data_cb(ssh_session session,
ssh_connector_except_channel(connector, connector->out_channel);
}
} else if (connector->out_fd != SSH_INVALID_SOCKET) {
w = write(connector->out_fd, data, len);
w = ssh_connector_fd_write(connector, data, len);
if (w < 0)
ssh_connector_except(connector, connector->out_fd);
} else {
@ -634,3 +654,96 @@ int ssh_connector_remove_event(ssh_connector connector) {
return SSH_OK;
}
/**
* @internal
*
* @brief Check the file descriptor to check if it is a Windows socket handle.
*
*/
static bool ssh_connector_fd_is_socket(socket_t s)
{
#ifdef _WIN32
struct sockaddr_storage ss;
int len = sizeof(struct sockaddr_storage);
int rc;
rc = getsockname(s, (struct sockaddr *)&ss, &len);
if (rc == 0) {
return true;
}
SSH_LOG(SSH_LOG_TRACE,
"Error %i in getsockname() for fd %d",
WSAGetLastError(),
s);
return false;
#else
struct stat sb;
int rc;
rc = fstat(s, &sb);
if (rc != 0) {
SSH_LOG(SSH_LOG_TRACE,
"error %i in fstat() for fd %d",
errno,
s);
return false;
}
/* The descriptor is a socket */
if (S_ISSOCK(sb.st_mode)) {
return true;
}
return false;
#endif /* _WIN32 */
}
/**
* @internal
*
* @brief read len bytes from socket into buffer
*
*/
static ssize_t ssh_connector_fd_read(ssh_connector connector,
void *buffer,
uint32_t len)
{
ssize_t nread = -1;
if (connector->fd_is_socket) {
nread = recv(connector->in_fd,buffer, len, 0);
} else {
nread = read(connector->in_fd,buffer, len);
}
return nread;
}
/**
* @internal
*
* @brief brief writes len bytes from buffer to socket
*
*/
static ssize_t ssh_connector_fd_write(ssh_connector connector,
const void *buffer,
uint32_t len)
{
ssize_t bwritten = -1;
int flags = 0;
#ifdef MSG_NOSIGNAL
flags |= MSG_NOSIGNAL;
#endif
if (connector->fd_is_socket) {
bwritten = send(connector->out_fd,buffer, len, flags);
} else {
bwritten = write(connector->out_fd, buffer, len);
}
return bwritten;
}