Skip to content

Commit

Permalink
compiles
Browse files Browse the repository at this point in the history
  • Loading branch information
avdb13 committed Oct 29, 2024
1 parent 9b4aef1 commit 5653d79
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 42 deletions.
2 changes: 2 additions & 0 deletions atrium-oauth/oauth-client/examples/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL};
use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver};
use atrium_oauth_client::store::cached::CachedMemoryStore;
use atrium_oauth_client::store::state::MemoryStateStore;
use atrium_oauth_client::{
AtprotoLocalhostClientMetadata, AuthorizeOptions, DefaultHttpClient, OAuthClient,
Expand Down Expand Up @@ -53,6 +54,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
protected_resource_metadata: Default::default(),
},
state_store: MemoryStateStore::default(),
session_store: CachedMemoryStore::default(),
};
let client = OAuthClient::new(config)?;
println!(
Expand Down
2 changes: 2 additions & 0 deletions atrium-oauth/oauth-client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pub enum Error {
Callback(String),
#[error("state store error: {0:?}")]
StateStore(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error(transparent)]
Session(#[from] crate::oauth_session::Error),
}

pub type Result<T> = core::result::Result<T, Error>;
1 change: 1 addition & 0 deletions atrium-oauth/oauth-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod server_agent;
pub mod store;
mod types;

Check warning on line 11 in atrium-oauth/oauth-client/src/lib.rs

View workflow job for this annotation

GitHub Actions / Rust (1.75.0)

Diff in /home/runner/work/atrium/atrium/atrium-oauth/oauth-client/src/lib.rs

Check warning on line 11 in atrium-oauth/oauth-client/src/lib.rs

View workflow job for this annotation

GitHub Actions / Rust (stable)

Diff in /home/runner/work/atrium/atrium/atrium-oauth/oauth-client/src/lib.rs
mod utils;
pub mod oauth_session;

pub use atproto::{
AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, GrantType, Scope,
Expand Down
67 changes: 43 additions & 24 deletions atrium-oauth/oauth-client/src/oauth_client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::constants::FALLBACK_ALG;
use crate::error::{Error, Result};
use crate::keyset::Keyset;
use crate::oauth_session::OAuthSession;
use crate::resolver::{OAuthResolver, OAuthResolverConfig};
use crate::server_agent::{OAuthRequest, OAuthServerAgent};
use crate::store::session::SessionStore;
use crate::server_agent::{OAuthRequest, OAuthServerAgent, OAuthServerFactory};
use crate::store::cached::Cached;
use crate::store::session::{Session, SessionStore};
use crate::store::state::{InternalStateData, StateStore};
use crate::types::{
AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptions, CallbackParams,
Expand All @@ -12,6 +14,7 @@ use crate::types::{
TryIntoOAuthClientMetadata,
};
use crate::utils::{compare_algos, generate_key, generate_nonce, get_random_values};
use atrium_api::types::string::Did;
use atrium_identity::{did::DidResolver, handle::HandleResolver, Resolver};
use atrium_xrpc::HttpClient;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
Expand Down Expand Up @@ -64,9 +67,10 @@ where
pub client_metadata: OAuthClientMetadata,
keyset: Option<Keyset>,
resolver: Arc<OAuthResolver<T, D, H>>,
server_factory: OAuthServerFactory<D, H, T>,
state_store: S,
session_store: N,
http_client: Arc<T>,
_http_client: Arc<T>,
}

#[cfg(not(feature = "default-client"))]
Expand All @@ -79,6 +83,7 @@ where
pub client_metadata: OAuthClientMetadata,
keyset: Option<Keyset>,
resolver: Arc<OAuthResolver<T, D, H>>,
server_factory: OAuthServerFactory<D, H, T>,
state_store: S,
session_store: N,
http_client: Arc<T>,
Expand All @@ -97,13 +102,15 @@ where
let keyset = if let Some(keys) = config.keys { Some(keys.try_into()?) } else { None };
let client_metadata = config.client_metadata.try_into_client_metadata(&keyset)?;
let http_client = Arc::new(crate::http_client::default::DefaultHttpClient::default());
let resolver = Arc::new(OAuthResolver::new(config.resolver, http_client.clone()));
Ok(Self {
client_metadata,
keyset,
resolver: Arc::new(OAuthResolver::new(config.resolver, http_client.clone())),
client_metadata: client_metadata.clone(),
keyset: keyset.clone(),
resolver: resolver.clone(),
server_factory: OAuthServerFactory::new(resolver, client_metadata, keyset),
state_store: config.state_store,
session_store: config.session_store,
http_client,
_http_client: http_client,
})
}
}
Expand Down Expand Up @@ -185,14 +192,8 @@ where
prompt: options.prompt.map(String::from),
};
if metadata.pushed_authorization_request_endpoint.is_some() {
let server = OAuthServerAgent::new(
dpop_key,
metadata.clone(),
self.client_metadata.clone(),
self.resolver.clone(),
self.http_client.clone(),
self.keyset.clone(),
)?;
let server =
self.server_factory.from_issuer(&metadata.issuer, dpop_key.clone()).await?;
let par_response = server
.request::<OAuthPusehedAuthorizationRequestResponse>(
OAuthRequest::PushedAuthorizationRequest(parameters),
Expand Down Expand Up @@ -244,17 +245,28 @@ where
} else if metadata.authorization_response_iss_parameter_supported == Some(true) {
return Err(Error::Callback("missing `iss` parameter".into()));
}
let server = OAuthServerAgent::new(
state.dpop_key.clone(),
metadata.clone(),
self.client_metadata.clone(),
self.resolver.clone(),
self.http_client.clone(),
self.keyset.clone(),
)?;
let server =
self.server_factory.from_issuer(&metadata.issuer, state.dpop_key.clone()).await?;
let token_set = server.exchange_code(&params.code, &state.verifier).await?;

// TODO: create session?
let sub: Did = token_set.sub.parse().unwrap();

if let Err(_error) = self
.session_store
.set(sub.clone(), Cached::new(Session::new(state.dpop_key.clone(), token_set.clone())))
.await
{
let _ = server
.revoke(
token_set.refresh_token.as_deref().unwrap_or_else(|| &token_set.access_token),
)
.await;

todo!();
// return Err(error);
}
let _session = self.create_session(server, sub);

Ok(token_set)
}
fn generate_dpop_key(metadata: &OAuthAuthorizationServerMetadata) -> Option<Key> {
Expand All @@ -271,4 +283,11 @@ where
hasher.update(verifier.as_bytes());
(URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes())), verifier)
}
fn create_session(
&self,
server: OAuthServerAgent<T, D, H>,
sub: Did,
) -> OAuthSession<N, T, D, H> {
OAuthSession::new(server, sub, self.session_store.clone())
}
}
75 changes: 75 additions & 0 deletions atrium-oauth/oauth-client/src/oauth_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use std::sync::Arc;

use atrium_api::types::string::Did;
use atrium_identity::{did::DidResolver, handle::HandleResolver};
use atrium_xrpc::HttpClient;
use thiserror::Error;

use crate::{
server_agent::OAuthServerAgent,
store::session::{Session, SessionStore},
types::TokenInfo,
Result, TokenSet,
};

#[derive(Clone, Debug, Error)]
pub enum Error {}

pub struct OAuthSession<S, T, D, H>
where
S: SessionStore,
T: HttpClient + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
{
session_store: S,
pub server: Arc<OAuthServerAgent<T, D, H>>,
pub sub: Did,
}
impl<S, T, D, H> OAuthSession<S, T, D, H>
where
S: SessionStore,
T: HttpClient + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
{
pub fn new(server: OAuthServerAgent<T, D, H>, sub: Did, session_store: S) -> Self {
Self { server: Arc::new(server), sub, session_store }
}

pub async fn get_token_set(&self, _refresh: Option<bool>) -> Result<TokenSet> {
let Some(value) = self.session_store.get(&self.sub).await.unwrap() else { todo!() };

let server = self.server.clone();

let get_cached = value.get_cached(|session| {
Box::pin(async move {
let Some(session) = session else { todo!() };

Ok(Session {
dpop_key: session.dpop_key,
token_set: server.refresh(session.token_set.clone()).await.unwrap(),
})
})
});
let session = get_cached.await.unwrap();

Ok(session.token_set)
}

pub async fn get_token_info(&self, refresh: Option<bool>) -> Result<TokenInfo> {
let TokenSet { iss, sub, aud, scope, expires_at, .. } = self.get_token_set(refresh).await?;
let expires_at = expires_at.as_ref().map(AsRef::as_ref).cloned();

Ok(TokenInfo::new(iss, sub.parse().expect("valid Did"), aud, scope, expires_at))
}

pub async fn logout(&self, _refresh: Option<bool>) -> Result<()> {
let token_set = self.get_token_set(Some(false)).await?;

self.server.revoke(&token_set.access_token).await?;

let _ = self.session_store.del(&self.sub).await;
Ok(())
}
}
45 changes: 40 additions & 5 deletions atrium-oauth/oauth-client/src/server_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use crate::keyset::Keyset;
use crate::resolver::OAuthResolver;
use crate::types::{
AuthorizationCodeParameters, OAuthAuthorizationServerMetadata, OAuthClientMetadata,
OAuthTokenResponse, PushedAuthorizationRequestParameters, TokenRequestParameters, TokenSet,
OAuthTokenResponse, PushedAuthorizationRequestParameters, RefreshTokenParameters,
RevocationRequestParameters, TokenRequestParameters, TokenSet,
};
use crate::utils::{compare_algos, generate_nonce};
use atrium_api::types::string::Datetime;
Expand Down Expand Up @@ -56,7 +57,7 @@ pub type Result<T> = core::result::Result<T, Error>;
#[allow(dead_code)]
pub enum OAuthRequest {
Token(TokenRequestParameters),
Revocation,
Revocation(RevocationRequestParameters),
Introspection,
PushedAuthorizationRequest(PushedAuthorizationRequestParameters),
}
Expand All @@ -65,14 +66,14 @@ impl OAuthRequest {
fn name(&self) -> String {
String::from(match self {
Self::Token(_) => "token",
Self::Revocation => "revocation",
Self::Revocation(_) => "revocation",
Self::Introspection => "introspection",
Self::PushedAuthorizationRequest(_) => "pushed_authorization_request",
})
}
fn expected_status(&self) -> StatusCode {
match self {
Self::Token(_) => StatusCode::OK,
Self::Token(_) | Self::Revocation(_) => StatusCode::OK,
Self::PushedAuthorizationRequest(_) => StatusCode::CREATED,
_ => unimplemented!(),
}
Expand All @@ -96,6 +97,8 @@ where
pub struct OAuthServerAgent<T, D, H>
where
T: HttpClient + Send + Sync + 'static,
D: DidResolver + Send + Sync + 'static,
H: HandleResolver + Send + Sync + 'static,
{
server_metadata: OAuthAuthorizationServerMetadata,
client_metadata: OAuthClientMetadata,
Expand Down Expand Up @@ -173,6 +176,37 @@ where
)
.await
}
pub async fn revoke(&self, token: &str) -> Result<()> {
self.request(OAuthRequest::Revocation(RevocationRequestParameters { token: token.into() }))
.await
}
pub async fn refresh(&self, token_set: TokenSet) -> Result<TokenSet> {
let TokenSet { sub, scope, refresh_token, access_token, token_type, expires_at, .. } =
token_set;
let expires_in = expires_at.map(|expires_at| {
expires_at.as_ref().signed_duration_since(Datetime::now().as_ref()).num_seconds()
});
let token_response = OAuthTokenResponse {
access_token,
token_type,
expires_in,
refresh_token,
scope,
sub: Some(sub),
};
let TokenSet { scope, refresh_token: Some(refresh_token), .. } =
self.verify_token_response(token_response).await?
else {
todo!();
};
self.verify_token_response(
self.request(OAuthRequest::Token(TokenRequestParameters::RefreshToken(
RefreshTokenParameters { refresh_token, scope },
)))
.await?,
)
.await
}
pub async fn request<O>(&self, request: OAuthRequest) -> Result<O>
where
O: serde::de::DeserializeOwned,
Expand All @@ -183,6 +217,7 @@ where
let body = match &request {
OAuthRequest::Token(params) => self.build_body(params)?,
OAuthRequest::PushedAuthorizationRequest(params) => self.build_body(params)?,
OAuthRequest::Revocation(params) => self.build_body(params)?,
_ => unimplemented!(),
};
let req = Request::builder()
Expand Down Expand Up @@ -268,7 +303,7 @@ where
fn endpoint(&self, request: &OAuthRequest) -> Option<&String> {
match request {
OAuthRequest::Token(_) => Some(&self.server_metadata.token_endpoint),
OAuthRequest::Revocation => self.server_metadata.revocation_endpoint.as_ref(),
OAuthRequest::Revocation(_) => self.server_metadata.revocation_endpoint.as_ref(),
OAuthRequest::Introspection => self.server_metadata.introspection_endpoint.as_ref(),
OAuthRequest::PushedAuthorizationRequest(_) => {
self.server_metadata.pushed_authorization_request_endpoint.as_ref()
Expand Down
14 changes: 8 additions & 6 deletions atrium-oauth/oauth-client/src/store/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,27 @@ where

pub async fn get_cached<G>(&self, getter: G) -> Result<T, E>
where
G: FnOnce() -> Getter<'static, Result<T, E>> + Send + 'static,
G: FnOnce(Option<T>) -> Getter<'static, Result<T, E>> + Send + 'static,
{
let mut rx = {
let mut _self = self.0.lock().unwrap();

if let Some(value) = _self.inner.as_ref() {
if value.expires_at().map_or(true, |expires_at| Utc::now() <= expires_at.to_utc()) {
let value = match _self.inner.as_ref() {
Some(value)
if value.expires_at().map_or(true, |exp| Utc::now() <= exp.to_utc()) =>
{
return Ok(value.clone());
}
}

value => value.cloned(),
};
if let Some(pending) = _self.pending.as_ref() {
pending.subscribe()
} else {
let (tx, rx) = broadcast::channel::<Result<T, _>>(1);
_self.pending = Some(tx.clone());
let cloned = self.0.clone();

let fut = getter();
let fut = getter(value);

tokio::spawn(async move {
let response = fut.await;
Expand Down
1 change: 1 addition & 0 deletions atrium-oauth/oauth-client/src/store/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use thiserror::Error;
pub struct Error;

// TODO: LRU cache?
#[derive(Clone)]
pub struct MemorySimpleStore<K, V> {
store: Arc<Mutex<HashMap<K, V>>>,
}
Expand Down
Loading

0 comments on commit 5653d79

Please sign in to comment.