Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: improve error handling #387

Open
wants to merge 1 commit into
base: session-timeout
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ uuid = { version = "1.11.0", features = ["v4", "fast-rng", "serde"] }
xeddsa = "1.0.2"
futures-util = "0.3.31"
futures = "0.3.31"
thiserror = "2.0.3"

[dev-dependencies]
axum-test = "16.4.0"
Expand Down
101 changes: 33 additions & 68 deletions server/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use axum::{extract::State, http::StatusCode, Json};
use eyre::eyre;
use axum::{extract::State, Json};
use uuid::Uuid;
use xeddsa::{xed25519, Verify as _};

Expand Down Expand Up @@ -33,35 +32,21 @@ pub(crate) async fn login(
) -> Result<Json<KeyLoginOutput>, AppError> {
// Check if the user sent the credentials
if args.signature.is_empty() || args.pubkey.is_empty() {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("empty args").into(),
));
return Err(AppError::InvalidArgument("signature or pubkey".into()));
}

let pubkey = TryInto::<[u8; 32]>::try_into(args.pubkey.clone()).map_err(|_| {
AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid pubkey").into(),
)
})?;
let pubkey = TryInto::<[u8; 32]>::try_into(args.pubkey.clone())
.map_err(|_| AppError::InvalidArgument("pubkey".into()))?;
let pubkey = xed25519::PublicKey(pubkey);
let signature = TryInto::<[u8; 64]>::try_into(args.signature).map_err(|_| {
AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid signature").into(),
)
})?;
let signature = TryInto::<[u8; 64]>::try_into(args.signature)
.map_err(|_| AppError::InvalidArgument("signature".into()))?;
pubkey
.verify(args.uuid.as_bytes(), &signature)
.map_err(|_| AppError(StatusCode::UNAUTHORIZED, eyre!("invalid signature").into()))?;
.map_err(|_| AppError::Unauthorized)?;

let mut challenges = state.challenges.write().unwrap();
if !challenges.remove(&args.uuid) {
return Err(AppError(
StatusCode::UNAUTHORIZED,
eyre!("invalid challenge").into(),
));
return Err(AppError::Unauthorized);
}
drop(challenges);

Expand Down Expand Up @@ -97,10 +82,7 @@ pub(crate) async fn create_new_session(
Json(args): Json<CreateNewSessionArgs>,
) -> Result<Json<CreateNewSessionOutput>, AppError> {
if args.message_count == 0 {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid message_count").into(),
));
return Err(AppError::InvalidArgument("message_count".into()));
}

// Create new session object.
Expand Down Expand Up @@ -157,22 +139,17 @@ pub(crate) async fn get_session_info(
let sessions = state.sessions.sessions.read().unwrap();
let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap();

let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;
let user_sessions = sessions_by_pubkey
.get(&user.pubkey)
.ok_or(AppError::SessionNotFound)?;

if !user_sessions.contains(&args.session_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
));
return Err(AppError::SessionNotFound);
}

let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let session = sessions
.get(&args.session_id)
.ok_or(AppError::SessionNotFound)?;

Ok(Json(GetSessionInfoOutput {
num_signers: session.num_signers,
Expand All @@ -195,10 +172,9 @@ pub(crate) async fn send(

// TODO: change to get_mut and modify in-place, if HashMapDelay ever
// adds support to it
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let mut session = sessions
.remove(&args.session_id)
.ok_or(AppError::SessionNotFound)?;

let recipients = if args.recipients.is_empty() {
vec![Vec::new()]
Expand All @@ -221,7 +197,6 @@ pub(crate) async fn send(
}

/// Implement the recv API
// TODO: get identifier from channel rather from arguments
#[tracing::instrument(ret, err(Debug), skip(state, user))]
pub(crate) async fn receive(
State(state): State<SharedState>,
Expand All @@ -235,10 +210,9 @@ pub(crate) async fn receive(
// adds support to it. This will also simplify the code since
// we have to do a workaround in order to not renew the timeout if there
// are no messages. See https://github.com/AgeManning/delay_map/issues/26
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let session = sessions
.get(&args.session_id)
.ok_or(AppError::SessionNotFound)?;

let pubkey = if user.pubkey == session.coordinator_pubkey && args.as_coordinator {
Vec::new()
Expand All @@ -252,10 +226,9 @@ pub(crate) async fn receive(
let msgs = if session.queue.contains_key(&pubkey) {
drop(sessions);
let mut sessions = state.sessions.sessions.write().unwrap();
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let mut session = sessions
.remove(&args.session_id)
.ok_or(AppError::SessionNotFound)?;
let msgs = session.queue.entry(pubkey).or_default().drain(..).collect();
sessions.insert(args.session_id, session);
msgs
Expand All @@ -276,28 +249,20 @@ pub(crate) async fn close_session(
let mut sessions = state.sessions.sessions.write().unwrap();
let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap();

let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;
let user_sessions = sessions_by_pubkey
.get(&user.pubkey)
.ok_or(AppError::SessionNotFound)?;

if !user_sessions.contains(&args.session_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
));
return Err(AppError::SessionNotFound);
}

let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid session ID").into(),
))?;
let session = sessions
.get(&args.session_id)
.ok_or(AppError::SessionNotFound)?;

if session.coordinator_pubkey != user.pubkey {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not the coordinator of the session").into(),
));
return Err(AppError::NotCoordinator);
}

for username in session.pubkeys.clone() {
Expand Down
49 changes: 44 additions & 5 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod types;
mod user;

pub use state::{AppState, SharedState};
use thiserror::Error;
use tower_http::trace::TraceLayer;
pub use types::*;

Expand All @@ -13,7 +14,7 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
Router,
Json, Router,
};

/// Create the axum Router for the server.
Expand Down Expand Up @@ -47,12 +48,50 @@ pub async fn run(args: &Args) -> Result<(), Box<dyn std::error::Error>> {

/// An error. Wraps a StatusCode which is returned by the server when the
/// error happens during a API call, and a generic eyre::Report.
// TODO: create an enum with specific errors
#[derive(Debug)]
pub struct AppError(StatusCode, Box<dyn std::error::Error>);
#[derive(Debug, Error)]
pub(crate) enum AppError {
#[error("invalid or missing argument: {0}")]
InvalidArgument(String),
#[error("client did not provide proper authorization credentials")]
Unauthorized,
#[error("session was not found")]
SessionNotFound,
#[error("user is not the coordinator")]
NotCoordinator,
}

// These make it easier to clients to tell which error happened.
pub const INVALID_ARGUMENT: usize = 1;
pub const UNAUTHORIZED: usize = 2;
pub const SESSION_NOT_FOUND: usize = 3;
pub const NOT_COORDINATOR: usize = 4;

impl AppError {
pub fn error_code(&self) -> usize {
match &self {
AppError::InvalidArgument(_) => INVALID_ARGUMENT,
AppError::Unauthorized => UNAUTHORIZED,
AppError::SessionNotFound => SESSION_NOT_FOUND,
AppError::NotCoordinator => NOT_COORDINATOR,
}
}
}

impl From<AppError> for types::Error {
fn from(err: AppError) -> Self {
types::Error {
code: err.error_code(),
msg: err.to_string(),
}
}
}

impl IntoResponse for AppError {
fn into_response(self) -> Response {
(self.0, format!("{}", self.1)).into_response()
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(Into::<types::Error>::into(self)),
)
.into_response()
}
}
6 changes: 6 additions & 0 deletions server/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ use frost_rerandomized::Randomizer;
use serde::{Deserialize, Serialize};
pub use uuid::Uuid;

#[derive(Debug, Serialize, Deserialize)]
pub struct Error {
pub code: usize,
pub msg: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterArgs {
pub username: String,
Expand Down
29 changes: 5 additions & 24 deletions server/src/user.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
use std::str::FromStr;

use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, StatusCode},
RequestPartsExt,
};
use axum::{async_trait, extract::FromRequestParts, http::request::Parts, RequestPartsExt};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use eyre::eyre;
use sqlx::FromRow;
use uuid::Uuid;

Expand All @@ -19,7 +13,7 @@ use crate::{state::SharedState, AppError};
/// An User
#[derive(Debug, FromRow)]
#[allow(dead_code)]
pub struct User {
pub(crate) struct User {
pub(crate) pubkey: Vec<u8>,
pub(crate) current_token: Uuid,
}
Expand All @@ -43,19 +37,9 @@ impl FromRequestParts<SharedState> for User {
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| {
AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("Bearer token missing").into(),
)
})?;
.map_err(|_| AppError::Unauthorized)?;
// Decode the user data
let access_token = Uuid::from_str(bearer.token()).map_err(|_| {
AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid access token").into(),
)
})?;
let access_token = Uuid::from_str(bearer.token()).map_err(|_| AppError::Unauthorized)?;

let pubkey = state
.access_tokens
Expand All @@ -70,10 +54,7 @@ impl FromRequestParts<SharedState> for User {
current_token: access_token,
})
} else {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("user not found").into(),
));
return Err(AppError::Unauthorized);
}
}
}
22 changes: 19 additions & 3 deletions server/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use server::{
};

use frost_core as frost;
use uuid::Uuid;
use xeddsa::{xed25519, Sign, Verify};

#[tokio::test]
Expand Down Expand Up @@ -411,7 +412,7 @@ async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
.send()
.await?;
if r.status() != reqwest::StatusCode::OK {
panic!("{}", r.text().await?)
panic!("{:?}", r.json::<server::Error>().await?)
}
let r = r.json::<server::ChallengeOutput>().await?;
let alice_challenge = r.challenge;
Expand All @@ -430,7 +431,7 @@ async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
.send()
.await?;
if r.status() != reqwest::StatusCode::OK {
panic!("{}", r.text().await?)
panic!("{:?}", r.json::<server::Error>().await?)
}
let r = r.json::<server::KeyLoginOutput>().await?;
let access_token = r.access_token;
Expand All @@ -447,12 +448,27 @@ async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
.send()
.await?;
if r.status() != reqwest::StatusCode::OK {
panic!("{}", r.text().await?)
panic!("{:?}", r.json::<server::Error>().await?)
}
let r = r.json::<server::CreateNewSessionOutput>().await?;
let session_id = r.session_id;
println!("Session ID: {}", session_id);

// Error test

let wrong_session_id = Uuid::new_v4();
let r = client
.post("http://127.0.0.1:2744/get_session_info")
.bearer_auth(access_token)
.json(&server::GetSessionInfoArgs {
session_id: wrong_session_id,
})
.send()
.await?;
assert_eq!(r.status(), reqwest::StatusCode::INTERNAL_SERVER_ERROR);
let r = r.json::<server::Error>().await?;
assert_eq!(r.code, server::SESSION_NOT_FOUND);

Ok(())
}

Expand Down