Skip to content

Commit

Permalink
use SessionStore in OAuthSession
Browse files Browse the repository at this point in the history
  • Loading branch information
avdb13 committed Nov 24, 2024
1 parent 82a9398 commit dcfd1a1
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 47 deletions.
4 changes: 2 additions & 2 deletions atrium-common/src/store/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ impl<K, V> Default for MemoryStore<K, V> {

impl<K, V> Store<K, V> for MemoryStore<K, V>
where
K: Debug + Eq + Hash + Send + Sync + 'static,
V: Debug + Clone + Send + Sync + 'static,
K: Eq + Hash + Send + Sync,
V: Clone + Send,
{
type Error = Error;

Expand Down
2 changes: 2 additions & 0 deletions atrium-oauth/oauth-client/src/http_client/dpop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub enum Error {
UnsupportedKey,
#[error("nonce store error: {0}")]
Nonces(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("session store error: {0}")]
SessionStore(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(transparent)]
SerdeJson(#[from] serde_json::Error),
}
Expand Down
48 changes: 34 additions & 14 deletions atrium-oauth/oauth-client/src/oauth_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,35 @@ use atrium_xrpc::{
};
use jose_jwk::Key;

use crate::{http_client::dpop::Error, server_agent::OAuthServerAgent, DpopClient, TokenSet};
use crate::{
http_client::dpop::Error,
server_agent::OAuthServerAgent,
store::session::{MemorySessionStore, SessionStore},
DpopClient, TokenSet,
};

pub struct OAuthSession<T, D, H, S = MemoryStore<String, String>>
where
pub struct OAuthSession<
T,
D,
H,
S0 = MemoryStore<String, String>,
S1 = MemorySessionStore<(), TokenSet>,
> where
T: HttpClient + Send + Sync + 'static,
S: Store<String, String>,
S0: Store<String, String>,
S1: SessionStore<(), TokenSet>,
{
#[allow(dead_code)]
server_agent: OAuthServerAgent<T, D, H>,
dpop_client: DpopClient<T, S>,
token_set: TokenSet,
dpop_client: DpopClient<T, S0>,
session_store: S1,
}

impl<T, D, H> OAuthSession<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
{
pub(crate) fn new(
pub(crate) async fn new(
server_agent: OAuthServerAgent<T, D, H>,
dpop_key: Key,
http_client: Arc<T>,
Expand All @@ -38,13 +49,19 @@ where
false,
&server_agent.server_metadata.token_endpoint_auth_signing_alg_values_supported,
)?;
Ok(Self { server_agent, dpop_client, token_set })

let session_store = MemorySessionStore::default();
session_store.set((), token_set).await.map_err(|e| Error::SessionStore(Box::new(e)))?;

Ok(Self { server_agent, dpop_client, session_store })
}
pub fn dpop_key(&self) -> Key {
self.dpop_client.key.clone()
}
pub fn token_set(&self) -> TokenSet {
self.token_set.clone()
pub async fn token_set(&self) -> Result<TokenSet, Error> {
let token_set =
self.session_store.get(&()).await.map_err(|e| Error::SessionStore(Box::new(e)))?;
Ok(token_set.expect("session store can never be empty"))
}
// pub async fn get_session(&self, refresh: bool) -> crate::Result<Session> {
// let Some(session) = self
Expand Down Expand Up @@ -97,13 +114,15 @@ where
S::Error: std::error::Error + Send + Sync + 'static,
{
fn base_uri(&self) -> String {
self.token_set.aud.clone()
// self.token_set.aud.clone()
todo!()
}
async fn authorization_token(&self, is_refresh: bool) -> Option<AuthorizationToken> {
let token_set = self.session_store.get(&()).await.transpose().and_then(Result::ok)?;
if is_refresh {
self.token_set.refresh_token.as_ref().cloned().map(AuthorizationToken::Dpop)
token_set.refresh_token.as_ref().cloned().map(AuthorizationToken::Dpop)
} else {
Some(AuthorizationToken::Dpop(self.token_set.access_token.clone()))
Some(AuthorizationToken::Dpop(token_set.access_token.clone()))
}
}
}
Expand All @@ -117,6 +136,7 @@ where
S::Error: std::error::Error + Send + Sync + 'static,
{
async fn did(&self) -> Option<Did> {
Some(self.token_set.sub.clone())
let token_set = self.session_store.get(&()).await.transpose().and_then(Result::ok)?;
Some(token_set.sub.clone())
}
}
36 changes: 8 additions & 28 deletions atrium-oauth/oauth-client/src/server_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,39 +193,19 @@ where
.await;
}
#[allow(dead_code)]
pub async fn refresh(&self, token_set: &TokenSet) {
pub async fn refresh(&self, token_set: &TokenSet) -> Result<TokenSet> {
let Some(refresh_token) = token_set.refresh_token.as_ref() else {
// TODO
return;
return Err(Error::NoRefreshToken(token_set.sub.to_string()));
};
// TODO
let result = self
.request::<OAuthTokenResponse>(OAuthRequest::Refresh(RefreshRequestParameters {
self.verify_token_response(
self.request::<OAuthTokenResponse>(OAuthRequest::Refresh(RefreshRequestParameters {
grant_type: TokenGrantType::RefreshToken,
refresh_token: refresh_token.clone(),
scope: None,
}))
.await;
println!("{result:?}");

// let Some(refresh_token) = token_set.refresh_token else {
// return Err(Error::NoRefreshToken(token_set.sub.clone()));
// };
// let (metadata, atrium_identity::identity_resolver::ResolvedIdentity { pds: aud, .. }) =
// self.resolver.resolve_from_identity(&token_set.sub).await?;
// if metadata.issuer != self.server_metadata.issuer {
// let _ = self.revoke(&token_set.access_token).await;
// return Err(Error::Token("issuer mismatch".into()));
// }
// let token_set = self
// .verify_token_response(
// self.request(OAuthRequest::Token(TokenRequestParameters::RefreshToken(
// RefreshTokenParameters { refresh_token, scope: token_set.scope.clone() },
// )))
// .await?,
// )
// .await?;
// Ok(TokenSet { aud, ..token_set })
.await?,
)
.await
}
pub async fn request<O>(&self, request: OAuthRequest) -> Result<O>
where
Expand Down Expand Up @@ -345,6 +325,6 @@ where
let dpop_key = self.dpop_client.key.clone();
// TODO
let session = session_getter.get(&sub).await.expect("").unwrap();
Ok(OAuthSession::new(self, dpop_key, http_client, session.token_set)?)
OAuthSession::new(self, dpop_key, http_client, session.token_set).await.map_err(Into::into)
}
}
18 changes: 15 additions & 3 deletions atrium-oauth/oauth-client/src/store/session.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::hash::Hash;

use crate::types::TokenSet;
use atrium_api::types::string::{Datetime, Did};
use atrium_common::store::{memory::MemoryStore, Store};
Expand All @@ -19,8 +21,18 @@ impl Session {
}
}

pub trait SessionStore: Store<Did, Session> {}
pub trait SessionStore<K = Did, V = Session>: Store<K, V>
where
K: Eq + Hash,
V: Clone,
{
}

pub type MemorySessionStore = MemoryStore<Did, Session>;
pub type MemorySessionStore<K = Did, V = Session> = MemoryStore<K, V>;

impl SessionStore for MemorySessionStore {}
impl<K, V> SessionStore<K, V> for MemorySessionStore<K, V>
where
K: Eq + Hash + Send + Sync,
V: Clone + Send,
{
}

0 comments on commit dcfd1a1

Please sign in to comment.