diff --git a/src/auth.c b/src/auth.c index 4e08c3e9..7122a577 100644 --- a/src/auth.c +++ b/src/auth.c @@ -63,23 +63,23 @@ * again is necessary */ static int ask_userauth(ssh_session session) { - int rc = 0; + int rc = 0; - enter_function(); - do { - rc=ssh_service_request(session,"ssh-userauth"); - if(ssh_is_blocking(session)){ - if(rc==SSH_AGAIN) - ssh_handle_packets(session,-1); - } else { - /* nonblocking */ - ssh_handle_packets(session,0); - rc=ssh_service_request(session,"ssh-userauth"); - break; - } - } while(rc==SSH_AGAIN); - leave_function(); - return rc; + enter_function(); + do { + rc = ssh_service_request(session,"ssh-userauth"); + if (ssh_is_blocking(session)) { + if(rc == SSH_AGAIN) + ssh_handle_packets(session, -2); + } else { + /* nonblocking */ + ssh_handle_packets(session, 0); + rc = ssh_service_request(session, "ssh-userauth"); + break; + } + } while (rc == SSH_AGAIN); + leave_function(); + return rc; } /** diff --git a/src/auth1.c b/src/auth1.c index 06f05497..8b96f8ca 100644 --- a/src/auth1.c +++ b/src/auth1.c @@ -38,7 +38,7 @@ static int wait_auth1_status(ssh_session session) { enter_function(); /* wait for a packet */ while(session->auth_state == SSH_AUTH_STATE_NONE) - if (ssh_handle_packets(session,-1) != SSH_OK) + if (ssh_handle_packets(session, -2) != SSH_OK) break; ssh_log(session,SSH_LOG_PROTOCOL,"Auth state : %d",session->auth_state); leave_function(); diff --git a/src/channels.c b/src/channels.c index ddccb97c..97d02ff0 100644 --- a/src/channels.c +++ b/src/channels.c @@ -296,7 +296,7 @@ static int channel_open(ssh_channel channel, const char *type_c, int window, /* Todo: fix this into a correct loop */ /* wait until channel is opened by server */ while(channel->state == SSH_CHANNEL_STATE_NOT_OPEN){ - ssh_handle_packets(session,-1); + ssh_handle_packets(session, -2); } if(channel->state == SSH_CHANNEL_STATE_OPEN) err=SSH_OK; @@ -1451,7 +1451,7 @@ static int channel_request(ssh_channel channel, const char *request, return SSH_OK; } while(channel->request_state == SSH_CHANNEL_REQ_STATE_PENDING){ - ssh_handle_packets(session,-1); + ssh_handle_packets(session, -2); if(session->session_state == SSH_SESSION_STATE_ERROR) { channel->request_state = SSH_CHANNEL_REQ_STATE_ERROR; break; @@ -1799,7 +1799,7 @@ static ssh_channel ssh_channel_accept(ssh_session session, int channeltype, for (t = timeout_ms; t >= 0; t -= 50) { - ssh_handle_packets(session,50); + ssh_handle_packets(session, 50); if (session->ssh_message_list) { iterator = ssh_list_get_iterator(session->ssh_message_list); @@ -1956,7 +1956,7 @@ static int global_request(ssh_session session, const char *request, return SSH_OK; } while(session->global_req_state == SSH_CHANNEL_REQ_STATE_PENDING){ - rc=ssh_handle_packets(session,-1); + rc=ssh_handle_packets(session, -2); if(rc==SSH_ERROR){ session->global_req_state = SSH_CHANNEL_REQ_STATE_ERROR; break; @@ -2372,7 +2372,7 @@ int channel_read_buffer(ssh_channel channel, ssh_buffer buffer, uint32_t count, leave_function(); return 0; } - ssh_handle_packets(channel->session, -1); + ssh_handle_packets(channel->session, -2); } while (r == 0); } while(total < count){ @@ -2480,7 +2480,7 @@ int ssh_channel_read(ssh_channel channel, void *dest, uint32_t count, int is_std break; } - ssh_handle_packets(session,-1); + ssh_handle_packets(session, -2); } len = buffer_get_rest_len(stdbuf); @@ -2585,7 +2585,7 @@ int ssh_channel_poll(ssh_channel channel, int is_stderr){ } if (buffer_get_rest_len(stdbuf) == 0 && channel->remote_eof == 0) { - if (ssh_handle_packets(channel->session,0)==SSH_ERROR) { + if (ssh_handle_packets(channel->session, 0)==SSH_ERROR) { leave_function(); return SSH_ERROR; } @@ -2640,7 +2640,7 @@ int ssh_channel_get_exit_status(ssh_channel channel) { while ((channel->remote_eof == 0 || channel->exit_status == -1) && channel->session->alive) { /* Parse every incoming packet */ - if (ssh_handle_packets(channel->session,-1) != SSH_OK) { + if (ssh_handle_packets(channel->session, -2) != SSH_OK) { return -1; } /* XXX We should actually wait for a close packet and not a close @@ -2675,7 +2675,7 @@ static int channel_protocol_select(ssh_channel *rchans, ssh_channel *wchans, chan = rchans[i]; while (ssh_channel_is_open(chan) && ssh_socket_data_available(chan->session->socket)) { - ssh_handle_packets(chan->session,-1); + ssh_handle_packets(chan->session, -2); } if ((chan->stdout_buffer && buffer_get_rest_len(chan->stdout_buffer) > 0) || diff --git a/src/kex.c b/src/kex.c index 592ca092..21f2cad6 100644 --- a/src/kex.c +++ b/src/kex.c @@ -826,14 +826,14 @@ int ssh_get_kex1(ssh_session session) { ssh_log(session, SSH_LOG_PROTOCOL, "Waiting for a SSH_SMSG_PUBLIC_KEY"); /* Here the callback is called */ while(session->session_state==SSH_SESSION_STATE_INITIAL_KEX){ - ssh_handle_packets(session,-1); + ssh_handle_packets(session, -2); } if(session->session_state==SSH_SESSION_STATE_ERROR) goto error; ssh_log(session, SSH_LOG_PROTOCOL, "Waiting for a SSH_SMSG_SUCCESS"); /* Waiting for SSH_SMSG_SUCCESS */ while(session->session_state==SSH_SESSION_STATE_KEXINIT_RECEIVED){ - ssh_handle_packets(session,-1); + ssh_handle_packets(session, -2); } if(session->session_state==SSH_SESSION_STATE_ERROR) goto error; diff --git a/src/messages.c b/src/messages.c index bbd9d1f3..30a0fd77 100644 --- a/src/messages.c +++ b/src/messages.c @@ -183,7 +183,7 @@ ssh_message ssh_message_get(ssh_session session) { session->ssh_message_list = ssh_list_new(); } do { - if (ssh_handle_packets(session,-1) == SSH_ERROR) { + if (ssh_handle_packets(session, -2) == SSH_ERROR) { leave_function(); return NULL; } diff --git a/src/server.c b/src/server.c index 9877ed86..6ee16e0f 100644 --- a/src/server.c +++ b/src/server.c @@ -479,7 +479,7 @@ int ssh_handle_key_exchange(ssh_session session) { * loop until SSH_SESSION_STATE_BANNER_RECEIVED or * SSH_SESSION_STATE_ERROR */ - ssh_handle_packets(session,-1); + ssh_handle_packets(session, -2); ssh_log(session,SSH_LOG_PACKET, "ssh_handle_key_exchange: Actual state : %d", session->session_state); } diff --git a/src/session.c b/src/session.c index ab12e0db..5aa6cbb0 100644 --- a/src/session.c +++ b/src/session.c @@ -404,6 +404,17 @@ void ssh_set_fd_except(ssh_session session) { ssh_socket_set_except(session->socket); } +static int ssh_make_milliseconds(long sec, long usec) { + int res = usec ? (usec / 1000) : 0; + res += (sec * 1000); + if (res == 0) { + res = 10 * 1000; /* use a reasonable default value in case + * SSH_OPTIONS_TIMEOUT is not set in options. */ + } + + return res; +} + /** * @internal * @@ -415,38 +426,56 @@ void ssh_set_fd_except(ssh_session session) { * @param[in] session The session handle to use. * * @param[in] timeout Set an upper limit on the time for which this function - * will block, in milliseconds. Specifying a negative value - * means an infinite timeout. This parameter is passed to - * the poll() function. + * will block, in milliseconds. Specifying -1 + * means an infinite timeout. + * Specifying -2 means to use the timeout specified in + * options. 0 means poll will return immediately. This + * parameter is passed to the poll() function. * * @return SSH_OK on success, SSH_ERROR otherwise. */ int ssh_handle_packets(ssh_session session, int timeout) { - ssh_poll_handle spoll_in,spoll_out; - ssh_poll_ctx ctx; - int rc; - if(session==NULL || session->socket==NULL) - return SSH_ERROR; - enter_function(); - spoll_in=ssh_socket_get_poll_handle_in(session->socket); - spoll_out=ssh_socket_get_poll_handle_out(session->socket); - if(session->server) - ssh_poll_add_events(spoll_in, POLLIN); - ctx=ssh_poll_get_ctx(spoll_in); - if(ctx==NULL){ - ctx=ssh_poll_get_default_ctx(session); - ssh_poll_ctx_add(ctx,spoll_in); - if(spoll_in != spoll_out) - ssh_poll_ctx_add(ctx,spoll_out); - } - rc = ssh_poll_ctx_dopoll(ctx,timeout); - if(rc == SSH_ERROR) - session->session_state = SSH_SESSION_STATE_ERROR; - leave_function(); - if (session->session_state != SSH_SESSION_STATE_ERROR) - return SSH_OK; - else - return SSH_ERROR; + ssh_poll_handle spoll_in,spoll_out; + ssh_poll_ctx ctx; + int tm = timeout; + int rc; + + if (session == NULL || session->socket == NULL) { + return SSH_ERROR; + } + enter_function(); + + spoll_in = ssh_socket_get_poll_handle_in(session->socket); + spoll_out = ssh_socket_get_poll_handle_out(session->socket); + if (session->server) { + ssh_poll_add_events(spoll_in, POLLIN); + } + ctx = ssh_poll_get_ctx(spoll_in); + + if (!ctx) { + ctx = ssh_poll_get_default_ctx(session); + ssh_poll_ctx_add(ctx, spoll_in); + if (spoll_in != spoll_out) { + ssh_poll_ctx_add(ctx, spoll_out); + } + } + + if (timeout == -2) { + tm = ssh_make_milliseconds(session->timeout, session->timeout_usec); + } + rc = ssh_poll_ctx_dopoll(ctx, tm); + + if (rc == SSH_ERROR) { + session->session_state = SSH_SESSION_STATE_ERROR; + } + + leave_function(); + + if (session->session_state == SSH_SESSION_STATE_ERROR) { + return SSH_ERROR; + } + + return SSH_OK; } /**