diff --git a/tests/unittests/CMakeLists.txt b/tests/unittests/CMakeLists.txt index 97234fd0..5c3b723f 100644 --- a/tests/unittests/CMakeLists.txt +++ b/tests/unittests/CMakeLists.txt @@ -15,6 +15,7 @@ set(LIBSSH_UNIT_TESTS torture_isipaddr torture_knownhosts_parsing torture_hashes + torture_packet_filter ) set(LIBSSH_THREAD_UNIT_TESTS diff --git a/tests/unittests/torture_packet_filter.c b/tests/unittests/torture_packet_filter.c new file mode 100644 index 00000000..72cbc4cd --- /dev/null +++ b/tests/unittests/torture_packet_filter.c @@ -0,0 +1,502 @@ +/* + * This file is part of the SSH Library + * + * Copyright (c) 2018 by Anderson Toshiyuki Sasaki + * + * The SSH Library is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation; either version 2.1 of the License, or (at your + * option) any later version. + * + * The SSH Library is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public + * License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with the SSH Library; see the file COPYING. If not, write to + * the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, + * MA 02111-1307, USA. + */ + +/* + * This test checks if the messages accepted by the packet filter were intented + * to be accepted. + * + * The process consists in 2 steps: + * - Try the filter with a message type in an arbitrary state + * - If the message is accepted by the filter, check if the message is in the + * set of accepted states. + * + * Only the values selected by the flag (COMPARE_*) are considered. + * */ + +#include "config.h" + +#define LIBSSH_STATIC + +#include "torture.h" +#include "libssh/priv.h" +#include "libssh/libssh.h" +#include "libssh/session.h" +#include "libssh/auth.h" +#include "libssh/ssh2.h" +#include "libssh/packet.h" + +#include "packet.c" + +#define COMPARE_SESSION_STATE 1 +#define COMPARE_ROLE (1 << 1) +#define COMPARE_DH_STATE (1 << 2) +#define COMPARE_AUTH_STATE (1 << 3) +#define COMPARE_GLOBAL_REQ_STATE (1 << 4) +#define COMPARE_CURRENT_METHOD (1 << 5) + +#define SESSION_STATE_COUNT 11 +#define DH_STATE_COUNT 4 +#define AUTH_STATE_COUNT 15 +#define GLOBAL_REQ_STATE_COUNT 5 +#define MESSAGE_COUNT 100 // from 1 to 100 + +#define ROLE_CLIENT 0 +#define ROLE_SERVER 1 + +/* + * This is the list of currently unfiltered message types. + * Only unrecognized types should be in this list. + * */ +static uint8_t unfiltered[] = { + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 22, 23, 24, 25, 26, 27, 28, 29, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + 54, 55, 56, 57, 58, 59, + 62, + 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 83, 84, 85, 86, 87, 88, 89, +}; + +typedef struct global_state_st { + /* If the bit in this flag is zero, the corresponding state is not + * considered, working as a wildcard (meaning any value is accepted) */ + uint32_t flags; + uint8_t role; + enum ssh_session_state_e session; + enum ssh_dh_state_e dh; + enum ssh_auth_state_e auth; + enum ssh_channel_request_state_e global_req; +} global_state; + +static int cmp_state(const void *e1, const void *e2) +{ + global_state *s1 = (global_state *) e1; + global_state *s2 = (global_state *) e2; + + /* Compare role (client == 0 or server == 1)*/ + if (s1->role < s2->role) { + return -1; + } + else if (s1->role > s2->role) { + return 1; + } + + /* Compare session state */ + if (s1->session < s2->session) { + return -1; + } + else if (s1->session > s2->session) { + return 1; + } + + /* Compare DH state */ + if (s1->dh < s2->dh) { + return -1; + } + else if (s1->dh > s2->dh) { + return 1; + } + + /* Compare auth */ + if (s1->auth < s2->auth) { + return -1; + } + else if (s1->auth > s2->auth) { + return 1; + } + + /* Compare global_req */ + if (s1->global_req < s2->global_req) { + return -1; + } + else if (s1->global_req > s2->global_req) { + return 1; + } + + /* If all equal, they are equal */ + return 0; +} + +static int cmp_state_search(const void *key, const void *array_element) +{ + global_state *s1 = (global_state *) key; + global_state *s2 = (global_state *) array_element; + + int result = 0; + + if (s2->flags & COMPARE_ROLE) { + /* Compare role (client == 0 or server == 1)*/ + if (s1->role < s2->role) { + return -1; + } + else if (s1->role > s2->role) { + return 1; + } + } + + if (s2->flags & COMPARE_SESSION_STATE) { + /* Compare session state */ + if (s1->session < s2->session) { + result = -1; + goto end; + } + else if (s1->session > s2->session) { + result = 1; + goto end; + } + } + + if (s2->flags & COMPARE_DH_STATE) { + /* Compare DH state */ + if (s1->dh < s2->dh) { + result = -1; + goto end; + } + else if (s1->dh > s2->dh) { + result = 1; + goto end; + } + } + + if (s2->flags & COMPARE_AUTH_STATE) { + /* Compare auth */ + if (s1->auth < s2->auth) { + result = -1; + goto end; + } + else if (s1->auth > s2->auth) { + result = 1; + goto end; + } + } + + if (s2->flags & COMPARE_GLOBAL_REQ_STATE) { + /* Compare global_req */ + if (s1->global_req < s2->global_req) { + result = -1; + goto end; + } + else if (s1->global_req > s2->global_req) { + result = 1; + goto end; + } + } + +end: + return result; +} + +static int is_state_accepted(global_state *tested, global_state *accepted, + int accepted_len) +{ + global_state *found = NULL; + + found = bsearch(tested, accepted, accepted_len, sizeof(global_state), + cmp_state_search); + + if (found != NULL) { + return 1; + } + + return 0; +} + +static int cmp_uint8(const void *i, const void *j) +{ + uint8_t e1 = *((uint8_t *)i); + uint8_t e2 = *((uint8_t *)j); + + if (e1 < e2) { + return -1; + } + else if (e1 > e2) { + return 1; + } + + return 0; +} + +static int check_unfiltered(uint8_t msg_type) +{ + uint8_t *found; + + found = bsearch(&msg_type, unfiltered, sizeof(unfiltered)/sizeof(uint8_t), + sizeof(uint8_t), cmp_uint8); + + if (found != NULL) { + return 1; + } + + return 0; +} + +static void torture_packet_filter_check_unfiltered(void **state) +{ + ssh_session session; + + int role_c; + int auth_c; + int session_c; + int dh_c; + int global_req_c; + + uint8_t msg_type; + + enum ssh_packet_filter_result_e rc; + int in_unfiltered; + + session = ssh_new(); + + for (msg_type = 1; msg_type <= MESSAGE_COUNT; msg_type++) { + session->in_packet.type = msg_type; + for (role_c = 0; role_c < 2; role_c++) { + session->server = role_c; + for (session_c = 0; session_c < SESSION_STATE_COUNT; session_c++) { + session->session_state = session_c; + for (dh_c = 0; dh_c < DH_STATE_COUNT; dh_c++) { + session->dh_handshake_state = dh_c; + for (auth_c = 0; auth_c < AUTH_STATE_COUNT; auth_c++) { + session->auth.state = auth_c; + for (global_req_c = 0; + global_req_c < GLOBAL_REQ_STATE_COUNT; + global_req_c++) + { + session->global_req_state = global_req_c; + + rc = ssh_packet_incoming_filter(session); + + if (rc == SSH_PACKET_UNKNOWN) { + in_unfiltered = check_unfiltered(msg_type); + + if (!in_unfiltered) { + fprintf(stderr, "Message type %d UNFILTERED " + "in state: role %d, session %d, dh %d, auth %d\n", + msg_type, role_c, session_c, dh_c, auth_c); + } + assert_int_equal(in_unfiltered, 1); + } + else { + in_unfiltered = check_unfiltered(msg_type); + + if (in_unfiltered) { + fprintf(stderr, "Message type %d NOT UNFILTERED " + "in state: role %d, session %d, dh %d, auth %d\n", + msg_type, role_c, session_c, dh_c, auth_c); + } + assert_int_equal(in_unfiltered, 0); + } + } + } + } + } + } + } + ssh_free(session); +} + +static int check_message_in_all_states(global_state accepted[], + int accepted_count, uint8_t msg_type) +{ + ssh_session session; + + int role_c; + int auth_c; + int session_c; + int dh_c; + int global_req_c; + + enum ssh_packet_filter_result_e rc; + int in_accepted; + + global_state key; + + session = ssh_new(); + + /* Sort the accepted array so that the elements can be searched using + * bsearch */ + qsort(accepted, accepted_count, sizeof(global_state), cmp_state); + + session->in_packet.type = msg_type; + + for (role_c = 0; role_c < 2; role_c++) { + session->server = role_c; + key.role = role_c; + for (session_c = 0; session_c < SESSION_STATE_COUNT; session_c++) { + session->session_state = session_c; + key.session = session_c; + for (dh_c = 0; dh_c < DH_STATE_COUNT; dh_c++) { + session->dh_handshake_state = dh_c; + key.dh = dh_c; + for (auth_c = 0; auth_c < AUTH_STATE_COUNT; auth_c++) { + session->auth.state = auth_c; + key.auth = auth_c; + for (global_req_c = 0; + global_req_c < GLOBAL_REQ_STATE_COUNT; + global_req_c++) + { + session->global_req_state = global_req_c; + key.global_req = global_req_c; + + rc = ssh_packet_incoming_filter(session); + + if (rc == SSH_PACKET_ALLOWED) { + in_accepted = is_state_accepted(&key, accepted, + accepted_count); + + if (!in_accepted) { + fprintf(stderr, "Message type %d ALLOWED " + "in state: role %d, session %d, dh %d, auth %d\n", + msg_type, role_c, session_c, dh_c, auth_c); + } + assert_int_equal(in_accepted, 1); + } + else if (rc == SSH_PACKET_DENIED) { + in_accepted = is_state_accepted(&key, accepted, accepted_count); + + if (in_accepted) { + fprintf(stderr, "Message type %d DENIED " + "in state: role %d, session %d, dh %d, auth %d\n", + msg_type, role_c, session_c, dh_c, auth_c); + } + assert_int_equal(in_accepted, 0); + } + else { + fprintf(stderr, "Message type %d UNFILTERED " + "in state: role %d, session %d, dh %d, auth %d\n", + msg_type, role_c, session_c, dh_c, auth_c); + } + } + } + } + } + } + + ssh_free(session); + return 0; +} + +static void torture_packet_filter_check_auth_success(void **state) +{ + int rc; + + global_state accepted[] = { + { + .flags = (COMPARE_SESSION_STATE | + COMPARE_ROLE | + COMPARE_AUTH_STATE | + COMPARE_DH_STATE), + .role = ROLE_CLIENT, + .session = SSH_SESSION_STATE_AUTHENTICATING, + .dh = DH_STATE_FINISHED, + .auth = SSH_AUTH_STATE_PUBKEY_AUTH_SENT, + }, + { + .flags = (COMPARE_SESSION_STATE | + COMPARE_ROLE | + COMPARE_AUTH_STATE | + COMPARE_DH_STATE), + .role = ROLE_CLIENT, + .session = SSH_SESSION_STATE_AUTHENTICATING, + .dh = DH_STATE_FINISHED, + .auth = SSH_AUTH_STATE_PASSWORD_AUTH_SENT, + }, + { + .flags = (COMPARE_SESSION_STATE | + COMPARE_ROLE | + COMPARE_AUTH_STATE | + COMPARE_DH_STATE), + .role = ROLE_CLIENT, + .session = SSH_SESSION_STATE_AUTHENTICATING, + .dh = DH_STATE_FINISHED, + .auth = SSH_AUTH_STATE_GSSAPI_MIC_SENT, + }, + { + .flags = (COMPARE_SESSION_STATE | + COMPARE_ROLE | + COMPARE_AUTH_STATE | + COMPARE_DH_STATE), + .role = ROLE_CLIENT, + .session = SSH_SESSION_STATE_AUTHENTICATING, + .dh = DH_STATE_FINISHED, + .auth = SSH_AUTH_STATE_KBDINT_SENT, + }, + { + .flags = (COMPARE_SESSION_STATE | + COMPARE_ROLE | + COMPARE_AUTH_STATE | + COMPARE_DH_STATE | + COMPARE_CURRENT_METHOD), + .role = ROLE_CLIENT, + .session = SSH_SESSION_STATE_AUTHENTICATING, + .dh = DH_STATE_FINISHED, + .auth = SSH_AUTH_STATE_AUTH_NONE_SENT, + } + }; + + int accepted_count = 5; + + /* Unused */ + (void) state; + + rc = check_message_in_all_states(accepted, accepted_count, + SSH2_MSG_USERAUTH_SUCCESS); + + assert_int_equal(rc, 0); +} + +static void torture_packet_filter_check_channel_open(void **state) +{ + int rc; + + /* The only condition to accept a CHANNEL_OPEN is to be authenticated */ + global_state accepted[] = { + { + .flags = COMPARE_SESSION_STATE, + .session = SSH_SESSION_STATE_AUTHENTICATED, + } + }; + + int accepted_count = 1; + + /* Unused */ + (void) state; + + rc = check_message_in_all_states(accepted, accepted_count, + SSH2_MSG_CHANNEL_OPEN); + + assert_int_equal(rc, 0); +} + +int torture_run_tests(void) +{ + int rc; + struct CMUnitTest tests[] = { + cmocka_unit_test(torture_packet_filter_check_auth_success), + cmocka_unit_test(torture_packet_filter_check_channel_open), + cmocka_unit_test(torture_packet_filter_check_unfiltered), + }; + + ssh_init(); + torture_filter_tests(tests); + rc = cmocka_run_group_tests(tests, NULL, NULL); + ssh_finalize(); + return rc; +}