From e0985fc0b41b5a5006e3f74f42a7a5b702dce098 Mon Sep 17 00:00:00 2001 From: omsheladia Date: Sat, 2 Apr 2022 01:57:46 +0530 Subject: [PATCH] client: Add ssh_session_set_disconnect_message() Fix #98 by adding 'ssh_session_set_disconnect_message' Whenever the ssh session disconnects a "Bye Bye" message was set and displayed. Now the peer has a choice to set a customised message to be sent after the session is disconnected. The default "Bye Bye" will be set if this function is not called or not called correctly. The testcases in tests/server/torture_server can also demonstrate how this function works. Signed-off-by: Om Sheladia Reviewed-by: Jakub Jelen Reviewed-by: Andreas Schneider --- include/libssh/libssh.h | 1 + include/libssh/session.h | 1 + src/client.c | 44 +++++++++++++++++++++- src/session.c | 1 + tests/server/torture_server.c | 69 +++++++++++++++++++++++++++++++++++ 5 files changed, 115 insertions(+), 1 deletion(-) diff --git a/include/libssh/libssh.h b/include/libssh/libssh.h index 6a034db9..7857a77b 100644 --- a/include/libssh/libssh.h +++ b/include/libssh/libssh.h @@ -841,6 +841,7 @@ LIBSSH_API int ssh_buffer_add_data(ssh_buffer buffer, const void *data, uint32_t LIBSSH_API uint32_t ssh_buffer_get_data(ssh_buffer buffer, void *data, uint32_t requestedlen); LIBSSH_API void *ssh_buffer_get(ssh_buffer buffer); LIBSSH_API uint32_t ssh_buffer_get_len(ssh_buffer buffer); +LIBSSH_API int ssh_session_set_disconnect_message(ssh_session session, const char *message); #ifndef LIBSSH_LEGACY_0_4 #include "libssh/legacy.h" diff --git a/include/libssh/session.h b/include/libssh/session.h index 55a10e48..0a6fb080 100644 --- a/include/libssh/session.h +++ b/include/libssh/session.h @@ -136,6 +136,7 @@ struct ssh_session_struct { the server */ char *discon_msg; /* disconnect message from the remote host */ + char *disconnect_message; /* disconnect message to be set */ ssh_buffer in_buffer; PACKET in_packet; ssh_buffer out_buffer; diff --git a/src/client.c b/src/client.c index 5ed893b5..e958a523 100644 --- a/src/client.c +++ b/src/client.c @@ -690,6 +690,39 @@ int ssh_get_openssh_version(ssh_session session) return session->openssh; } +/** + * @brief Add disconnect message when ssh_session is disconnected + * To add a disconnect message to give peer a better hint. + * @param session The SSH session to use. + * @param message The message to send after the session is disconnected. + * If no message is passed then a default message i.e + * "Bye Bye" will be sent. + */ +int +ssh_session_set_disconnect_message(ssh_session session, const char *message) +{ + if (session == NULL) { + return SSH_ERROR; + } + + if (message == NULL || strlen(message) == 0) { + SAFE_FREE(session->disconnect_message); //To free any message set earlier. + session->disconnect_message = strdup("Bye Bye") ; + if (session->disconnect_message == NULL) { + ssh_set_error_oom(session); + return SSH_ERROR; + } + return SSH_OK; + } + SAFE_FREE(session->disconnect_message); //To free any message set earlier. + session->disconnect_message = strdup(message); + if (session->disconnect_message == NULL) { + ssh_set_error_oom(session); + return SSH_ERROR; + } + return SSH_OK; +} + /** * @brief Disconnect from a session (client or server). @@ -712,12 +745,20 @@ ssh_disconnect(ssh_session session) return; } + if (session->disconnect_message == NULL) { + session->disconnect_message = strdup("Bye Bye") ; + if (session->disconnect_message == NULL) { + ssh_set_error_oom(session); + goto error; + } + } + if (session->socket != NULL && ssh_socket_is_open(session->socket)) { rc = ssh_buffer_pack(session->out_buffer, "bdss", SSH2_MSG_DISCONNECT, SSH2_DISCONNECT_BY_APPLICATION, - "Bye Bye", + session->disconnect_message, ""); /* language tag */ if (rc != SSH_OK) { ssh_set_error_oom(session); @@ -772,6 +813,7 @@ error: session->auth.supported_methods = 0; SAFE_FREE(session->serverbanner); SAFE_FREE(session->clientbanner); + SAFE_FREE(session->disconnect_message); if (session->ssh_message_list) { ssh_message msg = NULL; diff --git a/src/session.c b/src/session.c index 7eacc925..61b9720a 100644 --- a/src/session.c +++ b/src/session.c @@ -299,6 +299,7 @@ void ssh_free(ssh_session session) SAFE_FREE(session->serverbanner); SAFE_FREE(session->clientbanner); SAFE_FREE(session->banner); + SAFE_FREE(session->disconnect_message); SAFE_FREE(session->opts.bindaddr); SAFE_FREE(session->opts.custombanner); diff --git a/tests/server/torture_server.c b/tests/server/torture_server.c index fecf86c9..c21f3612 100644 --- a/tests/server/torture_server.c +++ b/tests/server/torture_server.c @@ -370,6 +370,66 @@ static void torture_server_unknown_global_request(void **state) ssh_channel_close(channel); } +static void torture_server_set_disconnect_message(void **state) +{ + struct test_server_st *tss = *state; + struct torture_state *s = NULL; + ssh_session session; + int rc; + const char *message = "Goodbye"; + + assert_non_null(tss); + + s = tss->state; + assert_non_null(s); + + session = s->ssh.session; + assert_non_null(session); + + rc = ssh_session_set_disconnect_message(session,message); + assert_ssh_return_code(session, rc); + assert_string_equal(session->disconnect_message,message); +} + +static void torture_null_server_set_disconnect_message(void **state) +{ + struct test_server_st *tss = *state; + struct torture_state *s = NULL; + ssh_session session; + int rc; + + assert_non_null(tss); + + s = tss->state; + assert_non_null(s); + + session = s->ssh.session; + assert_non_null(session); + + rc = ssh_session_set_disconnect_message(NULL,"Goodbye"); + assert_int_equal(rc, SSH_ERROR); +} + +static void torture_server_set_null_disconnect_message(void **state) +{ + struct test_server_st *tss = *state; + struct torture_state *s = NULL; + ssh_session session; + int rc; + + assert_non_null(tss); + + s = tss->state; + assert_non_null(s); + + session = s->ssh.session; + assert_non_null(session); + + rc = ssh_session_set_disconnect_message(session,NULL); + assert_int_equal(rc, SSH_OK); + assert_string_equal(session->disconnect_message,"Bye Bye"); +} + int torture_run_tests(void) { int rc; struct CMUnitTest tests[] = { @@ -388,6 +448,15 @@ int torture_run_tests(void) { cmocka_unit_test_setup_teardown(torture_server_unknown_global_request, session_setup, session_teardown), + cmocka_unit_test_setup_teardown(torture_server_set_disconnect_message, + session_setup, + session_teardown), + cmocka_unit_test_setup_teardown(torture_null_server_set_disconnect_message, + session_setup, + session_teardown), + cmocka_unit_test_setup_teardown(torture_server_set_null_disconnect_message, + session_setup, + session_teardown), }; ssh_init();