diff --git a/include/libssh/server.h b/include/libssh/server.h index ec09f2f8..c2132de1 100644 --- a/include/libssh/server.h +++ b/include/libssh/server.h @@ -45,7 +45,8 @@ enum ssh_bind_options_e { SSH_BIND_OPTIONS_BANNER, SSH_BIND_OPTIONS_LOG_VERBOSITY, SSH_BIND_OPTIONS_LOG_VERBOSITY_STR, - SSH_BIND_OPTIONS_ECDSAKEY + SSH_BIND_OPTIONS_ECDSAKEY, + SSH_BIND_OPTIONS_IMPORT_KEY }; typedef struct ssh_bind_struct* ssh_bind; diff --git a/src/bind.c b/src/bind.c index 6b0fb238..fa5f8d57 100644 --- a/src/bind.c +++ b/src/bind.c @@ -235,9 +235,11 @@ int ssh_bind_listen(ssh_bind sshbind) { return -1; } - rc = ssh_bind_import_keys(sshbind); - if (rc != SSH_OK) { - return SSH_ERROR; + if (sshbind->rsa == NULL && sshbind->dsa == NULL && sshbind->ecdsa == NULL) { + rc = ssh_bind_import_keys(sshbind); + if (rc != SSH_OK) { + return SSH_ERROR; + } } if (sshbind->bindfd == SSH_INVALID_SOCKET) { @@ -430,9 +432,13 @@ int ssh_bind_accept_fd(ssh_bind sshbind, ssh_session session, socket_t fd){ * where keys can be imported) on this ssh_bind and are instead * only using ssh_bind_accept_fd to manage sockets ourselves. */ - rc = ssh_bind_import_keys(sshbind); - if (rc != SSH_OK) { - return SSH_ERROR; + if (sshbind->rsa == NULL && + sshbind->dsa == NULL && + sshbind->ecdsa == NULL) { + rc = ssh_bind_import_keys(sshbind); + if (rc != SSH_OK) { + return SSH_ERROR; + } } #ifdef HAVE_ECC diff --git a/src/options.c b/src/options.c index 3470a792..68c11053 100644 --- a/src/options.c +++ b/src/options.c @@ -1391,6 +1391,9 @@ static int ssh_bind_set_key(ssh_bind sshbind, char **key_loc, * - SSH_BIND_OPTIONS_BANNER: * Set the server banner sent to clients (const char *). * + * - SSH_BIND_OPTIONS_IMPORT_KEY: + * Set the Private Key for the server directly (ssh_key) + * * @param value The value to set. This is a generic pointer and the * datatype which should be used is described at the * corresponding value of type above. @@ -1469,6 +1472,48 @@ int ssh_bind_options_set(ssh_bind sshbind, enum ssh_bind_options_e type, *bind_key_loc = key; } break; + case SSH_BIND_OPTIONS_IMPORT_KEY: + if (value == NULL) { + ssh_set_error_invalid(sshbind); + return -1; + } else { + int key_type; + ssh_key *bind_key_loc = NULL; + ssh_key key = (ssh_key)value; + + key_type = ssh_key_type(key); + switch (key_type) { + case SSH_KEYTYPE_DSS: + bind_key_loc = &sshbind->dsa; + break; + case SSH_KEYTYPE_ECDSA: +#ifdef HAVE_ECC + bind_key_loc = &sshbind->ecdsa; +#else + ssh_set_error(sshbind, + SSH_FATAL, + "ECDSA key used and libssh compiled " + "without ECDSA support"); +#endif + break; + case SSH_KEYTYPE_RSA: + case SSH_KEYTYPE_RSA1: + bind_key_loc = &sshbind->rsa; + break; + case SSH_KEYTYPE_ED25519: + bind_key_loc = &sshbind->ed25519; + break; + default: + ssh_set_error(sshbind, + SSH_FATAL, + "Unsupported key type %d", key_type); + } + if (bind_key_loc == NULL) + return -1; + ssh_key_free(*bind_key_loc); + *bind_key_loc = key; + } + break; case SSH_BIND_OPTIONS_BINDADDR: if (value == NULL) { ssh_set_error_invalid(sshbind); diff --git a/tests/unittests/torture_options.c b/tests/unittests/torture_options.c index 4ec3f4cb..820e607d 100644 --- a/tests/unittests/torture_options.c +++ b/tests/unittests/torture_options.c @@ -199,6 +199,53 @@ static void torture_options_proxycommand(void **state) { assert_null(session->opts.ProxyCommand); } + +/* sshbind options */ +static int sshbind_setup(void **state) +{ + ssh_bind bind = ssh_bind_new(); + *state = bind; + return 0; +} + +static int sshbind_teardown(void **state) +{ + ssh_bind_free(*state); + return 0; +} + +static void torture_bind_options_import_key(void **state) +{ + ssh_bind bind = *state; + int rc; + ssh_key key = ssh_key_new(); + const char *base64_key; + + /* set null */ + rc = ssh_bind_options_set(bind, SSH_BIND_OPTIONS_IMPORT_KEY, NULL); + assert_int_equal(rc, -1); + /* set invalid key */ + rc = ssh_bind_options_set(bind, SSH_BIND_OPTIONS_IMPORT_KEY, key); + assert_int_equal(rc, -1); + + /* set rsa key */ + base64_key = torture_get_testkey(SSH_KEYTYPE_RSA, 0, 0); + ssh_pki_import_privkey_base64(base64_key, NULL, NULL, NULL, &key); + rc = ssh_bind_options_set(bind, SSH_BIND_OPTIONS_IMPORT_KEY, key); + assert_int_equal(rc, 0); + /* set dsa key */ + base64_key = torture_get_testkey(SSH_KEYTYPE_DSS, 0, 0); + ssh_pki_import_privkey_base64(base64_key, NULL, NULL, NULL, &key); + rc = ssh_bind_options_set(bind, SSH_BIND_OPTIONS_IMPORT_KEY, key); + assert_int_equal(rc, 0); + /* set ecdsa key */ + base64_key = torture_get_testkey(SSH_KEYTYPE_ECDSA, 512, 0); + ssh_pki_import_privkey_base64(base64_key, NULL, NULL, NULL, &key); + rc = ssh_bind_options_set(bind, SSH_BIND_OPTIONS_IMPORT_KEY, key); + assert_int_equal(rc, 0); +} + + int torture_run_tests(void) { int rc; struct CMUnitTest tests[] = { @@ -214,9 +261,14 @@ int torture_run_tests(void) { cmocka_unit_test_setup_teardown(torture_options_proxycommand, setup, teardown), }; + struct CMUnitTest sshbind_tests[] = { + cmocka_unit_test_setup_teardown(torture_bind_options_import_key, sshbind_setup, sshbind_teardown), + }; + ssh_init(); torture_filter_tests(tests); rc = cmocka_run_group_tests(tests, NULL, NULL); + rc += cmocka_run_group_tests(sshbind_tests, NULL, NULL); ssh_finalize(); return rc; }