diff --git a/tests/server/test_server/default_cb.c b/tests/server/test_server/default_cb.c index 1e60f0cb..3f3a3cd7 100644 --- a/tests/server/test_server/default_cb.c +++ b/tests/server/test_server/default_cb.c @@ -688,6 +688,58 @@ static int process_stderr(socket_t fd, int revents, void *userdata) return n; } +/* The caller is responsible to set the userdata to be provided to the callback + * The caller is responsible to free the allocated structure + * */ +struct ssh_server_callbacks_struct *get_default_server_cb(void) +{ + + struct ssh_server_callbacks_struct *cb; + + cb = (struct ssh_server_callbacks_struct *)calloc(1, + sizeof(struct ssh_server_callbacks_struct)); + + if (cb == NULL) { + fprintf(stderr, "Out of memory\n"); + goto end; + } + + cb->auth_password_function = auth_password_cb; + cb->channel_open_request_session_function = channel_new_session_cb; +#if WITH_GSSAPI + cb->auth_gssapi_mic_function = auth_gssapi_mic_cb; +#endif + +end: + return cb; +} + +/* The caller is responsible to set the userdata to be provided to the callback + * The caller is responsible to free the allocated structure + * */ +struct ssh_channel_callbacks_struct *get_default_channel_cb(void) +{ + struct ssh_channel_callbacks_struct *cb; + + cb = (struct ssh_channel_callbacks_struct *)calloc(1, + sizeof(struct ssh_channel_callbacks_struct)); + if (cb == NULL) { + fprintf(stderr, "Out of memory\n"); + goto end; + } + + cb->channel_pty_request_function = channel_pty_request_cb; + cb->channel_pty_window_change_function = channel_pty_resize_cb; + cb->channel_shell_request_function = channel_shell_request_cb; + cb->channel_env_request_function = channel_env_request_cb; + cb->channel_subsystem_request_function = channel_subsystem_request_cb; + cb->channel_exec_request_function = channel_exec_request_cb; + cb->channel_data_function = channel_data_cb; + +end: + return cb; +}; + void default_handle_session_cb(ssh_event event, ssh_session session, struct server_state_st *state) @@ -724,31 +776,40 @@ void default_handle_session_cb(ssh_event event, .password = SSHD_DEFAULT_PASSWORD }; - struct ssh_channel_callbacks_struct default_channel_cb = { - .userdata = &cdata, - .channel_pty_request_function = channel_pty_request_cb, - .channel_pty_window_change_function = channel_pty_resize_cb, - .channel_shell_request_function = channel_shell_request_cb, - .channel_env_request_function = channel_env_request_cb, - .channel_subsystem_request_function = channel_subsystem_request_cb, - .channel_exec_request_function = channel_exec_request_cb, - .channel_data_function = channel_data_cb - }; - - struct ssh_server_callbacks_struct default_server_cb = { - .userdata = &sdata, - .auth_password_function = auth_password_cb, - .channel_open_request_session_function = channel_new_session_cb, -#if WITH_GSSAPI - .auth_gssapi_mic_function = auth_gssapi_mic_cb -#endif - }; + struct ssh_channel_callbacks_struct *channel_cb = NULL; + struct ssh_server_callbacks_struct *server_cb = NULL; if (state == NULL) { fprintf(stderr, "NULL server state provided\n"); goto end; } + /* If callbacks were provided use them. Otherwise, use default callbacks */ + if (state->server_cb != NULL) { + /* This is a macro, it does not return a value */ + ssh_callbacks_init(state->server_cb); + + rc = ssh_set_server_callbacks(session, state->server_cb); + if (rc) { + goto end; + } + } else { + server_cb = get_default_server_cb(); + if (server_cb == NULL) { + goto end; + } + + server_cb->userdata = &sdata; + + /* This is a macro, it does not return a value */ + ssh_callbacks_init(server_cb); + + rc = ssh_set_server_callbacks(session, server_cb); + if (rc) { + goto end; + } + } + sdata.server_state = (void *)state; cdata.server_state = (void *)state; @@ -764,17 +825,6 @@ void default_handle_session_cb(ssh_event event, sdata.password = state->expected_password; } - /* If callbacks were provided use them. Otherwise, use default callbacks */ - if (state->server_cb != NULL) { - /* TODO check return values */ - ssh_callbacks_init(state->server_cb); - ssh_set_server_callbacks(session, state->server_cb); - } else { - /* TODO check return values */ - ssh_callbacks_init(&default_server_cb); - ssh_set_server_callbacks(session, &default_server_cb); - } - if (ssh_handle_key_exchange(session) != SSH_OK) { fprintf(stderr, "%s\n", ssh_get_error(session)); return; @@ -807,10 +857,24 @@ void default_handle_session_cb(ssh_event event, /* TODO check return values */ if (state->channel_cb != NULL) { ssh_callbacks_init(state->channel_cb); - ssh_set_channel_callbacks(sdata.channel, state->channel_cb); + + rc = ssh_set_channel_callbacks(sdata.channel, state->channel_cb); + if (rc) { + goto end; + } } else { - ssh_callbacks_init(&default_channel_cb); - ssh_set_channel_callbacks(sdata.channel, &default_channel_cb); + channel_cb = get_default_channel_cb(); + if (channel_cb == NULL) { + goto end; + } + + channel_cb->userdata = &cdata; + + ssh_callbacks_init(channel_cb); + rc = ssh_set_channel_callbacks(sdata.channel, channel_cb); + if (rc) { + goto end; + } } do { @@ -879,5 +943,11 @@ end: #ifdef WITH_PCAP cleanup_pcap(&sdata); #endif + if (channel_cb != NULL) { + free(channel_cb); + } + if (server_cb != NULL) { + free(server_cb); + } return; } diff --git a/tests/server/test_server/default_cb.h b/tests/server/test_server/default_cb.h index 0db81559..487794c0 100644 --- a/tests/server/test_server/default_cb.h +++ b/tests/server/test_server/default_cb.h @@ -162,7 +162,15 @@ int channel_write_wontblock_cb(ssh_session session, ssh_channel channel_new_session_cb(ssh_session session, void *userdata); +/* The caller is responsible to set the userdata to be provided to the callback + * The caller is responsible to free the allocated structure + * */ struct ssh_server_callbacks_struct *get_default_server_cb(void); +/* The caller is responsible to set the userdata to be provided to the callback + * The caller is responsible to free the allocated structure + * */ +struct ssh_channel_callbacks_struct *get_default_channel_cb(void); + void default_handle_session_cb(ssh_event event, ssh_session session, struct server_state_st *state);