Skip to content

Commit

Permalink
Make ssh impl work with new mutable state fns.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tehforsch committed Sep 11, 2024
1 parent c4625c5 commit bae1120
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 107 deletions.
5 changes: 4 additions & 1 deletion rust/crates/nasl-function-proc-macro/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ impl<'a> ArgsStruct<'a> {
let fn_args_names = self.get_fn_args_names();
let call_expr = match self.receiver_type {
ReceiverType::None => quote! { #mangled_ident(#fn_args_names) },
ReceiverType::RefSelf => quote! { self.#mangled_ident(#fn_args_names) },
ReceiverType::RefSelf | ReceiverType::RefMutSelf => {
quote! { self.#mangled_ident(#fn_args_names) }
}
};
let await_ = match asyncness {
Some(_) => quote! { .await },
Expand Down Expand Up @@ -174,6 +176,7 @@ impl<'a> ArgsStruct<'a> {
let self_arg = match self.receiver_type {
ReceiverType::None => quote! {},
ReceiverType::RefSelf => quote! {&self,},
ReceiverType::RefMutSelf => quote! {&mut self,},
};
let inputs = quote! {
#self_arg
Expand Down
4 changes: 0 additions & 4 deletions rust/crates/nasl-function-proc-macro/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ pub enum ErrorKind {
TooManyAttributes,
OnlyNormalArgumentsAllowed,
MovedReceiverType,
MutableRefReceiverType,
TypedRefReceiverType,
}

Expand All @@ -36,9 +35,6 @@ impl Error {
ErrorKind::MovedReceiverType => {
"Receiver argument is of type `self`. Currently, only `&self` receiver types are supported."
}
ErrorKind::MutableRefReceiverType => {
"Receiver argument is of type `&mut self`. Currently, only `&self` receiver types are supported."
}
ErrorKind::TypedRefReceiverType => {
"Specific type specified in receiver argument. Currently, only `&self` is supported."
}
Expand Down
8 changes: 4 additions & 4 deletions rust/crates/nasl-function-proc-macro/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ impl ReceiverType {
if rec.reference.is_none() {
return make_err(ErrorKind::MovedReceiverType);
}
// `&mut self`
else if rec.mutability.is_some() {
return make_err(ErrorKind::MutableRefReceiverType);
}
// e.g. `self: Box<Self>`
else if rec.colon_token.is_some() {
return make_err(ErrorKind::TypedRefReceiverType);
}
// `&mut self`
else if rec.mutability.is_some() {
ReceiverType::RefMutSelf
} else {
ReceiverType::RefSelf
}
Expand Down
1 change: 1 addition & 0 deletions rust/crates/nasl-function-proc-macro/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct ArgsStruct<'a> {
pub enum ReceiverType {
None,
RefSelf,
RefMutSelf,
}

pub struct Arg<'a> {
Expand Down
73 changes: 42 additions & 31 deletions rust/src/nasl/builtin/ssh/libssh/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use tracing::debug;
use super::SessionId;
use crate::nasl::prelude::*;
use crate::nasl::utils::function::{Maybe, StringOrData};
use crate::nasl::utils::{IntoFunctionSet, StoredFunctionSet};

pub type Socket = c_int;

Expand Down Expand Up @@ -79,7 +80,7 @@ impl Ssh {
/// nasl return An integer to identify the ssh session. Zero on error.
#[nasl_function(named(socket, port, keytype, csciphers, scciphers, timeout))]
fn nasl_ssh_connect(
&self,
&mut self,
socket: Option<Socket>,
port: Option<u16>,
keytype: Option<&str>,
Expand Down Expand Up @@ -145,10 +146,12 @@ impl Ssh {
/// channels they are closed as well and their ids will be marked as
/// invalid.
#[nasl_function]
fn nasl_ssh_disconnect(&self, session_id: SessionId) -> Result<()> {
fn nasl_ssh_disconnect(&mut self, session_id: SessionId) -> Result<()> {
if session_id != 0 {
let mut session = self.get_by_id(session_id)?;
session.disconnect()?;
{
let mut session = self.get_by_id(session_id)?;
session.disconnect()?;
}
self.remove(session_id)?;
}
Ok(())
Expand All @@ -157,9 +160,7 @@ impl Ssh {
/// Given a socket, return the corresponding session id if available.
#[nasl_function]
fn nasl_ssh_session_id_from_sock(&self, socket: Socket) -> Result<Option<SessionId>> {
Ok(self
.find(|session| session.get_socket() == socket)?
.map(|session| session.id()))
Ok(self.find_id(|session| session.get_socket() == socket)?)
}

/// Given a session id, return the corresponding socket
Expand Down Expand Up @@ -649,28 +650,38 @@ impl Ssh {
}
}

function_set! {
Ssh,
sync_stateful,
(
Ssh::nasl_ssh_connect,
Ssh::nasl_ssh_disconnect,
Ssh::nasl_ssh_session_id_from_sock,
Ssh::nasl_ssh_get_sock,
Ssh::nasl_ssh_set_login,
Ssh::nasl_ssh_userauth,
Ssh::nasl_ssh_request_exec,
Ssh::nasl_ssh_shell_open,
Ssh::nasl_ssh_shell_read,
Ssh::nasl_ssh_shell_write,
Ssh::nasl_ssh_shell_close,
Ssh::nasl_ssh_login_interactive,
Ssh::nasl_ssh_login_interactive_pass,
Ssh::nasl_ssh_get_issue_banner,
Ssh::nasl_ssh_get_server_banner,
Ssh::nasl_ssh_get_auth_methods,
Ssh::nasl_ssh_get_host_key,
Ssh::nasl_sftp_enabled_check,
Ssh::nasl_ssh_execute_netconf_subsystem,
)
impl IntoFunctionSet for Ssh {
type State = Ssh;
fn into_function_set(self) -> StoredFunctionSet<Self::State> {
let mut set = StoredFunctionSet::new(self);
set.sync_stateful_mut("ssh_connect", Ssh::nasl_ssh_connect);
set.sync_stateful_mut("ssh_disconnect", Ssh::nasl_ssh_disconnect);
set.sync_stateful(
"ssh_session_id_from_sock",
Ssh::nasl_ssh_session_id_from_sock,
);
set.sync_stateful("ssh_get_sock", Ssh::nasl_ssh_get_sock);
set.sync_stateful("ssh_set_login", Ssh::nasl_ssh_set_login);
set.sync_stateful("ssh_userauth", Ssh::nasl_ssh_userauth);
set.sync_stateful("ssh_request_exec", Ssh::nasl_ssh_request_exec);
set.sync_stateful("ssh_shell_open", Ssh::nasl_ssh_shell_open);
set.sync_stateful("ssh_shell_read", Ssh::nasl_ssh_shell_read);
set.sync_stateful("ssh_shell_write", Ssh::nasl_ssh_shell_write);
set.sync_stateful("ssh_shell_close", Ssh::nasl_ssh_shell_close);
set.sync_stateful("ssh_login_interactive", Ssh::nasl_ssh_login_interactive);
set.sync_stateful(
"ssh_login_interactive_pass",
Ssh::nasl_ssh_login_interactive_pass,
);
set.sync_stateful("ssh_get_issue_banner", Ssh::nasl_ssh_get_issue_banner);
set.sync_stateful("ssh_get_server_banner", Ssh::nasl_ssh_get_server_banner);
set.sync_stateful("ssh_get_auth_methods", Ssh::nasl_ssh_get_auth_methods);
set.sync_stateful("ssh_get_host_key", Ssh::nasl_ssh_get_host_key);
set.sync_stateful("sftp_enabled_check", Ssh::nasl_sftp_enabled_check);
set.sync_stateful(
"ssh_execute_netconf_subsystem",
Ssh::nasl_ssh_execute_netconf_subsystem,
);
set
}
}
25 changes: 5 additions & 20 deletions rust/src/nasl/builtin/ssh/libssh/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,35 +37,20 @@ impl SshSession {
}

pub struct BorrowedSession<'a> {
guard: MutexGuard<'a, Vec<SshSession>>,
index: usize,
guard: MutexGuard<'a, SshSession>,
}

impl<'a> BorrowedSession<'a> {
pub fn new(guard: MutexGuard<'a, Vec<SshSession>>, id: SessionId) -> Result<Self> {
let index = guard
.iter()
.enumerate()
.find(|(_, session)| session.id == id)
.ok_or_else(|| SshError::InvalidSessionId(id))?
.0;
Ok(Self { guard, index })
}

pub fn from_index(guard: MutexGuard<'a, Vec<SshSession>>, index: usize) -> Self {
Self { guard, index }
}

pub fn take_guard(self) -> MutexGuard<'a, Vec<SshSession>> {
self.guard
pub fn new(guard: MutexGuard<'a, SshSession>) -> Self {
Self { guard }
}

fn borrow(&self) -> &SshSession {
&self.guard[self.index]
&self.guard
}

fn borrow_mut(&mut self) -> &mut SshSession {
&mut self.guard[self.index]
&mut self.guard
}

fn session(&self) -> &Session {
Expand Down
78 changes: 31 additions & 47 deletions rust/src/nasl/builtin/ssh/libssh/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

//! Defines functions and structures for handling sessions

use std::collections::HashSet;
use std::collections::{HashMap, HashSet};

use std::sync::{Arc, Mutex, MutexGuard};
use std::sync::Mutex;

use crate::nasl::builtin::ssh::SessionId;

Expand All @@ -15,17 +15,20 @@ use super::session::{BorrowedSession, SshSession};

#[derive(Default)]
pub struct Ssh {
sessions: Arc<Mutex<Vec<SshSession>>>,
// Unfortunately, we need a Mutex around the SshSession here.
// This is because it contains a libssh::Channel, which is not `Send`.
sessions: HashMap<SessionId, Mutex<SshSession>>,
}

impl Ssh {
fn lock(&self) -> Result<MutexGuard<Vec<SshSession>>> {
self.sessions.lock().map_err(|_| SshError::PoisonedLock)
}

pub fn get_by_id(&self, id: SessionId) -> Result<BorrowedSession> {
let guard = self.lock()?;
BorrowedSession::new(guard, id)
Ok(BorrowedSession::new(
self.sessions
.get(&id)
.ok_or_else(|| SshError::InvalidSessionId(id))?
.lock()
.map_err(|_| SshError::PoisonedLock)?,
))
}

/// Return the next available session ID
Expand All @@ -34,71 +37,52 @@ impl Ssh {
// hand out is an arbitrary high number, this is only to help
// debugging.
const MIN_VAL: SessionId = 9000;
let taken_ids = self
.lock()?
.iter()
.map(|session| session.id)
.collect::<HashSet<SessionId>>();
let taken_ids: HashSet<_> = self.sessions.keys().collect();
if taken_ids.is_empty() {
Ok(MIN_VAL)
} else {
let max_val = taken_ids.iter().max().unwrap() + 1;
let max_val = **taken_ids.iter().max().unwrap() + 1;
Ok((MIN_VAL..=max_val)
.find(|id| !taken_ids.contains(id))
.unwrap())
}
}

pub fn remove(&self, session_id: SessionId) -> Result<()> {
let mut guard = self.lock()?;
if let Some((index, _)) = guard
.iter()
.enumerate()
.find(|(_, session)| session.id == session_id)
{
guard.remove(index);
}
pub fn remove(&mut self, session_id: SessionId) -> Result<()> {
self.sessions.remove(&session_id);
Ok(())
}

pub fn find<'a>(
pub fn find_id<'a>(
&'a self,
f: impl for<'b> Fn(&BorrowedSession<'b>) -> bool,
) -> Result<Option<BorrowedSession<'a>>> {
// This is a pretty ugly implementation but the borrow checker
// (somewhat rightfully) makes this quite hard to do normally.
let mut guard = self.lock()?;
let len = guard.len();
for i in 0..len {
let session = BorrowedSession::from_index(guard, i);
) -> Result<Option<SessionId>> {
for id in self.sessions.keys() {
let session = self.get_by_id(*id)?;
if f(&session) {
return Ok(Some(session));
return Ok(Some(session.id()));
}
guard = session.take_guard();
}
Ok(None)
}

/// Create a new session, but only add it to the list of active sessions
/// if the given closure which modifies the session returns Ok(...).
pub fn add_new_session(
&self,
&mut self,
f: impl Fn(&mut BorrowedSession) -> Result<()>,
) -> Result<SessionId> {
let id = self.next_session_id()?;
let mut guard = self.lock()?;
let index = guard.len();
let session = SshSession::new(id)?;
guard.push(session);
let mut session = BorrowedSession::from_index(guard, index);
let result = f(&mut session);
match result {
Ok(()) => Ok(id),
Err(e) => {
session.disconnect()?;
session.take_guard().pop();
Err(e)
let session = Mutex::new(SshSession::new(id)?);
{
let mut borrowed_session =
BorrowedSession::new(session.lock().map_err(|_| SshError::PoisonedLock)?);
if let Err(e) = f(&mut borrowed_session) {
borrowed_session.disconnect()?;
return Err(e);
}
}
self.sessions.insert(id, session);
Ok(id)
}
}

0 comments on commit bae1120

Please sign in to comment.