Skip to content

Commit

Permalink
Move to tokio Mutex around libssh session.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tehforsch committed Sep 26, 2024
1 parent ee6d461 commit 16be56e
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 103 deletions.
200 changes: 106 additions & 94 deletions rust/src/nasl/builtin/ssh/libssh/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,63 +79,65 @@ impl Ssh {
///
/// nasl return An integer to identify the ssh session. Zero on error.
#[nasl_function(named(socket, port, keytype, csciphers, scciphers, timeout))]
fn nasl_ssh_connect(
async fn nasl_ssh_connect(
&mut self,
socket: Option<Socket>,
port: Option<u16>,
keytype: Option<&str>,
csciphers: Option<&str>,
scciphers: Option<&str>,
timeout: Option<u64>,
ctx: &Context,
ctx: &Context<'_>,
) -> Result<SessionId> {
let port = port.filter(|_| socket.is_none());
let ip_str: String = match ctx.target() {
x if !x.is_empty() => x.to_string(),
_ => "127.0.0.1".to_string(),
};

let session_id = self.add_new_session(|session| {
session.set_option(SshOption::LogLevel(get_log_level()))?;
session.set_option(SshOption::Hostname(ip_str.to_owned()))?;
session.set_option(SshOption::KnownHosts(Some("/dev/null".to_owned())))?;
if let Some(timeout) = timeout {
session.set_option(SshOption::Timeout(Duration::from_secs(timeout as u64)))?;
}
if let Some(keytype) = keytype {
session.set_option(SshOption::HostKeys(keytype.to_owned()))?;
}
if let Some(csciphers) = csciphers {
session.set_option(SshOption::CiphersCS(csciphers.to_owned()))?;
}
if let Some(scciphers) = scciphers {
session.set_option(SshOption::CiphersSC(scciphers.to_owned()))?;
}
if let Some(port) = port {
session.set_option(SshOption::Port(port))?;
}
let session_id = self
.add_new_session(|session| {
session.set_option(SshOption::LogLevel(get_log_level()))?;
session.set_option(SshOption::Hostname(ip_str.to_owned()))?;
session.set_option(SshOption::KnownHosts(Some("/dev/null".to_owned())))?;
if let Some(timeout) = timeout {
session.set_option(SshOption::Timeout(Duration::from_secs(timeout as u64)))?;
}
if let Some(keytype) = keytype {
session.set_option(SshOption::HostKeys(keytype.to_owned()))?;
}
if let Some(csciphers) = csciphers {
session.set_option(SshOption::CiphersCS(csciphers.to_owned()))?;
}
if let Some(scciphers) = scciphers {
session.set_option(SshOption::CiphersSC(scciphers.to_owned()))?;
}
if let Some(port) = port {
session.set_option(SshOption::Port(port))?;
}

if let Some(socket) = socket {
// This is a fake raw socket.
// TODO: implement openvas_get_socket_from_connection()
let my_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
if let Some(socket) = socket {
// This is a fake raw socket.
// TODO: implement openvas_get_socket_from_connection()
let my_sock = UdpSocket::bind("127.0.0.1:0").unwrap();
debug!(
ip_str = ip_str,
sock_fd = my_sock.as_raw_fd(),
nasl_sock = socket,
"Setting SSH fd for socket",
);
session.set_option(SshOption::Socket(my_sock.as_raw_fd()))?;
}
debug!(
ip_str = ip_str,
sock_fd = my_sock.as_raw_fd(),
nasl_sock = socket,
"Setting SSH fd for socket",
port = port,
socket = socket,
"Connecting to SSH server",
);
session.set_option(SshOption::Socket(my_sock.as_raw_fd()))?;
}
debug!(
ip_str = ip_str,
port = port,
socket = socket,
"Connecting to SSH server",
);
session.connect()?;
Ok(())
})?;
session.connect()?;
Ok(())
})
.await?;
Ok(session_id)
}

Expand All @@ -146,10 +148,10 @@ impl Ssh {
/// channels they are closed as well and their ids will be marked as
/// invalid.
#[nasl_function]
fn nasl_ssh_disconnect(&mut self, session_id: SessionId) -> Result<()> {
async fn nasl_ssh_disconnect(&mut self, session_id: SessionId) -> Result<()> {
if session_id != 0 {
{
let mut session = self.get_by_id(session_id)?;
let mut session = self.get_by_id(session_id).await?;
session.disconnect()?;
}
self.remove(session_id)?;
Expand All @@ -159,8 +161,10 @@ impl Ssh {

/// Given a socket, return the corresponding session id if available.
#[nasl_function]
fn nasl_ssh_session_id_from_sock(&self, socket: Socket) -> Result<Option<SessionId>> {
Ok(self.find_id(|session| session.get_socket() == socket)?)
async fn nasl_ssh_session_id_from_sock(&self, socket: Socket) -> Result<Option<SessionId>> {
Ok(self
.find_id(|session| session.get_socket() == socket)
.await?)
}

/// Given a session id, return the corresponding socket
Expand All @@ -173,8 +177,8 @@ impl Ssh {
///
/// return An integer representing the socket or -1 on error.
#[nasl_function]
fn nasl_ssh_get_sock(&self, session_id: SessionId) -> Result<Socket> {
let session = self.get_by_id(session_id)?;
async fn nasl_ssh_get_sock(&self, session_id: SessionId) -> Result<Socket> {
let session = self.get_by_id(session_id).await?;
Ok(session.get_socket())
}
/// Set the login name for the authentication.
Expand All @@ -191,8 +195,8 @@ impl Ssh {
/// for an established connection, the "login" parameter is silently
/// ignored on all further calls.
#[nasl_function(named(login))]
fn nasl_ssh_set_login(&self, session_id: SessionId, login: Option<&str>) -> Result<()> {
let mut session = self.get_by_id(session_id)?;
async fn nasl_ssh_set_login(&self, session_id: SessionId, login: Option<&str>) -> Result<()> {
let mut session = self.get_by_id(session_id).await?;
Ok(session.set_opt_user(login)?)
}

Expand Down Expand Up @@ -245,7 +249,7 @@ impl Ssh {
///
/// return An integer as status value; 0 indicates success.
#[nasl_function(named(login, password, privatekey, passphrase))]
fn nasl_ssh_userauth(
async fn nasl_ssh_userauth(
&self,
session_id: SessionId,
login: Option<&str>,
Expand All @@ -260,7 +264,7 @@ impl Ssh {
session_id
)));
}
let mut session = self.get_by_id(session_id)?;
let mut session = self.get_by_id(session_id).await?;
session.ensure_user_set(login)?;

let methods: AuthMethods = session.get_authmethods_cached()?;
Expand Down Expand Up @@ -387,14 +391,14 @@ impl Ssh {
/// If the named parameters @a stdout and @a stderr are not given, the
/// function acts exactly as if only @a stdout has been set to 1.
#[nasl_function(named(cmd, stdout, stderr))]
fn nasl_ssh_request_exec(
async fn nasl_ssh_request_exec(
&self,
session_id: SessionId,
cmd: &str,
stdout: Option<bool>,
stderr: Option<bool>,
) -> Result<Option<String>> {
let session = self.get_by_id(session_id)?;
let session = self.get_by_id(session_id).await?;
if cmd.is_empty() {
return Ok(None);
}
Expand All @@ -413,8 +417,12 @@ impl Ssh {

/// Open a new ssh shell.
#[nasl_function(named(pty))]
fn nasl_ssh_shell_open(&self, session_id: SessionId, pty: Option<bool>) -> Result<SessionId> {
let mut session = self.get_by_id(session_id)?;
async fn nasl_ssh_shell_open(
&self,
session_id: SessionId,
pty: Option<bool>,
) -> Result<SessionId> {
let mut session = self.get_by_id(session_id).await?;
let pty = pty.unwrap_or(true);
session.open_shell(pty)?;
Ok(session.id())
Expand All @@ -425,12 +433,12 @@ impl Ssh {
/// there are no more bytes left to read. Otherwise use non_blocking
/// read mode.
#[nasl_function]
fn nasl_ssh_shell_read(
async fn nasl_ssh_shell_read(
&self,
session_id: SessionId,
timeout: Option<Maybe<u64>>,
) -> Result<String> {
let session = self.get_by_id(session_id)?;
let session = self.get_by_id(session_id).await?;
let timeout = Duration::from_secs(timeout.and_then(Maybe::as_option).unwrap_or(0));
let channel = session.get_channel()?;
channel.ensure_open()?;
Expand All @@ -444,8 +452,8 @@ impl Ssh {

/// Write the string `cmd` to an ssh shell.
#[nasl_function]
fn nasl_ssh_shell_write(&self, session_id: SessionId, cmd: StringOrData) -> Result<i32> {
let session = self.get_by_id(session_id)?;
async fn nasl_ssh_shell_write(&self, session_id: SessionId, cmd: StringOrData) -> Result<i32> {
let session = self.get_by_id(session_id).await?;
let channel = session.get_channel()?;
channel.ensure_open()?;

Expand All @@ -458,8 +466,8 @@ impl Ssh {

/// Close an ssh shell.
#[nasl_function]
fn nasl_ssh_shell_close(&self, session_id: SessionId) -> Result<()> {
let mut session = self.get_by_id(session_id)?;
async fn nasl_ssh_shell_close(&self, session_id: SessionId) -> Result<()> {
let mut session = self.get_by_id(session_id).await?;
session.close();
Ok(())
}
Expand All @@ -472,12 +480,12 @@ impl Ssh {
/// The first time this function is called for a session id, the named
/// argument "login" is also expected.
#[nasl_function(named(login))]
fn nasl_ssh_login_interactive(
async fn nasl_ssh_login_interactive(
&self,
session_id: SessionId,
login: Option<&str>,
) -> Result<Option<String>> {
let mut session = self.get_by_id(session_id)?;
let mut session = self.get_by_id(session_id).await?;
session.ensure_user_set(login)?;
let methods = session.get_authmethods_cached()?;
debug!("Available methods:\n{:?}", methods);
Expand Down Expand Up @@ -522,8 +530,12 @@ impl Ssh {
/// The function finishes the authentication process started by
/// ssh_login_interactive.
#[nasl_function(named(password))]
fn nasl_ssh_login_interactive_pass(&self, session_id: SessionId, password: &str) -> Result<()> {
let session = self.get_by_id(session_id)?;
async fn nasl_ssh_login_interactive_pass(
&self,
session_id: SessionId,
password: &str,
) -> Result<()> {
let session = self.get_by_id(session_id).await?;
let info = session.userauth_keyboard_interactive_info()?;
debug!(
name = info.name,
Expand Down Expand Up @@ -565,8 +577,8 @@ impl Ssh {
/// The function returns a string with the issue banner. This is
/// usually displayed before authentication.
#[nasl_function]
fn nasl_ssh_get_issue_banner(&self, session_id: SessionId) -> Result<Option<String>> {
let mut session = self.get_by_id(session_id)?;
async fn nasl_ssh_get_issue_banner(&self, session_id: SessionId) -> Result<Option<String>> {
let mut session = self.get_by_id(session_id).await?;
session.ensure_user_set(None)?;
session.get_authmethods_cached()?;
Ok(session.get_issue_banner().ok())
Expand All @@ -575,8 +587,8 @@ impl Ssh {
/// The function returns a string with the server banner. This is
/// usually the first data sent by the server.
#[nasl_function]
fn nasl_ssh_get_server_banner(&self, session_id: SessionId) -> Result<Option<String>> {
let session = self.get_by_id(session_id)?;
async fn nasl_ssh_get_server_banner(&self, session_id: SessionId) -> Result<Option<String>> {
let session = self.get_by_id(session_id).await?;
// TODO: Check with openvas-nasl why the outputs doesn't match
Ok(session.get_server_banner().ok())
}
Expand All @@ -586,8 +598,8 @@ impl Ssh {
/// SSH_MSG_USERAUTH_FAILURE protocol element; however, it has been
/// screened and put into a definitive order.
#[nasl_function]
fn nasl_ssh_get_auth_methods(&self, session_id: SessionId) -> Result<Option<String>> {
let mut session = self.get_by_id(session_id)?;
async fn nasl_ssh_get_auth_methods(&self, session_id: SessionId) -> Result<Option<String>> {
let mut session = self.get_by_id(session_id).await?;
session.ensure_user_set(None)?;
let authmethods = session.get_authmethods_cached()?;

Expand Down Expand Up @@ -616,8 +628,8 @@ impl Ssh {

/// Return the MD5 host key.
#[nasl_function]
fn nasl_ssh_get_host_key(&self, session_id: SessionId) -> Result<Option<String>> {
let session = self.get_by_id(session_id)?;
async fn nasl_ssh_get_host_key(&self, session_id: SessionId) -> Result<Option<String>> {
let session = self.get_by_id(session_id).await?;
let key = session.get_server_public_key()?;
match key.get_public_key_hash_hexa(libssh_rs::PublicKeyHashType::Md5) {
Ok(hash) => Ok(Some(hash)),
Expand All @@ -627,8 +639,8 @@ impl Ssh {

/// Check if the SFTP subsystem is enabled on the remote SSH server.
#[nasl_function]
fn nasl_sftp_enabled_check(&self, session_id: SessionId) -> Result<i32> {
let session = self.get_by_id(session_id)?;
async fn nasl_sftp_enabled_check(&self, session_id: SessionId) -> Result<i32> {
let session = self.get_by_id(session_id).await?;
match session.sftp() {
Ok(_) => Ok(0),
Err(e) => {
Expand All @@ -640,8 +652,8 @@ impl Ssh {

/// Execute the NETCONF subsystem on the the ssh channel
#[nasl_function]
fn nasl_ssh_execute_netconf_subsystem(&self, session_id: SessionId) -> Result<SessionId> {
let mut session = self.get_by_id(session_id)?;
async fn nasl_ssh_execute_netconf_subsystem(&self, session_id: SessionId) -> Result<SessionId> {
let mut session = self.get_by_id(session_id).await?;
let channel = session.new_channel()?;
channel.open_session()?;
channel.request_subsystem("netconf")?;
Expand All @@ -654,31 +666,31 @@ impl IntoFunctionSet for Ssh {
type State = Ssh;
fn into_function_set(self) -> StoredFunctionSet<Self::State> {
let mut set = StoredFunctionSet::new(self);
set.sync_stateful_mut("ssh_connect", Ssh::nasl_ssh_connect);
set.sync_stateful_mut("ssh_disconnect", Ssh::nasl_ssh_disconnect);
set.sync_stateful(
set.async_stateful_mut("ssh_connect", Ssh::nasl_ssh_connect);
set.async_stateful_mut("ssh_disconnect", Ssh::nasl_ssh_disconnect);
set.async_stateful(
"ssh_session_id_from_sock",
Ssh::nasl_ssh_session_id_from_sock,
);
set.sync_stateful("ssh_get_sock", Ssh::nasl_ssh_get_sock);
set.sync_stateful("ssh_set_login", Ssh::nasl_ssh_set_login);
set.sync_stateful("ssh_userauth", Ssh::nasl_ssh_userauth);
set.sync_stateful("ssh_request_exec", Ssh::nasl_ssh_request_exec);
set.sync_stateful("ssh_shell_open", Ssh::nasl_ssh_shell_open);
set.sync_stateful("ssh_shell_read", Ssh::nasl_ssh_shell_read);
set.sync_stateful("ssh_shell_write", Ssh::nasl_ssh_shell_write);
set.sync_stateful("ssh_shell_close", Ssh::nasl_ssh_shell_close);
set.sync_stateful("ssh_login_interactive", Ssh::nasl_ssh_login_interactive);
set.sync_stateful(
set.async_stateful("ssh_get_sock", Ssh::nasl_ssh_get_sock);
set.async_stateful("ssh_set_login", Ssh::nasl_ssh_set_login);
set.async_stateful("ssh_userauth", Ssh::nasl_ssh_userauth);
set.async_stateful("ssh_request_exec", Ssh::nasl_ssh_request_exec);
set.async_stateful("ssh_shell_open", Ssh::nasl_ssh_shell_open);
set.async_stateful("ssh_shell_read", Ssh::nasl_ssh_shell_read);
set.async_stateful("ssh_shell_write", Ssh::nasl_ssh_shell_write);
set.async_stateful("ssh_shell_close", Ssh::nasl_ssh_shell_close);
set.async_stateful("ssh_login_interactive", Ssh::nasl_ssh_login_interactive);
set.async_stateful(
"ssh_login_interactive_pass",
Ssh::nasl_ssh_login_interactive_pass,
);
set.sync_stateful("ssh_get_issue_banner", Ssh::nasl_ssh_get_issue_banner);
set.sync_stateful("ssh_get_server_banner", Ssh::nasl_ssh_get_server_banner);
set.sync_stateful("ssh_get_auth_methods", Ssh::nasl_ssh_get_auth_methods);
set.sync_stateful("ssh_get_host_key", Ssh::nasl_ssh_get_host_key);
set.sync_stateful("sftp_enabled_check", Ssh::nasl_sftp_enabled_check);
set.sync_stateful(
set.async_stateful("ssh_get_issue_banner", Ssh::nasl_ssh_get_issue_banner);
set.async_stateful("ssh_get_server_banner", Ssh::nasl_ssh_get_server_banner);
set.async_stateful("ssh_get_auth_methods", Ssh::nasl_ssh_get_auth_methods);
set.async_stateful("ssh_get_host_key", Ssh::nasl_ssh_get_host_key);
set.async_stateful("sftp_enabled_check", Ssh::nasl_sftp_enabled_check);
set.async_stateful(
"ssh_execute_netconf_subsystem",
Ssh::nasl_ssh_execute_netconf_subsystem,
);
Expand Down
2 changes: 1 addition & 1 deletion rust/src/nasl/builtin/ssh/libssh/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use libssh_rs::{AuthMethods, AuthStatus, InteractiveAuthInfo, Session, Sftp, SshKey, SshOption};
use std::sync::MutexGuard;
use std::{os::fd::AsRawFd, time::Duration};
use tokio::sync::MutexGuard;
use tracing::{debug, info};

use crate::nasl::builtin::ssh::SessionId;
Expand Down
Loading

0 comments on commit 16be56e

Please sign in to comment.