Skip to content

Commit

Permalink
Add SessionStore and SessionGetter
Browse files Browse the repository at this point in the history
  • Loading branch information
sugyan committed Nov 21, 2024
1 parent a109ef5 commit a395d90
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 74 deletions.
11 changes: 6 additions & 5 deletions atrium-common/src/store/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use super::Store;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::Mutex;

#[derive(Error, Debug)]
#[error("memory store error")]
Expand All @@ -28,18 +29,18 @@ where
type Error = Error;

async fn get(&self, key: &K) -> Result<Option<V>, Self::Error> {
Ok(self.store.lock().unwrap().get(key).cloned())
Ok(self.store.lock().await.get(key).cloned())
}
async fn set(&self, key: K, value: V) -> Result<(), Self::Error> {
self.store.lock().unwrap().insert(key, value);
self.store.lock().await.insert(key, value);
Ok(())
}
async fn del(&self, key: &K) -> Result<(), Self::Error> {
self.store.lock().unwrap().remove(key);
self.store.lock().await.remove(key);
Ok(())
}
async fn clear(&self) -> Result<(), Self::Error> {
self.store.lock().unwrap().clear();
self.store.lock().await.clear();
Ok(())
}
}
4 changes: 3 additions & 1 deletion atrium-oauth/oauth-client/examples/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use atrium_api::agent::Agent;
use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL};
use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver};
use atrium_oauth_client::store::session::MemorySessionStore;
use atrium_oauth_client::store::state::MemoryStateStore;
use atrium_oauth_client::{
AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, KnownScope, OAuthClient,
Expand Down Expand Up @@ -58,6 +59,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
protected_resource_metadata: Default::default(),
},
state_store: MemoryStateStore::default(),
session_store: MemorySessionStore::default(),
};
let client = OAuthClient::new(config)?;
println!(
Expand All @@ -77,7 +79,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
);

// Click the URL and sign in,
// then copy and paste the URL like “http://127.0.0.1/?iss=...&code=...” after it is redirected.
// then copy and paste the URL like “http://127.0.0.1/callback?iss=...&code=...” after it is redirected.

print!("Redirected url: ");
stdout().lock().flush()?;
Expand Down
4 changes: 3 additions & 1 deletion atrium-oauth/oauth-client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ pub enum Error {
Authorize(String),
#[error("callback error: {0}")]
Callback(String),
#[error("state store error: {0:?}")]
#[error("state store error: {0}")]
StateStore(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("session store error: {0}")]
SessionStore(Box<dyn std::error::Error + Send + Sync + 'static>),
}

pub type Result<T> = core::result::Result<T, Error>;
93 changes: 47 additions & 46 deletions atrium-oauth/oauth-client/src/oauth_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,72 +5,77 @@ use crate::{
oauth_session::OAuthSession,
resolver::{OAuthResolver, OAuthResolverConfig},
server_agent::{OAuthRequest, OAuthServerAgent},
store::state::{InternalStateData, StateStore},
store::{
session::{Session, SessionStore},
session_getter::SessionGetter,
state::{InternalStateData, StateStore},
},
types::{
AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions,
CallbackParams, OAuthAuthorizationServerMetadata, OAuthClientMetadata,
OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters, TokenSet,
OAuthPusehedAuthorizationRequestResponse, PushedAuthorizationRequestParameters,
TryIntoOAuthClientMetadata,
},
utils::{compare_algos, generate_key, generate_nonce, get_random_values},
utils::{compare_algos, generate_key, generate_nonce},
};
use atrium_common::resolver::Resolver;
use atrium_api::types::string::Did;
use atrium_common::{resolver::Resolver, store::Store};
use atrium_identity::{did::DidResolver, handle::HandleResolver};
use atrium_xrpc::HttpClient;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use jose_jwk::{Jwk, JwkSet, Key};
use rand::rngs::ThreadRng;
use serde::Serialize;
use sha2::{Digest, Sha256};
use std::sync::Arc;

#[cfg(feature = "default-client")]
pub struct OAuthClientConfig<S, M, D, H>
pub struct OAuthClientConfig<S0, S1, M, D, H>
where
M: TryIntoOAuthClientMetadata,
{
// Config
pub client_metadata: M,
pub keys: Option<Vec<Jwk>>,
// Stores
pub state_store: S,
pub state_store: S0,
pub session_store: S1,
// Services
pub resolver: OAuthResolverConfig<D, H>,
}

#[cfg(not(feature = "default-client"))]
pub struct OAuthClientConfig<S, T, M, D, H>
pub struct OAuthClientConfig<S0, S1, T, M, D, H>
where
M: TryIntoOAuthClientMetadata,
{
// Config
pub client_metadata: M,
pub keys: Option<Vec<Jwk>>,
// Stores
pub state_store: S,
pub state_store: S0,
pub session_store: S1,
// Services
pub resolver: OAuthResolverConfig<D, H>,
// Others
pub http_client: T,
}

#[cfg(feature = "default-client")]
pub struct OAuthClient<S, D, H, T = crate::http_client::default::DefaultHttpClient>
pub struct OAuthClient<S0, S1, D, H, T = crate::http_client::default::DefaultHttpClient>
where
S: StateStore,
T: HttpClient + Send + Sync + 'static,
{
pub client_metadata: OAuthClientMetadata,
keyset: Option<Keyset>,
resolver: Arc<OAuthResolver<T, D, H>>,
state_store: S,
state_store: S0,
session_getter: SessionGetter<S1>,
http_client: Arc<T>,
}

#[cfg(not(feature = "default-client"))]
pub struct OAuthClient<S, D, H, T>
pub struct OAuthClient<S0, S1, D, H, T>
where
S: StateStore,
T: HttpClient + Send + Sync + 'static,
{
pub client_metadata: OAuthClientMetadata,
Expand All @@ -81,11 +86,8 @@ where
}

#[cfg(feature = "default-client")]
impl<S, D, H> OAuthClient<S, D, H, crate::http_client::default::DefaultHttpClient>
where
S: StateStore,
{
pub fn new<M>(config: OAuthClientConfig<S, M, D, H>) -> Result<Self>
impl<S0, S1, D, H> OAuthClient<S0, S1, D, H, crate::http_client::default::DefaultHttpClient> {
pub fn new<M>(config: OAuthClientConfig<S0, S1, M, D, H>) -> Result<Self>
where
M: TryIntoOAuthClientMetadata<Error = crate::atproto::Error>,
{
Expand All @@ -97,6 +99,7 @@ where
keyset,
resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())),
state_store: config.state_store,
session_getter: SessionGetter::new(config.session_store),
http_client,
})
}
Expand Down Expand Up @@ -125,13 +128,15 @@ where
}
}

impl<S, D, H, T> OAuthClient<S, D, H, T>
impl<S0, S1, D, H, T> OAuthClient<S0, S1, D, H, T>
where
S: StateStore,
S0: StateStore + Send + Sync + 'static,
S1: SessionStore + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
T: HttpClient + Send + Sync + 'static,
S::Error: std::error::Error + Send + Sync + 'static,
S0::Error: std::error::Error + Send + Sync + 'static,
S1::Error: std::error::Error + Send + Sync + 'static,
{
pub fn jwks(&self) -> JwkSet {
self.keyset.as_ref().map(|keyset| keyset.public_jwks()).unwrap_or_default()
Expand Down Expand Up @@ -234,31 +239,28 @@ where
return Err(Error::Callback("missing `iss` parameter".into()));
}
let server = self.create_server_agent(state.dpop_key.clone(), metadata.clone())?;
let token_set = server.exchange_code(&params.code, &state.verifier).await?;
// TODO: store token_set to session store

let session =
self.create_session_from_metadata(state.dpop_key.clone(), metadata, token_set)?;
Ok((session, state.app_state))
}
pub async fn create_session(
&self,
dpop_key: Key,
issuer: String,
token_set: TokenSet,
) -> Result<OAuthSession<T, D, H>> {
let server_metadata = self.resolver.get_authorization_server_metadata(issuer).await?;
self.create_session_from_metadata(dpop_key, server_metadata, token_set)
match server.exchange_code(&params.code, &state.verifier).await {
Ok(token_set) => {
let sub = token_set.sub.clone();
self.session_getter
.set(sub.clone(), Session { dpop_key: state.dpop_key.clone(), token_set })
.await
.map_err(|e| Error::SessionStore(Box::new(e)))?;
Ok((self.create_session(server, sub).await?, state.app_state))
}
Err(_) => {
todo!()
}
}
}
fn create_session_from_metadata(
async fn create_session(
&self,
dpop_key: Key,
server_metadata: OAuthAuthorizationServerMetadata,
token_set: TokenSet,
server: OAuthServerAgent<T, D, H>,
sub: Did,
) -> Result<OAuthSession<T, D, H>> {
Ok(self
.create_server_agent(dpop_key, server_metadata)?
.create_session(self.http_client.clone(), token_set)?)
Ok(server
.create_session(sub, self.http_client.clone(), self.session_getter.clone())
.await?)
}
fn create_server_agent(
&self,
Expand All @@ -282,8 +284,7 @@ where
}
fn generate_pkce() -> (String, String) {
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
let verifier =
URL_SAFE_NO_PAD.encode(get_random_values::<_, 32>(&mut ThreadRng::default()));
let verifier = [generate_nonce(), generate_nonce()].join("");
(URL_SAFE_NO_PAD.encode(Sha256::digest(&verifier)), verifier)
}
}
53 changes: 34 additions & 19 deletions atrium-oauth/oauth-client/src/server_agent.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
use crate::constants::FALLBACK_ALG;
use crate::http_client::dpop::DpopClient;
use crate::jose::jwt::{RegisteredClaims, RegisteredClaimsAud};
use crate::keyset::Keyset;
use crate::oauth_session::OAuthSession;
use crate::resolver::OAuthResolver;
use crate::types::{
OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthTokenResponse,
PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType,
TokenRequestParameters, TokenSet,
use crate::{
constants::FALLBACK_ALG,
http_client::dpop::DpopClient,
jose::jwt::{RegisteredClaims, RegisteredClaimsAud},
keyset::Keyset,
oauth_session::OAuthSession,
resolver::OAuthResolver,
store::{session::SessionStore, session_getter::SessionGetter},
types::{
OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthTokenResponse,
PushedAuthorizationRequestParameters, RefreshRequestParameters, TokenGrantType,
TokenRequestParameters, TokenSet,
},
utils::{compare_algos, generate_nonce},
};
use crate::utils::{compare_algos, generate_nonce};
use atrium_api::types::string::Datetime;
use atrium_api::types::string::{Datetime, Did};
use atrium_common::store::Store;
use atrium_identity::{did::DidResolver, handle::HandleResolver};
use atrium_xrpc::http::{Method, Request, StatusCode};
use atrium_xrpc::HttpClient;
use atrium_xrpc::{
http::{Method, Request, StatusCode},
HttpClient,
};
use chrono::{TimeDelta, Utc};
use jose_jwk::Key;
use serde::Serialize;
Expand All @@ -33,6 +39,8 @@ pub enum Error {
Token(String),
#[error("unsupported authentication method")]
UnsupportedAuthMethod,
#[error("failed to parse DID: {0}")]
InvalidDid(&'static str),
#[error(transparent)]
DpopClient(#[from] crate::http_client::dpop::Error),
#[error(transparent)]
Expand Down Expand Up @@ -154,7 +162,7 @@ where
.map(Datetime::new)
});
Ok(TokenSet {
sub: sub.clone(),
sub: sub.parse().map_err(Error::InvalidDid)?,
aud: identity.pds,
iss: metadata.issuer,
scope: token_response.scope,
Expand Down Expand Up @@ -296,12 +304,19 @@ where
}
}
}
pub(crate) fn create_session(
pub(crate) async fn create_session<S>(
self,
sub: Did,
http_client: Arc<T>,
token_set: TokenSet,
) -> Result<OAuthSession<T, D, H>> {
session_getter: SessionGetter<S>,
) -> Result<OAuthSession<T, D, H>>
where
S: SessionStore + Send + Sync + 'static,
S::Error: std::error::Error + Send + Sync + 'static,
{
let dpop_key = self.dpop_client.key.clone();
Ok(OAuthSession::new(self, dpop_key, http_client, token_set)?)
// TODO
let session = session_getter.get(&sub).await.expect("").unwrap();
Ok(OAuthSession::new(self, dpop_key, http_client, session.token_set)?)
}
}
2 changes: 2 additions & 0 deletions atrium-oauth/oauth-client/src/store.rs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
pub mod session;
pub mod session_getter;
pub mod state;
17 changes: 17 additions & 0 deletions atrium-oauth/oauth-client/src/store/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use crate::types::TokenSet;
use atrium_api::types::string::Did;
use atrium_common::store::{memory::MemoryStore, Store};
use jose_jwk::Key;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Session {
pub dpop_key: Key,
pub token_set: TokenSet,
}

pub trait SessionStore: Store<Did, Session> {}

pub type MemorySessionStore = MemoryStore<Did, Session>;

impl SessionStore for MemorySessionStore {}
Loading

0 comments on commit a395d90

Please sign in to comment.